diff --git a/.bazelrc b/.bazelrc index 35e8665f78f247..8b0ece5aa0bb62 100644 --- a/.bazelrc +++ b/.bazelrc @@ -164,15 +164,19 @@ build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain build:android_arm --config=android build:android_arm --cpu=armeabi-v7a build:android_arm --fat_apk_cpu=armeabi-v7a +build:android_arm --platforms=@org_tensorflow//tensorflow/tools/toolchains/android:armeabi-v7a build:android_arm64 --config=android build:android_arm64 --cpu=arm64-v8a build:android_arm64 --fat_apk_cpu=arm64-v8a +build:android_arm64 --platforms=@org_tensorflow//tensorflow/tools/toolchains/android:arm64-v8a build:android_x86 --config=android build:android_x86 --cpu=x86 build:android_x86 --fat_apk_cpu=x86 +build:android_x86 --platforms=@org_tensorflow//tensorflow/tools/toolchains/android:x86 build:android_x86_64 --config=android build:android_x86_64 --cpu=x86_64 build:android_x86_64 --fat_apk_cpu=x86_64 +build:android_x86_64 --platforms=@org_tensorflow//tensorflow/tools/toolchains/android:x86_64 # Build everything statically for Android since all static libs are later # bundled together into a single .so for deployment. @@ -205,6 +209,7 @@ build:apple-toolchain --host_crosstool_top=@local_config_apple_cc//:toolchain # Settings for MacOS on ARM CPUs. build:macos_arm64 --cpu=darwin_arm64 build:macos_arm64 --macos_minimum_os=11.0 +build:macos_arm64 --platforms=@build_bazel_apple_support//configs/platforms:darwin_arm64 # iOS configs for each architecture and the fat binary builds. build:ios --apple_platform_type=ios @@ -213,14 +218,19 @@ build:ios --copt=-Wno-c++11-narrowing build:ios --config=apple-toolchain build:ios_armv7 --config=ios build:ios_armv7 --cpu=ios_armv7 +build:ios_armv7 --platforms=@org_tensorflow//tensorflow/tools/toolchains/ios:ios_armv7 build:ios_arm64 --config=ios build:ios_arm64 --cpu=ios_arm64 +build:ios_arm64 --platforms=@build_bazel_apple_support//configs/platforms:ios_arm64 build:ios_arm64e --config=ios build:ios_arm64e --cpu=ios_arm64e +build:ios_arm64e --platforms=@build_bazel_apple_support//configs/platforms:ios_arm64e build:ios_sim_arm64 --config=ios build:ios_sim_arm64 --cpu=ios_sim_arm64 +build:ios_sim_arm64 --platforms=@build_bazel_apple_support//configs/platforms:ios_sim_arm64 build:ios_x86_64 --config=ios build:ios_x86_64 --cpu=ios_x86_64 +build:ios_x86_64 --platforms=@build_bazel_apple_support//configs/platforms:ios_x86_64 build:ios_fat --config=ios build:ios_fat --ios_multi_cpus=armv7,arm64,i386,x86_64 @@ -257,13 +267,15 @@ build:mkl_aarch64 -c opt build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true build:mkl_aarch64_threadpool -c opt +# Default CUDA and CUDNN versions. +build:cuda_version --repo_env=HERMETIC_CUDA_VERSION="12.5.1" +build:cuda_version --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" + # CUDA: This config refers to building CUDA op kernels with nvcc. build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda -# Default CUDA and CUDNN versions. -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.5.1" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" +build:cuda --config=cuda_version # This flag is needed to include CUDA libraries. build:cuda --@local_config_cuda//cuda:include_cuda_libs=true @@ -293,8 +305,7 @@ build:cuda_clang --linkopt="-lm" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang -build:cuda_clang_official --repo_env=HERMETIC_CUDA_VERSION="12.5.1" -build:cuda_clang_official --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" +build:cuda_clang_official --config=cuda_version build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" build:cuda_clang_official --crosstool_top="@local_config_cuda//crosstool:toolchain" @@ -623,6 +634,12 @@ build:rbe_linux_cpu --python_path="/usr/bin/python3" # These you may need to change for your own GCP project. common:rbe_linux_cpu --remote_instance_name=projects/tensorflow-testing/instances/default_instance +# Download CUDA/CUDNN redistributions to preserve the repositories cache between +# CPU and GPU builds. +# TODO(ybaturina): Uncomment when RBE is ready to support this. +# build:rbe_linux_cpu --repo_env USE_CUDA_REDISTRIBUTIONS=1 +# build:rbe_linux_cpu --config=cuda_version + # TODO(kanglan): Remove it after toolchain update is complete. build:rbe_linux_cpu_old --config=rbe_linux build:rbe_linux_cpu_old --host_crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain" @@ -676,8 +693,10 @@ build:elinux --crosstool_top=@local_config_embedded_arm//:toolchain build:elinux --host_crosstool_top=@bazel_tools//tools/cpp:toolchain build:elinux_aarch64 --config=elinux build:elinux_aarch64 --cpu=aarch64 +build:elinux_aarch64 --platforms=@org_tensorflow//tensorflow/tools/toolchains/linux:linux_aarch64 build:elinux_armhf --config=elinux build:elinux_armhf --cpu=armhf +build:elinux_armhf --platforms=@org_tensorflow//tensorflow/tools/toolchains/linux:linux_armhf build:elinux_armhf --copt -mfp16-format=ieee # Config-specific options should come above this line. diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml index e612b642fb1959..c1d680fcdfcea6 100644 --- a/.github/workflows/osv-scanner-scheduled.yml +++ b/.github/workflows/osv-scanner-scheduled.yml @@ -28,7 +28,7 @@ permissions: jobs: scan-scheduled: if: github.repository == 'tensorflow/tensorflow' - uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v1.9.2" + uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v2.0.0" with: scan-args: |- --lockfile=requirements.txt:./requirements_lock_3_9.txt diff --git a/.github/workflows/pylint-presubmit.yml b/.github/workflows/pylint-presubmit.yml index 09801d29b69797..97a7b7a5f8285f 100644 --- a/.github/workflows/pylint-presubmit.yml +++ b/.github/workflows/pylint-presubmit.yml @@ -38,7 +38,7 @@ jobs: run: | echo Changed files: ${{ steps.get_file_changes.outputs.files }} - name: Set up Python 3.9 - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: "3.9" - name: Install Python dependencies diff --git a/.github/workflows/release-branch-cherrypick.yml b/.github/workflows/release-branch-cherrypick.yml index 6587769b85b868..4fa4f8d5b9435a 100644 --- a/.github/workflows/release-branch-cherrypick.yml +++ b/.github/workflows/release-branch-cherrypick.yml @@ -58,7 +58,7 @@ jobs: echo "SHORTSHA=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%h")" >> "$GITHUB_OUTPUT" echo "TITLE=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%s")" >> "$GITHUB_OUTPUT" - name: Create Pull Request with changes - uses: peter-evans/create-pull-request@dd2324fc52d5d43c699a5636bcf19fceaa70c284 # v7.0.7 + uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8 with: title: '${{ github.event.inputs.release_branch }} cherry-pick: ${{ steps.cherrypick.outputs.SHORTSHA }} "${{ steps.cherrypick.outputs.TITLE }}"' committer: TensorFlow Release Automation diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml index 6adc36c3749df4..c68351d3bd3a23 100644 --- a/.github/workflows/scorecards-analysis.yml +++ b/.github/workflows/scorecards-analysis.yml @@ -55,7 +55,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4.6.1 + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: name: SARIF file path: results.sarif @@ -64,6 +64,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@b56ba49b26e50535fa1e7f7db0f4f7b4bf65d80d # v3.28.10 + uses: github/codeql-action/upload-sarif@1b549b9259bda1cb5ddde3b41741a82a2d15a841 # v3.28.13 with: sarif_file: results.sarif diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml index 11b83f43e70882..a06d2e0125f6b9 100644 --- a/.github/workflows/update-rbe.yml +++ b/.github/workflows/update-rbe.yml @@ -130,7 +130,7 @@ jobs: map sigbuild-r2.17-clang-python3.11 2.17-python3.11 map sigbuild-r2.17-clang-python3.12 2.17-python3.12 - name: Create Pull Request with changes - uses: peter-evans/create-pull-request@dd2324fc52d5d43c699a5636bcf19fceaa70c284 # v7.0.7 + uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8 with: title: Update the RBE images to the latest container versions committer: TensorFlow Release Automation diff --git a/RELEASE.md b/RELEASE.md index 5320705c80c85f..be7804c815cc07 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -3234,7 +3234,7 @@ This release introduces several vulnerability fixes: * Keras been split into a separate PIP package (`keras`), and its code has been moved to the GitHub - repository[keras-team/keras](http://github.com/keras-team/keras). The + repository[keras-team/keras](https://github.com/keras-team/keras). The API endpoints for `tf.keras` stay unchanged, but are now backed by the `keras` PIP package. The existing code in tensorflow/python/keras is a staled copy and will be removed in future release (2.7). Please remove @@ -10309,7 +10309,7 @@ answered questions, and were part of inspiring discussions. ## Major Features And Improvements * `tf.keras` is now part of the core TensorFlow API. -* [`tf.data`](http://tensorflow.org/guide/data) is now part of the core +* [`tf.data`](https://tensorflow.org/guide/data) is now part of the core TensorFlow API. * The API is now subject to backwards compatibility guarantees. * For a guide to migrating from the `tf.contrib.data` API, see the diff --git a/WORKSPACE b/WORKSPACE index 445f974b094333..e42663c6922986 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -43,6 +43,7 @@ python_init_repositories( "3.10": "//:requirements_lock_3_10.txt", "3.11": "//:requirements_lock_3_11.txt", "3.12": "//:requirements_lock_3_12.txt", + "3.13": "//:requirements_lock_3_13.txt", }, ) diff --git a/ci/official/envs/linux_arm64 b/ci/official/envs/linux_arm64 index 643be7a872e0df..2b6e38b0e42f04 100644 --- a/ci/official/envs/linux_arm64 +++ b/ci/official/envs/linux_arm64 @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config release_arm64_linux" +TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --repo_env=USE_PYWRAP_RULES=True --config release_arm64_linux" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64 # Note: this is not set to "--cpu", because that changes the package name # to tensorflow_cpu. These ARM builds are supposed to have the name "tensorflow" diff --git a/ci/official/envs/linux_arm64_cross_compile b/ci/official/envs/linux_arm64_cross_compile index e4e9004b4f1c3a..7333be2ff9fff8 100644 --- a/ci/official/envs/linux_arm64_cross_compile +++ b/ci/official/envs/linux_arm64_cross_compile @@ -13,5 +13,5 @@ # limitations under the License. # ============================================================================== source ci/official/envs/linux_arm64 -TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config cross_compile_linux_arm64" +TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config cross_compile_linux_arm64 --repo_env=USE_PYWRAP_RULES=True" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=cross_compile_linux_arm64 diff --git a/ci/official/envs/macos_arm64 b/ci/official/envs/macos_arm64 index c789a2dc2d0990..96d8c14655cea8 100644 --- a/ci/official/envs/macos_arm64 +++ b/ci/official/envs/macos_arm64 @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config release_macos_arm64" +TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --repo_env=USE_PYWRAP_RULES=True --config release_macos_arm64" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 TFCI_BUILD_PIP_PACKAGE_WHEEL_NAME_ARG="--repo_env=WHEEL_NAME=tensorflow" TFCI_INDEX_HTML_ENABLE=1 @@ -29,7 +29,12 @@ case $TFCI_PYTHON_VERSION in 3.11) TFCI_MACOS_PYENV_INSTALL_ENABLE=0 ;; +3.13) + TFCI_MACOS_UPGRADE_PYENV_ENABLE=1 + TFCI_MACOS_PYENV_INSTALL_ENABLE=1 + ;; *) TFCI_MACOS_PYENV_INSTALL_ENABLE=1 ;; esac + diff --git a/ci/official/envs/py313 b/ci/official/envs/py313 new file mode 100644 index 00000000000000..1210c5eca815f8 --- /dev/null +++ b/ci/official/envs/py313 @@ -0,0 +1,15 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +TFCI_PYTHON_VERSION=3.13 diff --git a/ci/official/requirements_updater/requirements.in b/ci/official/requirements_updater/requirements.in index 0cfbaf22f820b1..f63fa5ccc52934 100644 --- a/ci/official/requirements_updater/requirements.in +++ b/ci/official/requirements_updater/requirements.in @@ -28,7 +28,7 @@ requests >= 2.31.0 packaging==23.2 setuptools==70.0.0 jax==0.4.7 -zstandard=0.23.0 +zstandard==0.23.0 # NVIDIA CUDA dependencies # Note that the wheels are downloaded only when the targets in bazel command # contain dependencies on these wheels. @@ -44,7 +44,7 @@ nvidia-cusparse-cu12 == 12.5.1.3 nvidia-nccl-cu12 == 2.25.1 nvidia-nvjitlink-cu12 == 12.5.82 # The dependencies below are needed for TF wheel testing. -tensorflow-io-gcs-filesystem==0.37.1 +tensorflow-io-gcs-filesystem==0.37.1 ; python_version <= "3.12" libclang >= 13.0.0 google_pasta ~= 0.2 flatbuffers ~= 24.3.25 diff --git a/ci/official/utilities/rename_and_verify_wheels.sh b/ci/official/utilities/rename_and_verify_wheels.sh index 34389f79264f12..e5b08d1a9d4acc 100755 --- a/ci/official/utilities/rename_and_verify_wheels.sh +++ b/ci/official/utilities/rename_and_verify_wheels.sh @@ -69,7 +69,11 @@ fi # TODO(b/366266944) Remove the check after tf docker image upgrade for NumPy 2 # and numpy 1 support is dropped b/361369076. if [[ "$TFCI_WHL_NUMPY_VERSION" == 1 ]]; then - "$python" -m pip install numpy==1.26.0 + if [[ "$TFCI_PYTHON_VERSION" == "3.13" ]]; then + "$python" -m pip install numpy==1.26.4 + else + "$python" -m pip install numpy==1.26.0 + fi fi "$python" -m pip install *.whl $TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS if [[ "$TFCI_WHL_IMPORT_TEST_ENABLE" == "1" ]]; then diff --git a/ci/official/utilities/setup_macos.sh b/ci/official/utilities/setup_macos.sh index 8a63d318c6e18e..05c0cf27581dcb 100644 --- a/ci/official/utilities/setup_macos.sh +++ b/ci/official/utilities/setup_macos.sh @@ -61,10 +61,23 @@ fi # those VMs does not support installing Python 3.12 and above which we need # for running smoke tests in nightly/release wheel builds. if [[ "${TFCI_MACOS_UPGRADE_PYENV_ENABLE}" == 1 ]]; then - # The TFCI Mac VM image seems to have uncommitted local changes to the Pyenv - # repository so we have to discard them and reset the working directory before - # we can pull in the latest changes. - cd /Users/kbuilder/.pyenv/ && git reset --hard HEAD && git pull && cd - + echo "Upgrading pyenv..." + echo "Current pyevn version: $(pyenv --version)" + + # Check if pyenv is managed by homebrew. If so, update and upgrade pyenv. + # Otherwise, install the latest pyenv from github. + if command -v brew &> /dev/null && brew list pyenv &> /dev/null; then + # On "ventura-slcn" VMs, pyenv is managed via Homebrew. + echo "pyenv is installed and managed by homebrew." + brew update && brew upgrade pyenv + else + echo "pyenv is not managed by homebrew. Installing it via github..." + # On "ventura" VMs, pyenv is not managed by Homebrew. Install the latest + # pyenv from github. + rm -rf "$PYENV_ROOT" + git clone https://github.com/pyenv/pyenv.git "$PYENV_ROOT" + fi + echo "Upgraded pyenv version: $(pyenv --version)" fi # "TFCI_MACOS_PYENV_INSTALL_ENABLE" controls whether to use Pyenv to install diff --git a/requirements_lock_3_13.txt b/requirements_lock_3_13.txt new file mode 100644 index 00000000000000..a03c65b0b2486c --- /dev/null +++ b/requirements_lock_3_13.txt @@ -0,0 +1,842 @@ +# +# This file is autogenerated by pip-compile with Python 3.13 +# by the following command: +# +# bazel run //ci/official/requirements_updater:requirements.update +# +absl-py==2.2.1 \ + --hash=sha256:4c7bc50d42d021c12d4f31b7001167925e0bd71ade853069f64af410f5565ff9 \ + --hash=sha256:ca8209abd5005ae6e700ef36e2edc84ad5338678f95625a3f15275410a89ffbc + # via + # dm-tree + # keras-nightly + # tb-nightly +astor==0.7.1 \ + --hash=sha256:95c30d87a6c2cf89aa628b87398466840f0ad8652f88eb173125a6df8533fb8d \ + --hash=sha256:fb503b9e2fdd05609fbf557b916b4a7824171203701660f0c55bbf5a7a68713e + # via -r ci/official/requirements_updater/requirements.in +astunparse==1.6.3 \ + --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ + --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 + # via -r ci/official/requirements_updater/requirements.in +attrs==25.3.0 \ + --hash=sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3 \ + --hash=sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b + # via dm-tree +auditwheel==6.3.0 \ + --hash=sha256:05c70a234fa14c140aa6d9076135d9550962d95849911b8d5d0419a3add09f00 \ + --hash=sha256:31cbd8045d4ff6776f79bef328b5fd563e5ecc8ae82ea34b6fe5e76efe2a84eb + # via -r ci/official/requirements_updater/requirements.in +certifi==2025.1.31 \ + --hash=sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651 \ + --hash=sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe + # via requests +charset-normalizer==3.4.1 \ + --hash=sha256:0167ddc8ab6508fe81860a57dd472b2ef4060e8d378f0cc555707126830f2537 \ + --hash=sha256:01732659ba9b5b873fc117534143e4feefecf3b2078b0a6a2e925271bb6f4cfa \ + --hash=sha256:01ad647cdd609225c5350561d084b42ddf732f4eeefe6e678765636791e78b9a \ + --hash=sha256:04432ad9479fa40ec0f387795ddad4437a2b50417c69fa275e212933519ff294 \ + --hash=sha256:0907f11d019260cdc3f94fbdb23ff9125f6b5d1039b76003b5b0ac9d6a6c9d5b \ + --hash=sha256:0924e81d3d5e70f8126529951dac65c1010cdf117bb75eb02dd12339b57749dd \ + --hash=sha256:09b26ae6b1abf0d27570633b2b078a2a20419c99d66fb2823173d73f188ce601 \ + --hash=sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd \ + --hash=sha256:0af291f4fe114be0280cdd29d533696a77b5b49cfde5467176ecab32353395c4 \ + --hash=sha256:0f55e69f030f7163dffe9fd0752b32f070566451afe180f99dbeeb81f511ad8d \ + --hash=sha256:1a2bc9f351a75ef49d664206d51f8e5ede9da246602dc2d2726837620ea034b2 \ + --hash=sha256:22e14b5d70560b8dd51ec22863f370d1e595ac3d024cb8ad7d308b4cd95f8313 \ + --hash=sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd \ + --hash=sha256:2369eea1ee4a7610a860d88f268eb39b95cb588acd7235e02fd5a5601773d4fa \ + --hash=sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8 \ + --hash=sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1 \ + --hash=sha256:2967f74ad52c3b98de4c3b32e1a44e32975e008a9cd2a8cc8966d6a5218c5cb2 \ + --hash=sha256:2a75d49014d118e4198bcee5ee0a6f25856b29b12dbf7cd012791f8a6cc5c496 \ + --hash=sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d \ + --hash=sha256:2d074908e1aecee37a7635990b2c6d504cd4766c7bc9fc86d63f9c09af3fa11b \ + --hash=sha256:2fb9bd477fdea8684f78791a6de97a953c51831ee2981f8e4f583ff3b9d9687e \ + --hash=sha256:311f30128d7d333eebd7896965bfcfbd0065f1716ec92bd5638d7748eb6f936a \ + --hash=sha256:329ce159e82018d646c7ac45b01a430369d526569ec08516081727a20e9e4af4 \ + --hash=sha256:345b0426edd4e18138d6528aed636de7a9ed169b4aaf9d61a8c19e39d26838ca \ + --hash=sha256:363e2f92b0f0174b2f8238240a1a30142e3db7b957a5dd5689b0e75fb717cc78 \ + --hash=sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408 \ + --hash=sha256:3bed14e9c89dcb10e8f3a29f9ccac4955aebe93c71ae803af79265c9ca5644c5 \ + --hash=sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3 \ + --hash=sha256:44ecbf16649486d4aebafeaa7ec4c9fed8b88101f4dd612dcaf65d5e815f837f \ + --hash=sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a \ + --hash=sha256:49402233c892a461407c512a19435d1ce275543138294f7ef013f0b63d5d3765 \ + --hash=sha256:4c0907b1928a36d5a998d72d64d8eaa7244989f7aaaf947500d3a800c83a3fd6 \ + --hash=sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146 \ + --hash=sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6 \ + --hash=sha256:5df196eb874dae23dcfb968c83d4f8fdccb333330fe1fc278ac5ceeb101003a9 \ + --hash=sha256:619a609aa74ae43d90ed2e89bdd784765de0a25ca761b93e196d938b8fd1dbbd \ + --hash=sha256:6e27f48bcd0957c6d4cb9d6fa6b61d192d0b13d5ef563e5f2ae35feafc0d179c \ + --hash=sha256:6ff8a4a60c227ad87030d76e99cd1698345d4491638dfa6673027c48b3cd395f \ + --hash=sha256:73d94b58ec7fecbc7366247d3b0b10a21681004153238750bb67bd9012414545 \ + --hash=sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176 \ + --hash=sha256:75832c08354f595c760a804588b9357d34ec00ba1c940c15e31e96d902093770 \ + --hash=sha256:7709f51f5f7c853f0fb938bcd3bc59cdfdc5203635ffd18bf354f6967ea0f824 \ + --hash=sha256:78baa6d91634dfb69ec52a463534bc0df05dbd546209b79a3880a34487f4b84f \ + --hash=sha256:7974a0b5ecd505609e3b19742b60cee7aa2aa2fb3151bc917e6e2646d7667dcf \ + --hash=sha256:7a4f97a081603d2050bfaffdefa5b02a9ec823f8348a572e39032caa8404a487 \ + --hash=sha256:7b1bef6280950ee6c177b326508f86cad7ad4dff12454483b51d8b7d673a2c5d \ + --hash=sha256:7d053096f67cd1241601111b698f5cad775f97ab25d81567d3f59219b5f1adbd \ + --hash=sha256:804a4d582ba6e5b747c625bf1255e6b1507465494a40a2130978bda7b932c90b \ + --hash=sha256:807f52c1f798eef6cf26beb819eeb8819b1622ddfeef9d0977a8502d4db6d534 \ + --hash=sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f \ + --hash=sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b \ + --hash=sha256:8436c508b408b82d87dc5f62496973a1805cd46727c34440b0d29d8a2f50a6c9 \ + --hash=sha256:89149166622f4db9b4b6a449256291dc87a99ee53151c74cbd82a53c8c2f6ccd \ + --hash=sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125 \ + --hash=sha256:8c60ca7339acd497a55b0ea5d506b2a2612afb2826560416f6894e8b5770d4a9 \ + --hash=sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de \ + --hash=sha256:955f8851919303c92343d2f66165294848d57e9bba6cf6e3625485a70a038d11 \ + --hash=sha256:97f68b8d6831127e4787ad15e6757232e14e12060bec17091b85eb1486b91d8d \ + --hash=sha256:9b23ca7ef998bc739bf6ffc077c2116917eabcc901f88da1b9856b210ef63f35 \ + --hash=sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f \ + --hash=sha256:aabfa34badd18f1da5ec1bc2715cadc8dca465868a4e73a0173466b688f29dda \ + --hash=sha256:ab36c8eb7e454e34e60eb55ca5d241a5d18b2c6244f6827a30e451c42410b5f7 \ + --hash=sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a \ + --hash=sha256:b1ac5992a838106edb89654e0aebfc24f5848ae2547d22c2c3f66454daa11971 \ + --hash=sha256:b7b2d86dd06bfc2ade3312a83a5c364c7ec2e3498f8734282c6c3d4b07b346b8 \ + --hash=sha256:b97e690a2118911e39b4042088092771b4ae3fc3aa86518f84b8cf6888dbdb41 \ + --hash=sha256:bc2722592d8998c870fa4e290c2eec2c1569b87fe58618e67d38b4665dfa680d \ + --hash=sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f \ + --hash=sha256:c30197aa96e8eed02200a83fba2657b4c3acd0f0aa4bdc9f6c1af8e8962e0757 \ + --hash=sha256:c4c3e6da02df6fa1410a7680bd3f63d4f710232d3139089536310d027950696a \ + --hash=sha256:c75cb2a3e389853835e84a2d8fb2b81a10645b503eca9bcb98df6b5a43eb8886 \ + --hash=sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77 \ + --hash=sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76 \ + --hash=sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247 \ + --hash=sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85 \ + --hash=sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb \ + --hash=sha256:dad3e487649f498dd991eeb901125411559b22e8d7ab25d3aeb1af367df5efd7 \ + --hash=sha256:dccbe65bd2f7f7ec22c4ff99ed56faa1e9f785482b9bbd7c717e26fd723a1d1e \ + --hash=sha256:dd78cfcda14a1ef52584dbb008f7ac81c1328c0f58184bf9a84c49c605002da6 \ + --hash=sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037 \ + --hash=sha256:e358e64305fe12299a08e08978f51fc21fac060dcfcddd95453eabe5b93ed0e1 \ + --hash=sha256:ea0d8d539afa5eb2728aa1932a988a9a7af94f18582ffae4bc10b3fbdad0626e \ + --hash=sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807 \ + --hash=sha256:eb8178fe3dba6450a3e024e95ac49ed3400e506fd4e9e5c32d30adda88cbd407 \ + --hash=sha256:ecddf25bee22fe4fe3737a399d0d177d72bc22be6913acfab364b40bce1ba83c \ + --hash=sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12 \ + --hash=sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3 \ + --hash=sha256:f30bf9fd9be89ecb2360c7d94a711f00c09b976258846efe40db3d05828e8089 \ + --hash=sha256:fa88b843d6e211393a37219e6a1c1df99d35e8fd90446f1118f4216e307e48cd \ + --hash=sha256:fc54db6c8593ef7d4b2a331b58653356cf04f67c960f584edb7c3d8c97e8f39e \ + --hash=sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00 \ + --hash=sha256:ffc9202a29ab3920fa812879e95a9e78b2465fd10be7fcbd042899695d75e616 + # via requests +dill==0.3.7 \ + --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ + --hash=sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03 + # via -r ci/official/requirements_updater/requirements.in +dm-tree==0.1.9 \ + --hash=sha256:12f4cc6cd52a39aa38ff31577b6d79b6136a9a89273a876bf62335c9f65c27bf \ + --hash=sha256:1ae3cbff592bb3f2e197f5a8030de4a94e292e6cdd85adeea0b971d07a1b85f2 \ + --hash=sha256:2334cfe9d2ed4293f9f1c7aefba0657deaab9ea74b5fadd966f6d01d9b6b42d9 \ + --hash=sha256:294dc1cecf87552a45cdd5ddb215e7f5295a5a47c46f1f0a0463c3dd02a527d7 \ + --hash=sha256:54d5616015412311df154908069fcf2c2d8786f6088a2ae3554d186cdf2b1e15 \ + --hash=sha256:5d5b28ee2e461b6af65330c143806a6d0945dcabbb8d22d2ba863e6dabd9254e \ + --hash=sha256:6893fcdc5cf1a4f459cfc383526d35d42e7c671ae565d7e429a2f2cb2cb93e89 \ + --hash=sha256:7d7d784afaeb4b67d87d858261aaf02503939ddc1f09c4cca70728f9892ab004 \ + --hash=sha256:80c43417814b1181d3367b335460bfdd30b79ee187a64220e11f6ddd093a4b15 \ + --hash=sha256:831699d2c60a1b38776a193b7143ae0acad0a687d87654e6d3342584166816bc \ + --hash=sha256:9020a5ce256fcc83aa4bc190cc96dd66e87685db0a6e501b0c06aa492c2e38fc \ + --hash=sha256:a4c7db3d3935a5a2d5e4b383fc26c6b0cd6f78c6d4605d3e7b518800ecd5342b \ + --hash=sha256:a8d20eeab7fde77a3ed71f07716021eb0edfb4812a128eb381d108af3a310257 \ + --hash=sha256:b06e7a5da1c31a82521a60060573527e8d24b9920fdd20b2ec86f08412737598 \ + --hash=sha256:cfa33c2e028155810ad1b4e11928707bf47489516763a86e79cab2954d23bf68 \ + --hash=sha256:d05622d074353cf434049206e53c12147903a048c4bd7d77f2800d427413ad78 \ + --hash=sha256:e1f5d1e96b3a7de22b25b13a5eb30f41f8cf9c02dd4479a24920de99e780903c \ + --hash=sha256:e660d1779ddcbd1348410d08f67db4870d413a3ec4ba8b4b045bd5ce4bd8f35c \ + --hash=sha256:e97c34fcb44941c36b7ee81dcdbceba0fbe728bddcc77e5837ab2eb665bcbff8 \ + --hash=sha256:f68b0efad76703dd4648586c75618a48cdd671b68c3266fe980e323c15423607 + # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in +gast==0.4.0 \ + --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ + --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 + # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in +grpcio==1.71.0 \ + --hash=sha256:0ab8b2864396663a5b0b0d6d79495657ae85fa37dcb6498a2669d067c65c11ea \ + --hash=sha256:0fa05ee31a20456b13ae49ad2e5d585265f71dd19fbd9ef983c28f926d45d0a7 \ + --hash=sha256:0ff35c8d807c1c7531d3002be03221ff9ae15712b53ab46e2a0b4bb271f38537 \ + --hash=sha256:1be857615e26a86d7363e8a163fade914595c81fec962b3d514a4b1e8760467b \ + --hash=sha256:20e8f653abd5ec606be69540f57289274c9ca503ed38388481e98fa396ed0b41 \ + --hash=sha256:22c3bc8d488c039a199f7a003a38cb7635db6656fa96437a8accde8322ce2366 \ + --hash=sha256:24e867651fc67717b6f896d5f0cac0ec863a8b5fb7d6441c2ab428f52c651c6b \ + --hash=sha256:2b85f7820475ad3edec209d3d89a7909ada16caab05d3f2e08a7e8ae3200a55c \ + --hash=sha256:39983a9245d37394fd59de71e88c4b295eb510a3555e0a847d9965088cdbd033 \ + --hash=sha256:3d081e859fb1ebe176de33fc3adb26c7d46b8812f906042705346b314bde32c3 \ + --hash=sha256:469f42a0b410883185eab4689060a20488a1a0a00f8bbb3cbc1061197b4c5a79 \ + --hash=sha256:47be9584729534660416f6d2a3108aaeac1122f6b5bdbf9fd823e11fe6fbaa29 \ + --hash=sha256:4be74ddeeb92cc87190e0e376dbc8fc7736dbb6d3d454f2fa1f5be1dee26b9d7 \ + --hash=sha256:4dd0dfbe4d5eb1fcfec9490ca13f82b089a309dc3678e2edabc144051270a66e \ + --hash=sha256:5b08d03ace7aca7b2fadd4baf291139b4a5f058805a8327bfe9aece7253b6d67 \ + --hash=sha256:63e41b91032f298b3e973b3fa4093cbbc620c875e2da7b93e249d4728b54559a \ + --hash=sha256:652350609332de6dac4ece254e5d7e1ff834e203d6afb769601f286886f6f3a8 \ + --hash=sha256:693bc706c031aeb848849b9d1c6b63ae6bcc64057984bb91a542332b75aa4c3d \ + --hash=sha256:74258dce215cb1995083daa17b379a1a5a87d275387b7ffe137f1d5131e2cfbb \ + --hash=sha256:789d5e2a3a15419374b7b45cd680b1e83bbc1e52b9086e49308e2c0b5bbae6e3 \ + --hash=sha256:7c9c80ac6091c916db81131d50926a93ab162a7e97e4428ffc186b6e80d6dda4 \ + --hash=sha256:7d6ac9481d9d0d129224f6d5934d5832c4b1cddb96b59e7eba8416868909786a \ + --hash=sha256:85da336e3649a3d2171e82f696b5cad2c6231fdd5bad52616476235681bee5b3 \ + --hash=sha256:8700a2a57771cc43ea295296330daaddc0d93c088f0a35cc969292b6db959bf3 \ + --hash=sha256:8997d6785e93308f277884ee6899ba63baafa0dfb4729748200fcc537858a509 \ + --hash=sha256:9182e0063112e55e74ee7584769ec5a0b4f18252c35787f48738627e23a62b97 \ + --hash=sha256:9b91879d6da1605811ebc60d21ab6a7e4bae6c35f6b63a061d61eb818c8168f6 \ + --hash=sha256:a2242d6950dc892afdf9e951ed7ff89473aaf744b7d5727ad56bdaace363722b \ + --hash=sha256:a371e6b6a5379d3692cc4ea1cb92754d2a47bdddeee755d3203d1f84ae08e03e \ + --hash=sha256:a76d39b5fafd79ed604c4be0a869ec3581a172a707e2a8d7a4858cb05a5a7637 \ + --hash=sha256:ad9f30838550695b5eb302add33f21f7301b882937460dd24f24b3cc5a95067a \ + --hash=sha256:b2266862c5ad664a380fbbcdbdb8289d71464c42a8c29053820ee78ba0119e5d \ + --hash=sha256:b78a99cd1ece4be92ab7c07765a0b038194ded2e0a26fd654591ee136088d8d7 \ + --hash=sha256:c200cb6f2393468142eb50ab19613229dcc7829b5ccee8b658a36005f6669fdd \ + --hash=sha256:c30f393f9d5ff00a71bb56de4aa75b8fe91b161aeb61d39528db6b768d7eac69 \ + --hash=sha256:c6a0a28450c16809f94e0b5bfe52cabff63e7e4b97b44123ebf77f448534d07d \ + --hash=sha256:cebc1b34ba40a312ab480ccdb396ff3c529377a2fce72c45a741f7215bfe8379 \ + --hash=sha256:d2c170247315f2d7e5798a22358e982ad6eeb68fa20cf7a820bb74c11f0736e7 \ + --hash=sha256:d35a95f05a8a2cbe8e02be137740138b3b2ea5f80bd004444e4f9a1ffc511e32 \ + --hash=sha256:d5170929109450a2c031cfe87d6716f2fae39695ad5335d9106ae88cc32dc84c \ + --hash=sha256:d6aa986318c36508dc1d5001a3ff169a15b99b9f96ef5e98e13522c506b37eef \ + --hash=sha256:d6de81c9c00c8a23047136b11794b3584cdc1460ed7cbc10eada50614baa1444 \ + --hash=sha256:dc1a1231ed23caac1de9f943d031f1bc38d0f69d2a3b243ea0d664fc1fbd7fec \ + --hash=sha256:e6beeea5566092c5e3c4896c6d1d307fb46b1d4bdf3e70c8340b190a69198594 \ + --hash=sha256:e6d8de076528f7c43a2f576bc311799f89d795aa6c9b637377cc2b1616473804 \ + --hash=sha256:e6f83a583ed0a5b08c5bc7a3fe860bb3c2eac1f03f1f63e0bc2091325605d2b7 \ + --hash=sha256:f250ff44843d9a0615e350c77f890082102a0318d66a99540f54769c8766ab73 \ + --hash=sha256:f71574afdf944e6652203cd1badcda195b2a27d9c83e6d88dc1ce3cfb73b31a5 \ + --hash=sha256:f903017db76bf9cc2b2d8bdd37bf04b505bbccad6be8a81e1542206875d0e9db \ + --hash=sha256:f9a412f55bb6e8f3bb000e020dbc1e709627dcb3a56f6431fa7076b4c1aab0db \ + --hash=sha256:f9c30c464cb2ddfbc2ddf9400287701270fdc0f14be5f08a1e3939f1e749b455 + # via + # -r ci/official/requirements_updater/requirements.in + # tb-nightly +h5py==3.13.0 \ + --hash=sha256:10894c55d46df502d82a7a4ed38f9c3fdbcb93efb42e25d275193e093071fade \ + --hash=sha256:1870e46518720023da85d0895a1960ff2ce398c5671eac3b1a41ec696b7105c3 \ + --hash=sha256:21daf38171753899b5905f3d82c99b0b1ec2cbbe282a037cad431feb620e62ec \ + --hash=sha256:22ffe2a25770a2d67213a1b94f58006c14dce06933a42d2aaa0318c5868d1508 \ + --hash=sha256:337af114616f3656da0c83b68fcf53ecd9ce9989a700b0883a6e7c483c3235d4 \ + --hash=sha256:357e6dc20b101a805ccfd0024731fbaf6e8718c18c09baf3b5e4e9d198d13fca \ + --hash=sha256:477c58307b6b9a2509c59c57811afb9f598aedede24a67da808262dfa0ee37b4 \ + --hash=sha256:4f97ecde7ac6513b21cd95efdfc38dc6d19f96f6ca6f2a30550e94e551458e0a \ + --hash=sha256:5540daee2b236d9569c950b417f13fd112d51d78b4c43012de05774908dff3f5 \ + --hash=sha256:560e71220dc92dfa254b10a4dcb12d56b574d2d87e095db20466b32a93fec3f9 \ + --hash=sha256:56dd172d862e850823c4af02dc4ddbc308f042b85472ffdaca67f1598dff4a57 \ + --hash=sha256:57c4c74f627c616f02b7aec608a8c706fe08cb5b0ba7c08555a4eb1dde20805a \ + --hash=sha256:782ff0ac39f455f21fd1c8ebc007328f65f43d56718a89327eec76677ebf238a \ + --hash=sha256:82690e89c72b85addf4fc4d5058fb1e387b6c14eb063b0b879bf3f42c3b93c35 \ + --hash=sha256:851ae3a8563d87a5a0dc49c2e2529c75b8842582ccaefbf84297d2cfceeacd61 \ + --hash=sha256:8a8e38ef4ceb969f832cc230c0cf808c613cc47e31e768fd7b1106c55afa1cb8 \ + --hash=sha256:9c82ece71ed1c2b807b6628e3933bc6eae57ea21dac207dca3470e3ceaaf437c \ + --hash=sha256:be949b46b7388074c5acae017fbbe3e5ba303fd9daaa52157fdfef30bbdacadd \ + --hash=sha256:c10f061764d8dce0a9592ce08bfd5f243a00703325c388f1086037e5d619c5f1 \ + --hash=sha256:d2cf6a231a07c14acd504a945a6e9ec115e0007f675bde5e0de30a4dc8d86a31 \ + --hash=sha256:d571644958c5e19a61c793d8d23cd02479572da828e333498c9acc463f4a3997 \ + --hash=sha256:d6f13f9b5ce549448c01e4dfe08ea8d1772e6078799af2c1c8d09e941230a90d \ + --hash=sha256:e520ec76de00943dd017c8ea3f354fa1d2f542eac994811943a8faedf2a7d5cb \ + --hash=sha256:e79d8368cd9295045956bfb436656bea3f915beaa11d342e9f79f129f5178763 \ + --hash=sha256:f35640e81b03c02a88b8bf99fb6a9d3023cc52f7c627694db2f379e0028f2868 \ + --hash=sha256:fb267ce4b83f9c42560e9ff4d30f60f7ae492eacf9c7ede849edf8c1b860e16b + # via + # -r ci/official/requirements_updater/requirements.in + # keras-nightly +idna==3.10 \ + --hash=sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9 \ + --hash=sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3 + # via requests +jax==0.4.7 \ + --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 + # via -r ci/official/requirements_updater/requirements.in +keras-nightly==3.0.4.dev2024021403 \ + --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ + --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d + # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in +lit==17.0.6 \ + --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b + # via -r ci/official/requirements_updater/requirements.in +markdown==3.7 \ + --hash=sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2 \ + --hash=sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803 + # via tb-nightly +markdown-it-py==3.0.0 \ + --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ + --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb + # via rich +markupsafe==3.0.2 \ + --hash=sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4 \ + --hash=sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30 \ + --hash=sha256:1225beacc926f536dc82e45f8a4d68502949dc67eea90eab715dea3a21c1b5f0 \ + --hash=sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9 \ + --hash=sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396 \ + --hash=sha256:1a9d3f5f0901fdec14d8d2f66ef7d035f2157240a433441719ac9a3fba440b13 \ + --hash=sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028 \ + --hash=sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca \ + --hash=sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557 \ + --hash=sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832 \ + --hash=sha256:3169b1eefae027567d1ce6ee7cae382c57fe26e82775f460f0b2778beaad66c0 \ + --hash=sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b \ + --hash=sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579 \ + --hash=sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a \ + --hash=sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c \ + --hash=sha256:48032821bbdf20f5799ff537c7ac3d1fba0ba032cfc06194faffa8cda8b560ff \ + --hash=sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c \ + --hash=sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22 \ + --hash=sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094 \ + --hash=sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb \ + --hash=sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e \ + --hash=sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5 \ + --hash=sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a \ + --hash=sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d \ + --hash=sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a \ + --hash=sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b \ + --hash=sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8 \ + --hash=sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225 \ + --hash=sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c \ + --hash=sha256:88b49a3b9ff31e19998750c38e030fc7bb937398b1f78cfa599aaef92d693144 \ + --hash=sha256:8c4e8c3ce11e1f92f6536ff07154f9d49677ebaaafc32db9db4620bc11ed480f \ + --hash=sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87 \ + --hash=sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d \ + --hash=sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93 \ + --hash=sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf \ + --hash=sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158 \ + --hash=sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84 \ + --hash=sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb \ + --hash=sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48 \ + --hash=sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171 \ + --hash=sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c \ + --hash=sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6 \ + --hash=sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd \ + --hash=sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d \ + --hash=sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1 \ + --hash=sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d \ + --hash=sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca \ + --hash=sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a \ + --hash=sha256:cfad01eed2c2e0c01fd0ecd2ef42c492f7f93902e39a42fc9ee1692961443a29 \ + --hash=sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe \ + --hash=sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798 \ + --hash=sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c \ + --hash=sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8 \ + --hash=sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f \ + --hash=sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f \ + --hash=sha256:eaa0a10b7f72326f1372a713e73c3f739b524b3af41feb43e4921cb529f5929a \ + --hash=sha256:eb7972a85c54febfb25b5c4b4f3af4dcc731994c7da0d8a0b4a6eb0640e1d178 \ + --hash=sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0 \ + --hash=sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79 \ + --hash=sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430 \ + --hash=sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50 + # via werkzeug +mdurl==0.1.2 \ + --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ + --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba + # via markdown-it-py +ml-dtypes==0.5.1 \ + --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ + --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ + --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ + --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ + --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ + --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ + --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ + --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ + --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ + --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ + --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ + --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ + --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ + --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ + --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ + --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ + --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ + --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ + --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ + --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ + --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ + --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ + --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ + --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 + # via + # -r ci/official/requirements_updater/requirements.in + # jax + # keras-nightly +namex==0.0.8 \ + --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ + --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 + # via keras-nightly +numpy==2.1.3 \ + --hash=sha256:016d0f6f5e77b0f0d45d77387ffa4bb89816b57c835580c3ce8e099ef830befe \ + --hash=sha256:02135ade8b8a84011cbb67dc44e07c58f28575cf9ecf8ab304e51c05528c19f0 \ + --hash=sha256:08788d27a5fd867a663f6fc753fd7c3ad7e92747efc73c53bca2f19f8bc06f48 \ + --hash=sha256:0d30c543f02e84e92c4b1f415b7c6b5326cbe45ee7882b6b77db7195fb971e3a \ + --hash=sha256:0fa14563cc46422e99daef53d725d0c326e99e468a9320a240affffe87852564 \ + --hash=sha256:13138eadd4f4da03074851a698ffa7e405f41a0845a6b1ad135b81596e4e9958 \ + --hash=sha256:14e253bd43fc6b37af4921b10f6add6925878a42a0c5fe83daee390bca80bc17 \ + --hash=sha256:15cb89f39fa6d0bdfb600ea24b250e5f1a3df23f901f51c8debaa6a5d122b2f0 \ + --hash=sha256:17ee83a1f4fef3c94d16dc1802b998668b5419362c8a4f4e8a491de1b41cc3ee \ + --hash=sha256:2312b2aa89e1f43ecea6da6ea9a810d06aae08321609d8dc0d0eda6d946a541b \ + --hash=sha256:2564fbdf2b99b3f815f2107c1bbc93e2de8ee655a69c261363a1172a79a257d4 \ + --hash=sha256:3522b0dfe983a575e6a9ab3a4a4dfe156c3e428468ff08ce582b9bb6bd1d71d4 \ + --hash=sha256:4394bc0dbd074b7f9b52024832d16e019decebf86caf909d94f6b3f77a8ee3b6 \ + --hash=sha256:45966d859916ad02b779706bb43b954281db43e185015df6eb3323120188f9e4 \ + --hash=sha256:4d1167c53b93f1f5d8a139a742b3c6f4d429b54e74e6b57d0eff40045187b15d \ + --hash=sha256:4f2015dfe437dfebbfce7c85c7b53d81ba49e71ba7eadbf1df40c915af75979f \ + --hash=sha256:50ca6aba6e163363f132b5c101ba078b8cbd3fa92c7865fd7d4d62d9779ac29f \ + --hash=sha256:50d18c4358a0a8a53f12a8ba9d772ab2d460321e6a93d6064fc22443d189853f \ + --hash=sha256:5641516794ca9e5f8a4d17bb45446998c6554704d888f86df9b200e66bdcce56 \ + --hash=sha256:576a1c1d25e9e02ed7fa5477f30a127fe56debd53b8d2c89d5578f9857d03ca9 \ + --hash=sha256:6a4825252fcc430a182ac4dee5a505053d262c807f8a924603d411f6718b88fd \ + --hash=sha256:72dcc4a35a8515d83e76b58fdf8113a5c969ccd505c8a946759b24e3182d1f23 \ + --hash=sha256:747641635d3d44bcb380d950679462fae44f54b131be347d5ec2bce47d3df9ed \ + --hash=sha256:762479be47a4863e261a840e8e01608d124ee1361e48b96916f38b119cfda04a \ + --hash=sha256:78574ac2d1a4a02421f25da9559850d59457bac82f2b8d7a44fe83a64f770098 \ + --hash=sha256:825656d0743699c529c5943554d223c021ff0494ff1442152ce887ef4f7561a1 \ + --hash=sha256:8637dcd2caa676e475503d1f8fdb327bc495554e10838019651b76d17b98e512 \ + --hash=sha256:96fe52fcdb9345b7cd82ecd34547fca4321f7656d500eca497eb7ea5a926692f \ + --hash=sha256:973faafebaae4c0aaa1a1ca1ce02434554d67e628b8d805e61f874b84e136b09 \ + --hash=sha256:996bb9399059c5b82f76b53ff8bb686069c05acc94656bb259b1d63d04a9506f \ + --hash=sha256:a38c19106902bb19351b83802531fea19dee18e5b37b36454f27f11ff956f7fc \ + --hash=sha256:a6b46587b14b888e95e4a24d7b13ae91fa22386c199ee7b418f449032b2fa3b8 \ + --hash=sha256:a9f7f672a3388133335589cfca93ed468509cb7b93ba3105fce780d04a6576a0 \ + --hash=sha256:aa08e04e08aaf974d4458def539dece0d28146d866a39da5639596f4921fd761 \ + --hash=sha256:b0df3635b9c8ef48bd3be5f862cf71b0a4716fa0e702155c45067c6b711ddcef \ + --hash=sha256:b47fbb433d3260adcd51eb54f92a2ffbc90a4595f8970ee00e064c644ac788f5 \ + --hash=sha256:baed7e8d7481bfe0874b566850cb0b85243e982388b7b23348c6db2ee2b2ae8e \ + --hash=sha256:bc6f24b3d1ecc1eebfbf5d6051faa49af40b03be1aaa781ebdadcbc090b4539b \ + --hash=sha256:c006b607a865b07cd981ccb218a04fc86b600411d83d6fc261357f1c0966755d \ + --hash=sha256:c181ba05ce8299c7aa3125c27b9c2167bca4a4445b7ce73d5febc411ca692e43 \ + --hash=sha256:c7662f0e3673fe4e832fe07b65c50342ea27d989f92c80355658c7f888fcc83c \ + --hash=sha256:c80e4a09b3d95b4e1cac08643f1152fa71a0a821a2d4277334c88d54b2219a41 \ + --hash=sha256:c894b4305373b9c5576d7a12b473702afdf48ce5369c074ba304cc5ad8730dff \ + --hash=sha256:d7aac50327da5d208db2eec22eb11e491e3fe13d22653dce51b0f4109101b408 \ + --hash=sha256:d89dd2b6da69c4fff5e39c28a382199ddedc3a5be5390115608345dec660b9e2 \ + --hash=sha256:d9beb777a78c331580705326d2367488d5bc473b49a9bc3036c154832520aca9 \ + --hash=sha256:dc258a761a16daa791081d026f0ed4399b582712e6fc887a95af09df10c5ca57 \ + --hash=sha256:e14e26956e6f1696070788252dcdff11b4aca4c3e8bd166e0df1bb8f315a67cb \ + --hash=sha256:e6988e90fcf617da2b5c78902fe8e668361b43b4fe26dbf2d7b0f8034d4cafb9 \ + --hash=sha256:e711e02f49e176a01d0349d82cb5f05ba4db7d5e7e0defd026328e5cfb3226d3 \ + --hash=sha256:ea4dedd6e394a9c180b33c2c872b92f7ce0f8e7ad93e9585312b0c5a04777a4a \ + --hash=sha256:ecc76a9ba2911d8d37ac01de72834d8849e55473457558e12995f4cd53e778e0 \ + --hash=sha256:f55ba01150f52b1027829b50d70ef1dafd9821ea82905b63936668403c3b471e \ + --hash=sha256:f653490b33e9c3a4c1c01d41bc2aef08f9475af51146e4a7710c450cf9761598 \ + --hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4 + # via + # -r ci/official/requirements_updater/requirements.in + # dm-tree + # h5py + # jax + # keras-nightly + # ml-dtypes + # opt-einsum + # scipy + # tb-nightly +nvidia-cublas-cu12==12.5.3.2 \ + --hash=sha256:4960f3dc5f39699acadf76fa6d94b10a2a00f2956c2c442efa299fb22b0748f3 \ + --hash=sha256:7d0191251180de606023d396b94d66f66470a0ae96d1dbb906c7656ea0f71eda \ + --hash=sha256:ca070ad70e9fa6654084575d01bd001f30cc4665e33d4bb9fc8e0f321caa034b + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.5.82 \ + --hash=sha256:4f835281cf492e2bedd153f5c3de9da8f1d775a419468305e64ce73b3b0c6dc3 \ + --hash=sha256:bde77a5feb66752ec61db2adfe47f56b941842825b4c7e2068aff27c9d107953 \ + --hash=sha256:d32c06490c6ba35c4323730820c7d0c4c126c04ed58d2f57275adb8d54b138fe + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-nvrtc-cu12==12.5.82 \ + --hash=sha256:3dbd97b0104b4bfbc3c4f8c79cd2496307c89c43c29a9f83125f1d76296ff3fd \ + --hash=sha256:5bb6a0eb01d4974bb7ca3d48bd3859472debb3c3057a5e7de2b08fbdf35eed7e \ + --hash=sha256:e5db37e990056c70953b7772dd778336ef9da0a0b5bb28f9f2a61c2e42b51d78 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-runtime-cu12==12.5.82 \ + --hash=sha256:0fd5fbca289bceb9f0690aa9858f06187b554fdeb7e2711dfd5bb3ce58900b46 \ + --hash=sha256:3e79a060e126df40fd3a068f3f787eb000fa51b251ec6cd97d09579632687115 \ + --hash=sha256:71f015dbf9df05dd71f7480132c6ebf47a6ceb2ab53d7db8e08e4b30ebb87e14 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cudnn-cu12==9.3.0.75 \ + --hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \ + --hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \ + --hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cufft-cu12==11.2.3.61 \ + --hash=sha256:4a8f6f0ce93c52a50ee83422a80472b5f376054a63f38532d0eab4007e7ef28b \ + --hash=sha256:6d45b48a5ee7599e57131129cda2c58544d9b78b95064d3ec3e5c6b96e2b58cc \ + --hash=sha256:9a6e8df162585750f61983a638104a48c756aa13f9f48e19ab079b38e3c828b8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-curand-cu12==10.3.6.82 \ + --hash=sha256:0631ba65231260ad832ce233ddda57e7b3b7158eabf000d78e46cbb5bd5b7aae \ + --hash=sha256:2823fb27de4e44dbb22394a6adf53aa6e1b013aca0f8c22867d1cfae58405536 \ + --hash=sha256:36aabeb5990297bbce3df324ea7c7c13c3aabb140c86d50ab3b23e4ec61672f1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusolver-cu12==11.6.3.83 \ + --hash=sha256:1b8b77d2fe8abe72bb722dafb708cceaeb81f1a03999477f20b33b34f46ab885 \ + --hash=sha256:6224732963cba312a84c78114b9a38c4ffabb2e2a6a120923ac99ba6f895c8cf \ + --hash=sha256:93cfafacde4428b71778eeb092ec615a02a3d05404da1bcf91c53e3fa1bce42b + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusparse-cu12==12.5.1.3 \ + --hash=sha256:016df8e993c437e8301e62739f01775cba988fd5253cd4c64173f8e8d2f8e752 \ + --hash=sha256:33520db374e2f5ebc976d6faa1852b98c398a57e6f71150fe59705928596ffd1 \ + --hash=sha256:7b97fd01f0a61628af99d0efd52132fccc8c18fc5c509f13802dccf0574a19c2 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-nvjitlink-cu12==12.5.82 \ + --hash=sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27 \ + --hash=sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697 \ + --hash=sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 + # via + # -r ci/official/requirements_updater/requirements.in + # jax +packaging==23.2 \ + --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ + --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 + # via + # -r ci/official/requirements_updater/requirements.in + # auditwheel + # tb-nightly +portpicker==1.6.0 \ + --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ + --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa + # via -r ci/official/requirements_updater/requirements.in +protobuf==6.30.2 \ + --hash=sha256:0eb523c550a66a09a0c20f86dd554afbf4d32b02af34ae53d93268c1f73bc65b \ + --hash=sha256:35c859ae076d8c56054c25b59e5e59638d86545ed6e2b6efac6be0b6ea3ba048 \ + --hash=sha256:4f6c687ae8efae6cf6093389a596548214467778146b7245e886f35e1485315d \ + --hash=sha256:50f32cc9fd9cb09c783ebc275611b4f19dfdfb68d1ee55d2f0c7fa040df96815 \ + --hash=sha256:524afedc03b31b15586ca7f64d877a98b184f007180ce25183d1a5cb230ee72b \ + --hash=sha256:7653c99774f73fe6b9301b87da52af0e69783a2e371e8b599b3e9cb4da4b12b9 \ + --hash=sha256:acec579c39c88bd8fbbacab1b8052c793efe83a0a5bd99db4a31423a25c0a0e2 \ + --hash=sha256:ae86b030e69a98e08c77beab574cbcb9fff6d031d57209f574a5aea1445f4b51 \ + --hash=sha256:b12ef7df7b9329886e66404bef5e9ce6a26b54069d7f7436a0853ccdeb91c103 + # via tb-nightly +psutil==7.0.0 \ + --hash=sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25 \ + --hash=sha256:1e744154a6580bc968a0195fd25e80432d3afec619daf145b9e5ba16cc1d688e \ + --hash=sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91 \ + --hash=sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da \ + --hash=sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34 \ + --hash=sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553 \ + --hash=sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456 \ + --hash=sha256:84df4eb63e16849689f76b1ffcb36db7b8de703d1bc1fe41773db487621b6c17 \ + --hash=sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993 \ + --hash=sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99 + # via portpicker +pyelftools==0.32 \ + --hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \ + --hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5 + # via auditwheel +pygments==2.19.1 \ + --hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \ + --hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c + # via rich +requests==2.32.3 \ + --hash=sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760 \ + --hash=sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6 + # via -r ci/official/requirements_updater/requirements.in +rich==14.0.0 \ + --hash=sha256:1c9491e1951aac09caffd42f448ee3d04e58923ffe14993f6e83068dc395d7e0 \ + --hash=sha256:82f1bc23a6a21ebca4ae0c45af9bdbc492ed20231dcb63f297d6d1021a9d5725 + # via keras-nightly +scipy==1.15.2 \ + --hash=sha256:01edfac9f0798ad6b46d9c4c9ca0e0ad23dbf0b1eb70e96adb9fa7f525eff0bf \ + --hash=sha256:03205d57a28e18dfd39f0377d5002725bf1f19a46f444108c29bdb246b6c8a11 \ + --hash=sha256:08b57a9336b8e79b305a143c3655cc5bdbe6d5ece3378578888d2afbb51c4e37 \ + --hash=sha256:11e7ad32cf184b74380f43d3c0a706f49358b904fa7d5345f16ddf993609184d \ + --hash=sha256:28a0d2c2075946346e4408b211240764759e0fabaeb08d871639b5f3b1aca8a0 \ + --hash=sha256:2b871df1fe1a3ba85d90e22742b93584f8d2b8e6124f8372ab15c71b73e428b8 \ + --hash=sha256:302093e7dfb120e55515936cb55618ee0b895f8bcaf18ff81eca086c17bd80af \ + --hash=sha256:42dabaaa798e987c425ed76062794e93a243be8f0f20fff6e7a89f4d61cb3d40 \ + --hash=sha256:447ce30cee6a9d5d1379087c9e474628dab3db4a67484be1b7dc3196bfb2fac9 \ + --hash=sha256:4c6676490ad76d1c2894d77f976144b41bd1a4052107902238047fb6a473e971 \ + --hash=sha256:54c462098484e7466362a9f1672d20888f724911a74c22ae35b61f9c5919183d \ + --hash=sha256:597a0c7008b21c035831c39927406c6181bcf8f60a73f36219b69d010aa04737 \ + --hash=sha256:5a6fd6eac1ce74a9f77a7fc724080d507c5812d61e72bd5e4c489b042455865e \ + --hash=sha256:5ea7ed46d437fc52350b028b1d44e002646e28f3e8ddc714011aaf87330f2f32 \ + --hash=sha256:601881dfb761311045b03114c5fe718a12634e5608c3b403737ae463c9885d53 \ + --hash=sha256:62ca1ff3eb513e09ed17a5736929429189adf16d2d740f44e53270cc800ecff1 \ + --hash=sha256:69ea6e56d00977f355c0f84eba69877b6df084516c602d93a33812aa04d90a3d \ + --hash=sha256:6a8e34cf4c188b6dd004654f88586d78f95639e48a25dfae9c5e34a6dc34547e \ + --hash=sha256:6d0194c37037707b2afa7a2f2a924cf7bac3dc292d51b6a925e5fcb89bc5c776 \ + --hash=sha256:6f223753c6ea76983af380787611ae1291e3ceb23917393079dcc746ba60cfb5 \ + --hash=sha256:6f5e296ec63c5da6ba6fa0343ea73fd51b8b3e1a300b0a8cae3ed4b1122c7462 \ + --hash=sha256:7cd5b77413e1855351cdde594eca99c1f4a588c2d63711388b6a1f1c01f62274 \ + --hash=sha256:869269b767d5ee7ea6991ed7e22b3ca1f22de73ab9a49c44bad338b725603301 \ + --hash=sha256:87994da02e73549dfecaed9e09a4f9d58a045a053865679aeb8d6d43747d4df3 \ + --hash=sha256:888307125ea0c4466287191e5606a2c910963405ce9671448ff9c81c53f85f58 \ + --hash=sha256:92233b2df6938147be6fa8824b8136f29a18f016ecde986666be5f4d686a91a4 \ + --hash=sha256:9412f5e408b397ff5641080ed1e798623dbe1ec0d78e72c9eca8992976fa65aa \ + --hash=sha256:9b18aa747da280664642997e65aab1dd19d0c3d17068a04b3fe34e2559196cb9 \ + --hash=sha256:9de9d1416b3d9e7df9923ab23cd2fe714244af10b763975bea9e4f2e81cebd27 \ + --hash=sha256:a2ec871edaa863e8213ea5df811cd600734f6400b4af272e1c011e69401218e9 \ + --hash=sha256:a5080a79dfb9b78b768cebf3c9dcbc7b665c5875793569f48bf0e2b1d7f68f6f \ + --hash=sha256:a8bf5cb4a25046ac61d38f8d3c3426ec11ebc350246a4642f2f315fe95bda655 \ + --hash=sha256:b09ae80010f52efddb15551025f9016c910296cf70adbf03ce2a8704f3a5ad20 \ + --hash=sha256:b5e025e903b4f166ea03b109bb241355b9c42c279ea694d8864d033727205e65 \ + --hash=sha256:bad78d580270a4d32470563ea86c6590b465cb98f83d760ff5b0990cb5518a93 \ + --hash=sha256:bae43364d600fdc3ac327db99659dcb79e6e7ecd279a75fe1266669d9a652828 \ + --hash=sha256:c4697a10da8f8765bb7c83e24a470da5797e37041edfd77fd95ba3811a47c4fd \ + --hash=sha256:c90ebe8aaa4397eaefa8455a8182b164a6cc1d59ad53f79943f266d99f68687f \ + --hash=sha256:cd58a314d92838f7e6f755c8a2167ead4f27e1fd5c1251fd54289569ef3495ec \ + --hash=sha256:cf72ff559a53a6a6d77bd8eefd12a17995ffa44ad86c77a5df96f533d4e6c6bb \ + --hash=sha256:def751dd08243934c884a3221156d63e15234a3155cf25978b0a668409d45eb6 \ + --hash=sha256:e7c68b6a43259ba0aab737237876e5c2c549a031ddb7abc28c7b47f22e202ded \ + --hash=sha256:ecf797d2d798cf7c838c6d98321061eb3e72a74710e6c40540f0e8087e3b499e \ + --hash=sha256:f031846580d9acccd0044efd1a90e6f4df3a6e12b4b6bd694a7bc03a89892b28 \ + --hash=sha256:fb530e4794fc8ea76a4a21ccb67dea33e5e0e60f07fc38a49e821e1eae3b71a0 \ + --hash=sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db + # via + # -r ci/official/requirements_updater/requirements.in + # jax +six==1.17.0 \ + --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ + --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 + # via + # astunparse + # google-pasta + # tb-nightly +tb-nightly==2.19.0a20250218 \ + --hash=sha256:7c7fea911a9e113e7d40fa9aed96168840e2443c5ada52fba5bc3645ec6e206f + # via -r ci/official/requirements_updater/requirements.in +tblib==2.0.0 \ + --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ + --hash=sha256:a6df30f272c08bf8be66e0775fad862005d950a6b8449b94f7c788731d70ecd7 + # via -r ci/official/requirements_updater/requirements.in +tensorboard-data-server==0.7.2 \ + --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \ + --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ + --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 + # via tb-nightly +termcolor==2.3.0 \ + --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ + --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a + # via -r ci/official/requirements_updater/requirements.in +typing-extensions==4.8.0 \ + --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ + --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef + # via -r ci/official/requirements_updater/requirements.in +urllib3==2.3.0 \ + --hash=sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df \ + --hash=sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d + # via requests +werkzeug==3.1.3 \ + --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ + --hash=sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746 + # via tb-nightly +wheel==0.41.3 \ + --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ + --hash=sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841 + # via + # -r ci/official/requirements_updater/requirements.in + # astunparse +wrapt==1.16.0 \ + --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ + --hash=sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81 \ + --hash=sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09 \ + --hash=sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e \ + --hash=sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca \ + --hash=sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0 \ + --hash=sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb \ + --hash=sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487 \ + --hash=sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40 \ + --hash=sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c \ + --hash=sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060 \ + --hash=sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202 \ + --hash=sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41 \ + --hash=sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9 \ + --hash=sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b \ + --hash=sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664 \ + --hash=sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d \ + --hash=sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362 \ + --hash=sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00 \ + --hash=sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc \ + --hash=sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1 \ + --hash=sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267 \ + --hash=sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956 \ + --hash=sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966 \ + --hash=sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1 \ + --hash=sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228 \ + --hash=sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72 \ + --hash=sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d \ + --hash=sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292 \ + --hash=sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0 \ + --hash=sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0 \ + --hash=sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36 \ + --hash=sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c \ + --hash=sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5 \ + --hash=sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f \ + --hash=sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73 \ + --hash=sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b \ + --hash=sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2 \ + --hash=sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593 \ + --hash=sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39 \ + --hash=sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389 \ + --hash=sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf \ + --hash=sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf \ + --hash=sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89 \ + --hash=sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c \ + --hash=sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c \ + --hash=sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f \ + --hash=sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440 \ + --hash=sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465 \ + --hash=sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136 \ + --hash=sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b \ + --hash=sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8 \ + --hash=sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3 \ + --hash=sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8 \ + --hash=sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6 \ + --hash=sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e \ + --hash=sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f \ + --hash=sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c \ + --hash=sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e \ + --hash=sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8 \ + --hash=sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2 \ + --hash=sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020 \ + --hash=sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35 \ + --hash=sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d \ + --hash=sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3 \ + --hash=sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537 \ + --hash=sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809 \ + --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ + --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ + --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 + # via + # -r ci/official/requirements_updater/requirements.in + # dm-tree +zstandard==0.23.0 \ + --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ + --hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \ + --hash=sha256:11e3bf3c924853a2d5835b24f03eeba7fc9b07d8ca499e247e06ff5676461a15 \ + --hash=sha256:12a289832e520c6bd4dcaad68e944b86da3bad0d339ef7989fb7e88f92e96072 \ + --hash=sha256:1516c8c37d3a053b01c1c15b182f3b5f5eef19ced9b930b684a73bad121addf4 \ + --hash=sha256:157e89ceb4054029a289fb504c98c6a9fe8010f1680de0201b3eb5dc20aa6d9e \ + --hash=sha256:1bfe8de1da6d104f15a60d4a8a768288f66aa953bbe00d027398b93fb9680b26 \ + --hash=sha256:1e172f57cd78c20f13a3415cc8dfe24bf388614324d25539146594c16d78fcc8 \ + --hash=sha256:1fd7e0f1cfb70eb2f95a19b472ee7ad6d9a0a992ec0ae53286870c104ca939e5 \ + --hash=sha256:203d236f4c94cd8379d1ea61db2fce20730b4c38d7f1c34506a31b34edc87bdd \ + --hash=sha256:27d3ef2252d2e62476389ca8f9b0cf2bbafb082a3b6bfe9d90cbcbb5529ecf7c \ + --hash=sha256:29a2bc7c1b09b0af938b7a8343174b987ae021705acabcbae560166567f5a8db \ + --hash=sha256:2ef230a8fd217a2015bc91b74f6b3b7d6522ba48be29ad4ea0ca3a3775bf7dd5 \ + --hash=sha256:2ef3775758346d9ac6214123887d25c7061c92afe1f2b354f9388e9e4d48acfc \ + --hash=sha256:2f146f50723defec2975fb7e388ae3a024eb7151542d1599527ec2aa9cacb152 \ + --hash=sha256:2fb4535137de7e244c230e24f9d1ec194f61721c86ebea04e1581d9d06ea1269 \ + --hash=sha256:32ba3b5ccde2d581b1e6aa952c836a6291e8435d788f656fe5976445865ae045 \ + --hash=sha256:34895a41273ad33347b2fc70e1bff4240556de3c46c6ea430a7ed91f9042aa4e \ + --hash=sha256:379b378ae694ba78cef921581ebd420c938936a153ded602c4fea612b7eaa90d \ + --hash=sha256:38302b78a850ff82656beaddeb0bb989a0322a8bbb1bf1ab10c17506681d772a \ + --hash=sha256:3aa014d55c3af933c1315eb4bb06dd0459661cc0b15cd61077afa6489bec63bb \ + --hash=sha256:4051e406288b8cdbb993798b9a45c59a4896b6ecee2f875424ec10276a895740 \ + --hash=sha256:40b33d93c6eddf02d2c19f5773196068d875c41ca25730e8288e9b672897c105 \ + --hash=sha256:43da0f0092281bf501f9c5f6f3b4c975a8a0ea82de49ba3f7100e64d422a1274 \ + --hash=sha256:445e4cb5048b04e90ce96a79b4b63140e3f4ab5f662321975679b5f6360b90e2 \ + --hash=sha256:48ef6a43b1846f6025dde6ed9fee0c24e1149c1c25f7fb0a0585572b2f3adc58 \ + --hash=sha256:50a80baba0285386f97ea36239855f6020ce452456605f262b2d33ac35c7770b \ + --hash=sha256:519fbf169dfac1222a76ba8861ef4ac7f0530c35dd79ba5727014613f91613d4 \ + --hash=sha256:53dd9d5e3d29f95acd5de6802e909ada8d8d8cfa37a3ac64836f3bc4bc5512db \ + --hash=sha256:53ea7cdc96c6eb56e76bb06894bcfb5dfa93b7adcf59d61c6b92674e24e2dd5e \ + --hash=sha256:576856e8594e6649aee06ddbfc738fec6a834f7c85bf7cadd1c53d4a58186ef9 \ + --hash=sha256:59556bf80a7094d0cfb9f5e50bb2db27fefb75d5138bb16fb052b61b0e0eeeb0 \ + --hash=sha256:5d41d5e025f1e0bccae4928981e71b2334c60f580bdc8345f824e7c0a4c2a813 \ + --hash=sha256:61062387ad820c654b6a6b5f0b94484fa19515e0c5116faf29f41a6bc91ded6e \ + --hash=sha256:61f89436cbfede4bc4e91b4397eaa3e2108ebe96d05e93d6ccc95ab5714be512 \ + --hash=sha256:62136da96a973bd2557f06ddd4e8e807f9e13cbb0bfb9cc06cfe6d98ea90dfe0 \ + --hash=sha256:64585e1dba664dc67c7cdabd56c1e5685233fbb1fc1966cfba2a340ec0dfff7b \ + --hash=sha256:65308f4b4890aa12d9b6ad9f2844b7ee42c7f7a4fd3390425b242ffc57498f48 \ + --hash=sha256:66b689c107857eceabf2cf3d3fc699c3c0fe8ccd18df2219d978c0283e4c508a \ + --hash=sha256:6a41c120c3dbc0d81a8e8adc73312d668cd34acd7725f036992b1b72d22c1772 \ + --hash=sha256:6f77fa49079891a4aab203d0b1744acc85577ed16d767b52fc089d83faf8d8ed \ + --hash=sha256:72c68dda124a1a138340fb62fa21b9bf4848437d9ca60bd35db36f2d3345f373 \ + --hash=sha256:752bf8a74412b9892f4e5b58f2f890a039f57037f52c89a740757ebd807f33ea \ + --hash=sha256:76e79bc28a65f467e0409098fa2c4376931fd3207fbeb6b956c7c476d53746dd \ + --hash=sha256:774d45b1fac1461f48698a9d4b5fa19a69d47ece02fa469825b442263f04021f \ + --hash=sha256:77da4c6bfa20dd5ea25cbf12c76f181a8e8cd7ea231c673828d0386b1740b8dc \ + --hash=sha256:77ea385f7dd5b5676d7fd943292ffa18fbf5c72ba98f7d09fc1fb9e819b34c23 \ + --hash=sha256:80080816b4f52a9d886e67f1f96912891074903238fe54f2de8b786f86baded2 \ + --hash=sha256:80a539906390591dd39ebb8d773771dc4db82ace6372c4d41e2d293f8e32b8db \ + --hash=sha256:82d17e94d735c99621bf8ebf9995f870a6b3e6d14543b99e201ae046dfe7de70 \ + --hash=sha256:837bb6764be6919963ef41235fd56a6486b132ea64afe5fafb4cb279ac44f259 \ + --hash=sha256:84433dddea68571a6d6bd4fbf8ff398236031149116a7fff6f777ff95cad3df9 \ + --hash=sha256:8c24f21fa2af4bb9f2c492a86fe0c34e6d2c63812a839590edaf177b7398f700 \ + --hash=sha256:8ed7d27cb56b3e058d3cf684d7200703bcae623e1dcc06ed1e18ecda39fee003 \ + --hash=sha256:9206649ec587e6b02bd124fb7799b86cddec350f6f6c14bc82a2b70183e708ba \ + --hash=sha256:983b6efd649723474f29ed42e1467f90a35a74793437d0bc64a5bf482bedfa0a \ + --hash=sha256:98da17ce9cbf3bfe4617e836d561e433f871129e3a7ac16d6ef4c680f13a839c \ + --hash=sha256:9c236e635582742fee16603042553d276cca506e824fa2e6489db04039521e90 \ + --hash=sha256:9da6bc32faac9a293ddfdcb9108d4b20416219461e4ec64dfea8383cac186690 \ + --hash=sha256:a05e6d6218461eb1b4771d973728f0133b2a4613a6779995df557f70794fd60f \ + --hash=sha256:a0817825b900fcd43ac5d05b8b3079937073d2b1ff9cf89427590718b70dd840 \ + --hash=sha256:a4ae99c57668ca1e78597d8b06d5af837f377f340f4cce993b551b2d7731778d \ + --hash=sha256:a8c86881813a78a6f4508ef9daf9d4995b8ac2d147dcb1a450448941398091c9 \ + --hash=sha256:a8fffdbd9d1408006baaf02f1068d7dd1f016c6bcb7538682622c556e7b68e35 \ + --hash=sha256:a9b07268d0c3ca5c170a385a0ab9fb7fdd9f5fd866be004c4ea39e44edce47dd \ + --hash=sha256:ab19a2d91963ed9e42b4e8d77cd847ae8381576585bad79dbd0a8837a9f6620a \ + --hash=sha256:ac184f87ff521f4840e6ea0b10c0ec90c6b1dcd0bad2f1e4a9a1b4fa177982ea \ + --hash=sha256:b0e166f698c5a3e914947388c162be2583e0c638a4703fc6a543e23a88dea3c1 \ + --hash=sha256:b2170c7e0367dde86a2647ed5b6f57394ea7f53545746104c6b09fc1f4223573 \ + --hash=sha256:b2d8c62d08e7255f68f7a740bae85b3c9b8e5466baa9cbf7f57f1cde0ac6bc09 \ + --hash=sha256:b4567955a6bc1b20e9c31612e615af6b53733491aeaa19a6b3b37f3b65477094 \ + --hash=sha256:b69bb4f51daf461b15e7b3db033160937d3ff88303a7bc808c67bbc1eaf98c78 \ + --hash=sha256:b8c0bd73aeac689beacd4e7667d48c299f61b959475cdbb91e7d3d88d27c56b9 \ + --hash=sha256:be9b5b8659dff1f913039c2feee1aca499cfbc19e98fa12bc85e037c17ec6ca5 \ + --hash=sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9 \ + --hash=sha256:c16842b846a8d2a145223f520b7e18b57c8f476924bda92aeee3a88d11cfc391 \ + --hash=sha256:c363b53e257246a954ebc7c488304b5592b9c53fbe74d03bc1c64dda153fb847 \ + --hash=sha256:c7c517d74bea1a6afd39aa612fa025e6b8011982a0897768a2f7c8ab4ebb78a2 \ + --hash=sha256:d20fd853fbb5807c8e84c136c278827b6167ded66c72ec6f9a14b863d809211c \ + --hash=sha256:d2240ddc86b74966c34554c49d00eaafa8200a18d3a5b6ffbf7da63b11d74ee2 \ + --hash=sha256:d477ed829077cd945b01fc3115edd132c47e6540ddcd96ca169facff28173057 \ + --hash=sha256:d50d31bfedd53a928fed6707b15a8dbeef011bb6366297cc435accc888b27c20 \ + --hash=sha256:dc1d33abb8a0d754ea4763bad944fd965d3d95b5baef6b121c0c9013eaf1907d \ + --hash=sha256:dc5d1a49d3f8262be192589a4b72f0d03b72dcf46c51ad5852a4fdc67be7b9e4 \ + --hash=sha256:e2d1a054f8f0a191004675755448d12be47fa9bebbcffa3cdf01db19f2d30a54 \ + --hash=sha256:e7792606d606c8df5277c32ccb58f29b9b8603bf83b48639b7aedf6df4fe8171 \ + --hash=sha256:ed1708dbf4d2e3a1c5c69110ba2b4eb6678262028afd6c6fbcc5a8dac9cda68e \ + --hash=sha256:f2d4380bf5f62daabd7b751ea2339c1a21d1c9463f1feb7fc2bdcea2c29c3160 \ + --hash=sha256:f3513916e8c645d0610815c257cbfd3242adfd5c4cfa78be514e5a3ebb42a41b \ + --hash=sha256:f8346bfa098532bc1fb6c7ef06783e969d87a99dd1d2a5a18a892c1d7a643c58 \ + --hash=sha256:f83fa6cae3fff8e98691248c9320356971b59678a17f20656a9e59cd32cee6d8 \ + --hash=sha256:fa6ce8b52c5987b3e34d5674b0ab529a4602b632ebab0a93b07bfb4dfc8f8a33 \ + --hash=sha256:fb2b1ecfef1e67897d336de3a0e3f52478182d6a47eda86cbd42504c5cbd009a \ + --hash=sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880 \ + --hash=sha256:fd30d9c67d13d891f2360b2a120186729c111238ac63b43dbd37a5a40670b8ca \ + --hash=sha256:fd7699e8fd9969f455ef2926221e0233f81a2542921471382e77a9e2f2b57f4b \ + --hash=sha256:fe3b385d996ee0822fd46528d9f0443b880d4d05528fd26a9119a54ec3f91c69 + # via -r ci/official/requirements_updater/requirements.in + +# The following packages are considered to be unsafe in a requirements file: +setuptools==70.0.0 \ + --hash=sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4 \ + --hash=sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0 + # via + # -r ci/official/requirements_updater/requirements.in + # tb-nightly diff --git a/tensorflow/BUILD b/tensorflow/BUILD index bd81809453539e..ead208a518b77a 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -253,7 +253,7 @@ config_setting( config_setting( name = "android", constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], + ["@platforms//os:android"], [], ), values = if_oss( @@ -265,45 +265,45 @@ config_setting( config_setting( name = "android_x86", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_32", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "x86", ), visibility = ["//visibility:public"], ) config_setting( name = "android_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "x86_64", ), visibility = ["//visibility:public"], ) config_setting( name = "android_armeabi", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:armv6-m", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "armeabi", ), visibility = ["//visibility:public"], ) @@ -311,22 +311,28 @@ config_setting( # copybara:uncomment_begin(google-only) # config_setting( # name = "chromiumos_x86_64", -# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], -# values = {"cpu": "k8"}, +# constraint_values = [ +# "@platforms//cpu:x86_64", +# "@platforms//os:chromiumos", +# ], # visibility = ["//visibility:public"], # ) # # config_setting( # name = "chromiumos_arm64", -# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], -# values = {"cpu": "arm"}, +# constraint_values = [ +# "@platforms//cpu:aarch64", +# "@platforms//os:chromiumos", +# ], # visibility = ["//visibility:public"], # ) # # config_setting( # name = "chromiumos_armv7", -# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], -# values = {"cpu": "armeabi-v7a"}, +# constraint_values = [ +# "@platforms//cpu:armv7", +# "@platforms//os:chromiumos", +# ], # visibility = ["//visibility:public"], # ) # copybara:uncomment_end @@ -334,7 +340,7 @@ config_setting( config_setting( name = "emscripten", constraint_values = if_google( - ["//third_party/bazel_platforms/os:emscripten"], + ["@platforms//os:emscripten"], [], ), values = if_oss( @@ -346,57 +352,56 @@ config_setting( config_setting( name = "raspberry_pi_armeabi", + constraint_values = + [ + "@platforms//cpu:armv6-m", + "@platforms//os:linux", + ], values = { "crosstool_top": "@local_config_arm_compiler//:toolchain", - "cpu": "armeabi", }, visibility = ["//visibility:public"], ) config_setting( name = "android_arm", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:armv7", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "armeabi-v7a", ), visibility = ["//visibility:public"], ) config_setting( name = "android_arm64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:aarch64", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "arm64-v8a", ), visibility = ["//visibility:public"], ) -config_setting( - name = "android_mips", - values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "mips", - }, - visibility = ["//visibility:public"], -) - config_setting( name = "android_mips64", + constraint_values = + [ + "@platforms//cpu:mips64", + "@platforms//os:android", + ], values = { "crosstool_top": "//external:android/crosstool", - "cpu": "mips64", }, visibility = ["//visibility:public"], ) @@ -404,16 +409,10 @@ config_setting( # TODO(jakeharmon8): Remove in favor of TSL version config_setting( name = "windows", - # Internal builds query the target OS. - constraint_values = if_google( - ["//third_party/bazel_platforms/os:windows"], - [], - ), - # OSS builds query the CPU type. - values = if_oss( - {"cpu": "x64_windows"}, - {}, - ), + constraint_values = + [ + "@platforms//os:windows", + ], visibility = ["//visibility:public"], ) @@ -423,52 +422,28 @@ config_setting( visibility = ["//visibility:public"], ) -# Sometimes Bazel reports darwin_x86_64 as "darwin" and sometimes as -# "darwin_x86_64". The former shows up when building on a Mac x86_64 host for a Mac x86_64 target. -# The latter shows up when cross-compiling for Mac x86_64 from a Mac ARM machine and in internal -# Google builds. -config_setting( - name = "macos_x86_64_default", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:macos"], - [], - ), - values = { - "apple_platform_type": "macos", - "cpu": "darwin", - }, -) - config_setting( - name = "macos_x86_64_crosscompile", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:macos"], - [], - ), + name = "macos_x86_64", + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:macos", + ], values = { "apple_platform_type": "macos", - "cpu": "darwin_x86_64", }, -) - -selects.config_setting_group( - name = "macos_x86_64", - match_any = [ - ":macos_x86_64_default", - ":macos_x86_64_crosscompile", - ], visibility = ["//visibility:public"], ) config_setting( name = "macos_arm64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:macos"], - [], - ), + constraint_values = + [ + "@platforms//cpu:aarch64", + "@platforms//os:macos", + ], values = { "apple_platform_type": "macos", - "cpu": "darwin_arm64", }, visibility = ["//visibility:public"], ) @@ -486,7 +461,7 @@ selects.config_setting_group( config_setting( name = "ios", constraint_values = if_google( - ["//third_party/bazel_platforms/os:ios"], + ["@platforms//os:ios"], [], ), values = if_oss( @@ -499,41 +474,32 @@ config_setting( # TODO(jakeharmon8): Remove in favor of TSL version config_setting( name = "fuchsia", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:fuchsia"], - [], - ), - values = if_oss( - # TODO(b/149248802) When we have a Fuchsia Bazel SDK update to use the values it sets. - {"cpu": "fuchsia"}, - {}, - ), + constraint_values = + ["@platforms//os:fuchsia"], visibility = ["//visibility:public"], ) config_setting( name = "fuchsia_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:fuchsia"], - [], - ), - values = { - "cpu": "x86_64", - }, + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:fuchsia", + ], visibility = ["//visibility:public"], ) config_setting( name = "ios_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:ios"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:ios", + ], values = dict( if_oss( {"crosstool_top": "//tools/osx/crosstool:crosstool"}, ), - cpu = "ios_x86_64", ), visibility = ["//visibility:public"], ) @@ -541,7 +507,7 @@ config_setting( config_setting( name = "chromiumos", constraint_values = if_google( - ["//third_party/bazel_platforms/os:chromiumos"], + ["@platforms//os:chromiumos"], [], ), values = if_oss( @@ -553,49 +519,43 @@ config_setting( config_setting( name = "linux_aarch64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "aarch64"}, + constraint_values = + [ + "@platforms//cpu:aarch64", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_armhf", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "armhf"}, + constraint_values = + [ + "@platforms//cpu:armv7e-mf", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "k8"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "haswell", - values = {"cpu": "haswell"}, + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) # This condition takes precedence over :linux_x86_64 config_setting( name = "linux_x86_64_no_sse", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], values = { - "cpu": "k8", "copt": "-mno-sse4.2", }, visibility = ["//visibility:public"], @@ -605,52 +565,52 @@ config_setting( # TODO(b/290533709): Remove this with PJRT build rule cleanup. config_setting( name = "linux_x86_64_with_weightwatcher", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], define_values = {"tensorflow_weightwatcher": "true"}, - values = {"cpu": "k8"}, visibility = ["//visibility:public"], ) config_setting( name = "linux_ppc64le", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "ppc"}, + constraint_values = + [ + "@platforms//cpu:ppc64le", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_s390x", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "s390x"}, + constraint_values = + [ + "@platforms//cpu:s390x", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_mips64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "mips64"}, + constraint_values = + [ + "@platforms//cpu:mips64", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_riscv64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "riscv64"}, + constraint_values = + [ + "@platforms//cpu:riscv64", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) @@ -670,45 +630,25 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "arm", - values = {"cpu": "arm"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "armeabi", - values = {"cpu": "armeabi"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "armeabi-v7a", - values = {"cpu": "armeabi-v7a"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "arm64-v8a", - values = {"cpu": "arm64-v8a"}, - visibility = ["//visibility:public"], -) - selects.config_setting_group( name = "arm_any", match_any = [ - ":arm", - ":armeabi", - ":armeabi-v7a", - ":arm64-v8a", - ":linux_aarch64", - ":linux_armhf", + "@platforms//cpu:aarch32", + "@platforms//cpu:aarch64", + "@platforms//cpu:armv6-m", + "@platforms//cpu:armv7", + "@platforms//cpu:armv7-m", + "@platforms//cpu:armv7e-m", + "@platforms//cpu:armv7e-mf", ], ) config_setting( name = "freebsd", - values = {"cpu": "freebsd"}, + constraint_values = [ + "@platforms//os:freebsd", + "@platforms//cpu:x86_64", + ], visibility = ["//visibility:public"], ) @@ -1775,6 +1715,7 @@ py_library( "//tensorflow/lite/python:lite", "//tensorflow/lite/python/authoring", "//tensorflow/python:no_contrib", + "//tensorflow/python/profiler:profiler_client", "@pypi_keras_nightly//:pkg", "@pypi_tb_nightly//:pkg", ], diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 4e8f5156761f9c..e4c2c92783d4d8 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -1119,7 +1119,7 @@ cc_library( tf_cuda_cc_test( name = "dlpack_test", - size = "small", + size = "medium", srcs = [ "dlpack_test.cc", ], diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 23a6fa0d240440..39f93d17aa2932 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1,3 +1,5 @@ +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load("@local_xla//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm") load( "@local_xla//xla/tsl:tsl.bzl", @@ -106,6 +108,10 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "@local_xla//xla/service:gpu_plugin", "//tensorflow/core/tfrt/common:pjrt_gpu_client_registration", + ]) + if_cuda([ + "@local_xla//xla/stream_executor/cuda:all_runtime", # buildcleaner: keep + ]) + if_rocm([ + "@local_xla//xla/stream_executor/rocm:all_runtime", # buildcleaner: keep ]), alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 2b15a4affc76af..50b26371698877 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -370,11 +370,16 @@ bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const { // https://github.com/tensorflow/tensorflow/pull/31012: // ResizeNearestNeighbor, ResizeBilinear, and ResizeBilinearGrad sometimes // create convolutions too large for CuDNN to handle. + // NonMaxSuppressionV3/V4 in XLA runs significantly slower than TF kernel in + // object detection models, specially when there are a lot of proposed + // bounding boxes. return node.type_string() == "SelfAdjointEigV2" || node.type_string() == "Svd" || node.type_string() == "Qr" || node.type_string() == "MatrixInverse" || node.type_string() == "MatrixSolve" || - node.type_string() == "ResizeBilinearGrad"; + node.type_string() == "ResizeBilinearGrad" || + node.type_string() == "NonMaxSuppressionV3" || + node.type_string() == "NonMaxSuppressionV4"; } bool RecursiveCompilabilityChecker::IsCompilableNode( diff --git a/tensorflow/compiler/jit/compilability_check_util_test.cc b/tensorflow/compiler/jit/compilability_check_util_test.cc index 0fe2d2d2fe96b7..ea24176bb04a4a 100644 --- a/tensorflow/compiler/jit/compilability_check_util_test.cc +++ b/tensorflow/compiler/jit/compilability_check_util_test.cc @@ -51,6 +51,7 @@ constexpr char kUncompilableFunctionName[] = "UncompilableFn"; constexpr char kUncompilableFunctionNodeName[] = "n_c_uncompilable"; constexpr char kUncompilableFunctionTwoName[] = "UncompilableFnTwo"; constexpr char kUncompilableFunctionNodeTwoName[] = "n_d_uncompilable"; +constexpr char kNonMaxSuppressionNodeName[] = "NonMaxSuppression"; // A dummy OpKernel for testing. class DummyCompilableOp : public XlaOpKernel { @@ -63,6 +64,7 @@ class DummyCompilableOp : public XlaOpKernel { // Register the DummyCompilableOp kernel for CPU. REGISTER_OP("InputFloatOp").Output("o: float"); +REGISTER_OP("InputInt32Op").Output("o: int32"); REGISTER_OP("CompilableOp").Input("i: float").Output("o: float"); REGISTER_XLA_OP(Name("CompilableOp").Device(DEVICE_CPU_XLA_JIT), DummyCompilableOp); @@ -554,5 +556,90 @@ TEST_F(CompilabilityCheckUtilTest, TestCanTriggerXlaCompilation) { EXPECT_TRUE(CanTriggerXlaCompilation(graph_def)); } +TEST_F(CompilabilityCheckUtilTest, CheckNonMaxSuppressionV3UncompilableSlowOp) { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + auto opts = builder.opts(); + + Node* boxes = ops::SourceOp("InputFloatOp", opts); + Node* scores = ops::SourceOp("InputFloatOp", opts); + Node* max_output_size = ops::SourceOp("InputInt32Op", opts); + Node* iou_threshold = ops::SourceOp("InputFloatOp", opts); + Node* score_threshold = ops::SourceOp("InputFloatOp", opts); + + NodeBuilder non_max_suppression_builder( + kNonMaxSuppressionNodeName, "NonMaxSuppressionV3", opts.op_registry()); + non_max_suppression_builder.Input(boxes) + .Input(scores) + .Input(max_output_size) + .Input(iou_threshold) + .Input(score_threshold) + .Attr("T", DT_FLOAT); + Node* non_max_suppression; + non_max_suppression = + builder.opts().FinalizeBuilder(&non_max_suppression_builder); + + GraphDef graph_def; + TF_EXPECT_OK(builder.ToGraphDef(&graph_def)); + auto* flib_runtime = GetFunctionLibraryRuntime(); + + EXPECT_FALSE(checker_->IsCompilableNode(*non_max_suppression, flib_runtime)); + + const auto uncompilable_nodes = + checker_->FindUncompilableNodes(*non_max_suppression, flib_runtime); + ASSERT_EQ(1, uncompilable_nodes.size()); + auto node_info_it = + uncompilable_nodes.find(NameAttrList().ShortDebugString()); + ASSERT_NE(uncompilable_nodes.end(), node_info_it); + + const auto& uncompilable_nodes_inside_function = node_info_it->second.second; + ASSERT_EQ(1, uncompilable_nodes_inside_function.size()); + const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0); + EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason, + "slow operation")); +} + +TEST_F(CompilabilityCheckUtilTest, CheckNonMaxSuppressionV4UncompilableSlowOp) { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + auto opts = builder.opts(); + + Node* boxes = ops::SourceOp("InputFloatOp", opts); + Node* scores = ops::SourceOp("InputFloatOp", opts); + Node* max_output_size = ops::SourceOp("InputInt32Op", opts); + Node* iou_threshold = ops::SourceOp("InputFloatOp", opts); + Node* score_threshold = ops::SourceOp("InputFloatOp", opts); + + NodeBuilder non_max_suppression_v4_builder( + kNonMaxSuppressionNodeName, "NonMaxSuppressionV4", opts.op_registry()); + non_max_suppression_v4_builder.Input(boxes) + .Input(scores) + .Input(max_output_size) + .Input(iou_threshold) + .Input(score_threshold) + .Attr("T", DT_FLOAT); + Node* non_max_suppression_v4; + non_max_suppression_v4 = + builder.opts().FinalizeBuilder(&non_max_suppression_v4_builder); + + GraphDef graph_def; + TF_EXPECT_OK(builder.ToGraphDef(&graph_def)); + auto* flib_runtime = GetFunctionLibraryRuntime(); + + EXPECT_FALSE( + checker_->IsCompilableNode(*non_max_suppression_v4, flib_runtime)); + + const auto uncompilable_nodes = + checker_->FindUncompilableNodes(*non_max_suppression_v4, flib_runtime); + ASSERT_EQ(1, uncompilable_nodes.size()); + auto node_info_it = + uncompilable_nodes.find(NameAttrList().ShortDebugString()); + ASSERT_NE(uncompilable_nodes.end(), node_info_it); + + const auto& uncompilable_nodes_inside_function = node_info_it->second.second; + ASSERT_EQ(1, uncompilable_nodes_inside_function.size()); + const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0); + EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason, + "slow operation")); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_compilation_cache.h b/tensorflow/compiler/jit/device_compilation_cache.h index 6137d1bfd95adc..e6938024344b3d 100644 --- a/tensorflow/compiler/jit/device_compilation_cache.h +++ b/tensorflow/compiler/jit/device_compilation_cache.h @@ -107,8 +107,8 @@ class DeviceCompilationCache { const mutex_lock lock(compile_cache_mu_); absl::erase_if( cache_, - [&](std::pair>>& kv) { - const absl::Nullable entry = kv.second.get(); + [&](std::pair>& kv) { + Entry* absl_nullable const entry = kv.second.get(); if (entry == nullptr) { return true; } diff --git a/tensorflow/compiler/jit/device_compiler.h b/tensorflow/compiler/jit/device_compiler.h index fb0dbd2ae4171a..34b22033129b96 100644 --- a/tensorflow/compiler/jit/device_compiler.h +++ b/tensorflow/compiler/jit/device_compiler.h @@ -406,7 +406,7 @@ absl::Status DeviceCompiler::CompileAsynchronous( template void DeviceCompiler::Finalize() { const mutex_lock lock(cluster_mutexes_mu_); - std::vector> cluster_mutexes; + std::vector cluster_mutexes; cluster_mutexes.reserve(cluster_mutexes_.size()); for (auto& [_, mutex] : cluster_mutexes_) { if (mutex != nullptr) { @@ -420,7 +420,7 @@ void DeviceCompiler::Finalize() { absl::c_sort(cluster_mutexes); std::vector cluster_mutex_locks; cluster_mutex_locks.reserve(cluster_mutexes.size()); - for (const absl::Nonnull mutex : cluster_mutexes) { + for (mutex* absl_nonnull const mutex : cluster_mutexes) { cluster_mutex_locks.emplace_back(*mutex); } diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc index 74462a1cdfd1c6..dee77ac750ee54 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc @@ -188,7 +188,7 @@ absl::Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt( io::ZlibCompressionOptions::GZIP()); tstring decompressed_pbtxt_string; absl::Status s = in.ReadNBytes(INT_MAX, &decompressed_pbtxt_string); - if (!s.ok() && !errors::IsOutOfRange(s)) { + if (!s.ok() && !absl::IsOutOfRange(s)) { // OutOfRange is fine since we set the number of read bytes to INT_MAX. // Only return other kinds of errors. return s; diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index f9af695e33c163..2fa93816071225 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -255,7 +255,7 @@ absl::Status BuildXlaDeviceCompiler(DeviceBase* device, return platform.status(); } - absl::StatusOr compiler_for_platform = + absl::StatusOr> compiler_for_platform = xla::Compiler::GetForPlatform(platform.value()); if (!compiler_for_platform.ok()) { // In some rare cases (usually in unit tests with very small clusters) we diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 20c7d3abb35dfc..c11a761a089128 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -220,11 +220,11 @@ tf_cc_binary( srcs = ["tf_mlir_translate_main.cc"], deps = [ ":init_mlir", - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", - "//tensorflow/compiler/mlir/lite/tools:translate_registration", "//tensorflow/compiler/mlir/tensorflow:tf_xla_mlir_translate", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tf2xla/tests/registration:graph_to_tf_executor_registration", + "//tensorflow/compiler/mlir/tools:translate_cl_options", + "//tensorflow/compiler/mlir/tools:translate_registration", "//tensorflow/core:lib", "//tensorflow/core:tensorflow", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index b5cac23baa56b4..207effbe3d11e2 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -101,15 +101,10 @@ td_library( gentbl_cc_library( name = "tensorflow_lite_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlowLiteTd", - ], - "transforms/passes.h.inc", - ), - ], + tbl_outs = {"transforms/passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlowLiteTd", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/passes.td", deps = [ @@ -120,23 +115,14 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tfl_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tfl_ops.cc.inc", - ), - ( - [ - "-gen-dialect-doc", - "-dialect=tfl", - ], - "g3doc/tfl_ops.md", - ), - ], + tbl_outs = { + "ir/tfl_ops.h.inc": ["-gen-op-decls"], + "ir/tfl_ops.cc.inc": ["-gen-op-defs"], + "g3doc/tfl_ops.md": [ + "-gen-dialect-doc", + "-dialect=tfl", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfl_ops.td", deps = [ @@ -147,24 +133,12 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_op_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "ir/tfl_ops_interface.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "ir/tfl_ops_interface.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "ir/tfl_ops_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "ir/tfl_ops_dialect.cc.inc", - ), - ], + tbl_outs = { + "ir/tfl_ops_interface.h.inc": ["-gen-op-interface-decls"], + "ir/tfl_ops_interface.cc.inc": ["-gen-op-interface-defs"], + "ir/tfl_ops_dialect.h.inc": ["-gen-dialect-decls"], + "ir/tfl_ops_dialect.cc.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfl_op_interfaces.td", deps = [ @@ -175,24 +149,12 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_op_enums_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-enum-decls"], - "ir/tfl_ops_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "ir/tfl_ops_enums.cc.inc", - ), - ( - ["-gen-attrdef-decls"], - "ir/tfl_ops_attrdefs.h.inc", - ), - ( - ["-gen-attrdef-defs"], - "ir/tfl_ops_attrdefs.cc.inc", - ), - ], + tbl_outs = { + "ir/tfl_ops_enums.h.inc": ["-gen-enum-decls"], + "ir/tfl_ops_enums.cc.inc": ["-gen-enum-defs"], + "ir/tfl_ops_attrdefs.h.inc": ["-gen-attrdef-decls"], + "ir/tfl_ops_attrdefs.cc.inc": ["-gen-attrdef-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfl_op_enums.td", deps = [ @@ -203,12 +165,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_prepare_tf_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_prepare_tf.inc", - ), - ], + tbl_outs = {"transforms/generated_prepare_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/prepare_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -217,12 +174,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_lower_static_tensor_list_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_lower_static_tensor_list.inc", - ), - ], + tbl_outs = {"transforms/generated_lower_static_tensor_list.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/tensorlist_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -231,12 +183,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_tf_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_tf.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -245,12 +192,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_variables_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_variables.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_variables.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_variables.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -259,12 +201,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_optimize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_optimize.inc", - ), - ], + tbl_outs = {"transforms/generated_optimize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/optimize_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -273,12 +210,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_optimize_batch_matmul_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_optimize_batch_matmul.inc", - ), - ], + tbl_outs = {"transforms/generated_optimize_batch_matmul.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/optimize_batch_matmul.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -287,12 +219,7 @@ gentbl_cc_library( gentbl_cc_library( name = "optimize_broadcast_like_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_optimize_broadcast_like.inc", - ), - ], + tbl_outs = {"transforms/generated_optimize_broadcast_like.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/optimize_broadcast_like_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -301,12 +228,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_quantize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_quantize.inc", - ), - ], + tbl_outs = {"transforms/generated_quantize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/quantize_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -315,12 +237,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_quantize_by_converter_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_quantize_by_converter.inc", - ), - ], + tbl_outs = {"transforms/generated_quantize_by_converter.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/quantize_by_converter_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -329,12 +246,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_post_quantize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_post_quantize.inc", - ), - ], + tbl_outs = {"transforms/generated_post_quantize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/post_quantize_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -343,12 +255,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_tensorlist_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_tensorlist.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_tensorlist.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_tensorlist.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -380,12 +287,7 @@ cc_library( gentbl_cc_library( name = "tensorflow_lite_canonicalize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "ir/tfl_canonicalize.inc", - ), - ], + tbl_outs = {"ir/tfl_canonicalize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfl_canonicalize.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -395,6 +297,8 @@ cc_library( name = "utils", hdrs = ["utils/utils.h"], deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", @@ -402,6 +306,18 @@ cc_library( ], ) +tf_cc_test( + name = "utils_test", + srcs = ["utils/utils_test.cc"], + deps = [ + ":utils", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "attribute_utils", srcs = ["utils/attribute_utils.cc"], @@ -473,6 +389,7 @@ cc_library( deps = [ ":common", ":converter_flags_proto_cc", + ":optimize_broadcast_like_pass_options", ":optimize_pass_options", ":pass_options", ":pass_options_setter", @@ -587,6 +504,7 @@ cc_library( hdrs = [ "ir/tfl_ops.h", "transforms/canonicalize_boundary_value_pass.h", + "transforms/cleanup_optimization_barrier_pass.h", "transforms/optimize_batch_matmul_pass.h", "transforms/optimize_broadcast_like_pass.h", "transforms/optimize_pass.h", @@ -605,9 +523,11 @@ cc_library( deps = [ ":attribute_utils", ":canonicalize_boundary_value", + ":cleanup_optimization_barrier", ":converter_inc", ":cost_estimators", ":optimize_broadcast_like_pass", + ":optimize_broadcast_like_pass_options", ":optimize_pass_options", ":pass", ":pass_options", @@ -642,6 +562,8 @@ cc_library( "//tensorflow/core:framework", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", "@llvm-project//llvm:Support", @@ -957,6 +879,26 @@ cc_library( ], ) +cc_library( + name = "cleanup_optimization_barrier", + srcs = [ + "transforms/cleanup_optimization_barrier_pass.cc", + ], + hdrs = [ + "transforms/cleanup_optimization_barrier_pass.h", + ], + deps = [ + ":pass", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:stablehlo_ops", + ], +) + cc_library( name = "tensorflow_lite_legalize_tf_analyze_variables", srcs = [ @@ -1107,6 +1049,7 @@ cc_library( ":fake_quant_utils", ":lstm_utils", ":nms_utils", + ":optimize_broadcast_like_pass_options", ":perception_ops_utils", ":shape_and_size_utils", ":stateful_ops_utils", @@ -1125,7 +1068,6 @@ cc_library( "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf", "//tensorflow/compiler/mlir/lite/stablehlo:optimize_layout", "//tensorflow/compiler/mlir/lite/stablehlo:prepare_hlo", "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", @@ -1134,6 +1076,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/compiler/mlir/stablehlo:legalize_tf", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", @@ -1173,6 +1116,8 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", + "@local_xla//xla/mlir_hlo:type_conversion", + "@local_xla//xla/mlir_hlo:unfuse_batch_norm", "@stablehlo//:stablehlo_ops", ], ) @@ -1213,6 +1158,29 @@ cc_library( ], ) +cc_library( + name = "optimize_batch_matmul_utils", + srcs = ["transforms/tflite_passes/optimize_batch_matmul_utils.cc"], + hdrs = ["transforms/tflite_passes/optimize_batch_matmul_utils.h"], + deps = [ + ":utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "optimize_batch_matmul_utils_test", + srcs = ["transforms/tflite_passes/optimize_batch_matmul_utils_test.cc"], + deps = [ + ":optimize_batch_matmul_utils", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "tensorflow_lite_optimize_batch_matmul", srcs = [ @@ -1224,6 +1192,7 @@ cc_library( ], deps = [ ":convert_type", + ":optimize_batch_matmul_utils", ":pass", ":pass_options", ":tensorflow_lite_ops", @@ -1258,8 +1227,8 @@ cc_library( ], deps = [ ":optimize_broadcast_like_inc_gen", + ":optimize_broadcast_like_pass_options", ":pass", - ":pass_options", ":tensorflow_lite_ops", ":utils", "@llvm-project//llvm:Support", @@ -1267,6 +1236,7 @@ cc_library( "@llvm-project//mlir:Dialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", ], @@ -1340,6 +1310,7 @@ cc_library( deps = [ "convert_type", ":op_quant_spec_getters_inc", + ":optimize_broadcast_like_pass_options", ":shape_and_size_utils", ":stateful_ops_utils", ":tensorflow_lite", @@ -1358,6 +1329,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -1442,7 +1414,7 @@ filegroup( gentbl_cc_library( name = "op_quant_spec_getters_inc", compatible_with = get_compatible_with_portable(), - tbl_outs = [([], "utils/generated_op_quant_spec_getters.inc")], + tbl_outs = {"utils/generated_op_quant_spec_getters.inc": []}, tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen", td_file = "ir/tfl_ops.td", deps = [ @@ -1453,7 +1425,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tflite_op_coverage_spec_inc", compatible_with = get_compatible_with_portable(), - tbl_outs = [([], "utils/tflite_op_coverage_spec.inc")], + tbl_outs = {"utils/tflite_op_coverage_spec.inc": []}, tblgen = "//tensorflow/compiler/mlir/lite/quantization:tflite_op_coverage_spec_getters_gen", td_file = "ir/tfl_ops.td", visibility = ["//learning/brain/mobile/model_optimization/g3doc/autogen:__pkg__"], @@ -1478,16 +1450,10 @@ tf_native_cc_binary( gentbl_cc_library( name = "converter_inc", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["--gen-operator-converters"], - "operator_converters.inc", - ), - ( - ["--gen-runtime-verifiers"], - "runtime_verifiers.inc", - ), - ], + tbl_outs = { + "operator_converters.inc": ["--gen-operator-converters"], + "runtime_verifiers.inc": ["--gen-runtime-verifiers"], + }, tblgen = ":converter-gen", td_file = "ir/tfl_ops.td", test = 1, @@ -1741,6 +1707,15 @@ cc_library( ], ) +cc_library( + name = "optimize_broadcast_like_pass_options", + hdrs = ["transforms/optimize_broadcast_like_pass_options.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Pass", + ], +) + cc_library( name = "flatbuffer_translate_lib", hdrs = [ @@ -1845,7 +1820,6 @@ tf_cc_binary( ":tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/lite:converter_flags_proto_cc", - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", @@ -1889,7 +1863,6 @@ cc_library( ":tensorflow_lite_optimize_batch_matmul", # buildcleaner: keep ":tensorflow_lite_push_transpose_through_ewise_pass", # buildcleaner: keep ":tensorflow_lite_quantize", # buildcleaner: keep - ":tensorflow_lite_tf_unfreeze_global_tensors", ":variable_freezing_pipeline", "//tensorflow/compiler/mlir/lite/core:macros", "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index db9715e99c1acd..aa552ec43d138a 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -107,6 +107,12 @@ struct PassConfig { // When set to true, convert +Inf/-Inf to MIN/MAX float value and output of // convert only contains finite values. bool canonicalizing_inf_as_min_max_float = true; + + // When set to true, allows fusion of dynamic shaped broadcast ops. It helps + // fusing implicit broadcasting ops when output shape has dynamic dimensions, + // but it may cause incorrect results when broadcasting ops are introduced by + // explicit broadcasting in the source model. + bool unsafe_fuse_dynamic_shaped_broadcast = false; }; inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, @@ -133,6 +139,8 @@ inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, << pass_config.enable_stablehlo_conversion << "\nlegalize_custom_tensor_list_ops: " << pass_config.legalize_custom_tensor_list_ops + << "\nunsafe_fuse_dynamic_shaped_broadcast: " + << pass_config.unsafe_fuse_dynamic_shaped_broadcast << "\nreduce_type_precision: " << pass_config.reduce_type_precision << "\nconvert_qdq_format: " << GetQDQQuantModeString(pass_config.qdq_conversion_mode) diff --git a/tensorflow/compiler/mlir/lite/converter_flags.proto b/tensorflow/compiler/mlir/lite/converter_flags.proto index 5b6b9e2ca752a6..1c1a1ad00aea74 100644 --- a/tensorflow/compiler/mlir/lite/converter_flags.proto +++ b/tensorflow/compiler/mlir/lite/converter_flags.proto @@ -41,7 +41,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 68. +// Next ID to use: 69. message ConverterFlags { // Input file format optional FileFormat input_format = 1; @@ -385,4 +385,10 @@ message ConverterFlags { // possible rather than quantizing any op that is possible to quantize. // WARNING: Experimental interface, subject to change. optional bool strict_qdq_mode = 67 [default = false]; + + // When set to true, allows fusion of dynamic shaped broadcast ops. It helps + // fusing implicit broadcasting ops when output shape has dynamic dimensions, + // but it may cause incorrect results when broadcasting ops are introduced by + // explicit broadcasting in the source model. + optional bool unsafe_fuse_dynamic_shaped_broadcast = 68 [default = false]; } diff --git a/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h index 1327162f23262b..c580bf03cd3f59 100644 --- a/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h +++ b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h @@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + /// WARNING: Users of TensorFlow Lite should not include this file directly, -/// but should instead include -/// "third_party/tensorflow/lite/c/builtin_op_data.h". -/// Only the TensorFlow Lite implementation itself should include this -/// file directly. +/// only the TensorFlow Lite implementation itself should. + +// IWYU pragma: private, include "third_party/tensorflow/lite/c/builtin_op_data.h" + #ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD index f3edb169515bb6..e8018e4d9acdae 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD @@ -98,12 +98,7 @@ cc_library( gentbl_cc_library( name = "transform_patterns_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_transform_patterns.inc", - ), - ], + tbl_outs = {"transforms/generated_transform_patterns.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/transform_patterns.td", deps = [ diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc index 19cd2e081a7d1e..91dc26155fc659 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc @@ -59,6 +59,15 @@ double GpuHardware::GetHardwareSwitchingCost(const TargetHardware* from, kCrossHardwareTransferFixedCost; } +bool GpuHardware::IsOpSupported(mlir::Operation* op) const { + if (TargetHardware::IsOpSupported(op)) { + return true; + } + + // We also support quantized ops. + return !NotTFLQuantDequantizeOp(op); +} + namespace { // GPU constexpr float kGPUArithmeticUnitCost = 0.2; diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h index 149c2076a6154a..cc13c6e36be269 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h @@ -41,6 +41,8 @@ class GpuHardware : public TargetHardware { double GetHardwareSwitchingCost(const TargetHardware* from, size_t buffer_size) const override; + + bool IsOpSupported(mlir::Operation* op) const override; }; } // namespace tac } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc index fd4852b34ed3cf..f9a14eef837823 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc @@ -202,7 +202,6 @@ bool AlternativeSubgraphPass::IsAllSupportedbySpec( bool found_unsupported = false; func.walk([&](Operation* op) { if (IsNonConstOp(op) && !IsTerminatorOp(op) && - NotTFLQuantDequantizeOp(op) && !llvm::isa(op) && !IsSupported(op, device_inference_type.hardware)) { found_unsupported = true; diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc index 61d481fe6a3242..8dee7c0902262a 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc @@ -127,12 +127,11 @@ void ApplyTacFilter( } auto should_filter_op = [](mlir::Operation* op) { - return IsNonConstOp(op) && NotTFLQuantDequantizeOp(op) && - !IsTerminatorOp(op) && + return IsNonConstOp(op) && !IsTerminatorOp(op) && !llvm::isa(op); }; - auto map_op_to_cpu = [&](mlir::Operation* op, std::string name) { + auto map_op_to_cpu = [&](mlir::Operation* op) { if (!should_filter_op(op)) { return; } @@ -181,7 +180,7 @@ void ApplyTacFilter( switch (match_type) { case OpFilter::MATCH: if (device_type == OpFilter::CPU) { - map_op_to_cpu(op, loc.getName().str()); + map_op_to_cpu(op); return; } map_op_to_custom_device(op); @@ -193,7 +192,7 @@ void ApplyTacFilter( switch (match_type) { case OpFilter::INVERT_MATCH: if (device_type == OpFilter::CPU) { - map_op_to_cpu(op, loc.getName().str()); + map_op_to_cpu(op); return; } map_op_to_custom_device(op); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc index 6d1bf7ab9341df..e3d1a4e47e782d 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc @@ -140,8 +140,7 @@ void TargetAnnotationPass::runOnFunction() { func.walk([&](Operation* op) { // We only care about TFL dialect. - if (IsNonConstOp(op) && NotTFLQuantDequantizeOp(op) && - !IsTerminatorOp(op) && + if (IsNonConstOp(op) && !IsTerminatorOp(op) && !llvm::isa(op)) { SetTargetAnnotation(op, device_specs_flag_, &builder); } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 3429c156f1551b..6045278ffa541e 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -1716,30 +1716,34 @@ void CreateFlexbufferVector( const std::unique_ptr& flex_builder, std::string& name, const mlir::Attribute& attr) { auto start = flex_builder->StartVector(name.c_str()); - auto array = attr.cast().getValue(); + auto array = mlir::cast(attr).getValue(); for (int i = 0; i < array.size(); i++) { if (llvm::isa(array[i])) { flex_builder->Bool(name.c_str(), - array[i].cast().getValue()); + mlir::cast(array[i]).getValue()); } else if (llvm::isa(attr)) { - flex_builder->String(name.c_str(), - array[i].cast().getValue().str()); + flex_builder->String( + name.c_str(), + mlir::cast(array[i]).getValue().str()); } else if (llvm::isa(array[i])) { - flex_builder->Bool(name.c_str(), - array[i].cast().getValue()); + flex_builder->Bool( + name.c_str(), + mlir::cast(array[i]).getValue()); } else if (llvm::isa(array[i])) { flex_builder->String( name.c_str(), - array[i].cast().getValue().str()); + mlir::cast(array[i]).getValue().str()); } else if (llvm::isa(array[i])) { - flex_builder->Int( - name.c_str(), - array[i].cast().getValue().getSExtValue()); + flex_builder->Int(name.c_str(), + mlir::cast(array[i]) + .getValue() + .getSExtValue()); } else if (llvm::isa(array[i])) { - flex_builder->Float( - name.c_str(), - array[i].cast().getValue().convertToFloat()); + flex_builder->Float(name.c_str(), + mlir::cast(array[i]) + .getValue() + .convertToFloat()); } else if (llvm::isa(array[i])) { CreateFlexbufferVector(flex_builder, name, array[i]); @@ -1835,43 +1839,49 @@ Translator::BuildVhloCompositeV1Op(mlir::vhlo::CompositeOpV1 composite_op, uint32_t opcode_index = GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_COMPOSITE); - int32_t api_version = composite_op.getVersion() - .cast() - .getValue() - .getSExtValue(); + int32_t api_version = + mlir::cast(composite_op.getVersion()) + .getValue() + .getSExtValue(); auto name = builder_.CreateString( - composite_op.getName().cast().getValue().str()); + mlir::cast(composite_op.getName()) + .getValue() + .str()); - auto composite_attributes = composite_op.getCompositeAttributes() - .cast(); + auto composite_attributes = mlir::cast( + composite_op.getCompositeAttributes()); auto flex_builder = std::make_unique(); size_t map_start = flex_builder->StartMap(); for (auto namedAttr : composite_attributes.getValue()) { auto name = - namedAttr.first.cast().getValue().str(); + mlir::cast(namedAttr.first).getValue().str(); auto attr = namedAttr.second; if (llvm::isa(attr)) - flex_builder->Bool(name.c_str(), attr.cast().getValue()); + flex_builder->Bool(name.c_str(), + mlir::cast(attr).getValue()); else if (llvm::isa(attr)) flex_builder->String(name.c_str(), - attr.cast().getValue().str()); + mlir::cast(attr).getValue().str()); else if (llvm::isa(attr)) - flex_builder->Bool(name.c_str(), - attr.cast().getValue()); + flex_builder->Bool( + name.c_str(), mlir::cast(attr).getValue()); else if (llvm::isa(attr)) flex_builder->String( - name.c_str(), attr.cast().getValue().str()); - else if (llvm::isa(attr)) - flex_builder->Int( name.c_str(), - attr.cast().getValue().getSExtValue()); + mlir::cast(attr).getValue().str()); + else if (llvm::isa(attr)) + flex_builder->Int(name.c_str(), + mlir::cast(attr) + .getValue() + .getSExtValue()); else if (llvm::isa(attr)) - flex_builder->Float( - name.c_str(), - attr.cast().getValue().convertToFloat()); + flex_builder->Float(name.c_str(), + mlir::cast(attr) + .getValue() + .convertToFloat()); else if (llvm::isa(attr)) CreateFlexbufferVector(flex_builder, name, attr); else if (llvm::isa(attr)) { @@ -1932,8 +1942,8 @@ Translator::BuildVhloCompositeV1Op(mlir::vhlo::CompositeOpV1 composite_op, flex_builder->Finish(); int32_t decomposition_subgraph_index = - subgraph_index_map_[composite_op.getDecomposition() - .cast() + subgraph_index_map_[mlir::cast( + composite_op.getDecomposition()) .getValue() .str()]; diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 132d87c93cd4f2..12b1da1da8ab96 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -377,10 +377,8 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b, // min/max stats is just for comments, so ignore it. if (!tensor.quantization || tfl::IsQuantized(tensor)) return nullptr; // If the result isn't float and unquantizable, the min/max is ignored. - if (!res.getType() - .cast() - .getElementType() - .isa()) { + if (!llvm::isa( + llvm::cast(res.getType()).getElementType())) { return nullptr; } auto mins = tensor.quantization->min; @@ -438,7 +436,7 @@ StatusOr BuildExternalConstOp(const tflite::TensorT& tensor, TF_ASSIGN_OR_RETURN(mlir::TensorType type, tfl::GetTensorType(tensor, builder, /*is_constant=*/true)); - auto shaped_type = type.dyn_cast(); + auto shaped_type = llvm::dyn_cast(type); if (!shaped_type) { return errors::Internal("Constant doesn't have a shape"); } @@ -457,7 +455,7 @@ StatusOr BuildVariableOp(const tflite::TensorT& tensor, TF_ASSIGN_OR_RETURN(mlir::TensorType type, tfl::GetTensorType(tensor, builder, /*is_constant=*/true)); - auto shaped_type = type.dyn_cast(); + auto shaped_type = llvm::dyn_cast(type); if (!shaped_type) { return errors::Internal("Constant doesn't have a shape"); } @@ -510,7 +508,7 @@ static StatusOr BuildSparseConstOp( TF_ASSIGN_OR_RETURN(mlir::TensorType type, tfl::GetTensorType(tensor, builder, /*is_constant=*/true)); - auto shaped_type = type.dyn_cast(); + auto shaped_type = llvm::dyn_cast(type); if (!shaped_type) { return errors::Internal("Constant doesn't have a shape"); } @@ -598,7 +596,7 @@ StatusOr BuildConstOp(const tflite::TensorT& tensor, /*is_constant=*/true, /*is_intermediate=*/false, /*get_storage=*/true)); - auto shaped_type = type.dyn_cast(); + auto shaped_type = llvm::dyn_cast(type); if (!shaped_type) { return errors::Internal("Constant doesn't have a shape"); } @@ -619,11 +617,11 @@ StatusOr BuildConstOp(const tflite::TensorT& tensor, } auto elem_type = shaped_type.getElementType(); - if (auto float_type = elem_type.dyn_cast()) { + if (auto float_type = llvm::dyn_cast(elem_type)) { TF_ASSIGN_OR_RETURN(value, tfl::ConvertFloatBuffer(shaped_type, buffer)); - } else if (elem_type.isa()) { + } else if (llvm::isa(elem_type)) { TF_ASSIGN_OR_RETURN(value, tfl::ConvertIntBuffer(shaped_type, buffer)); - } else if (elem_type.isa()) { + } else if (llvm::isa(elem_type)) { tensorflow::TensorProto repr = tfl::ConvertTfliteConstTensor(tensor, buffer); std::vector refs; @@ -633,7 +631,8 @@ StatusOr BuildConstOp(const tflite::TensorT& tensor, refs.push_back({ref.data(), ref.size()}); value = mlir::DenseStringElementsAttr::get(shaped_type, refs); - } else if (elem_type.isa()) { + } else if (llvm::isa( + elem_type)) { tensorflow::TensorProto repr = tfl::ConvertTfliteConstTensor(tensor, buffer); std::string mangled = tensorflow::mangling_util::MangleTensor(repr); @@ -929,8 +928,8 @@ StatusOr ConvertOp( // Flattens reshape ops when more than one dimension shape operand is given. mlir::DenseIntElementsAttr shape_attr; if (matchPattern(op_state.operands[1], m_Constant(&shape_attr))) { - auto shape_ty = - op_state.operands[1].getType().dyn_cast(); + auto shape_ty = llvm::dyn_cast( + op_state.operands[1].getType()); if (shape_ty != nullptr && shape_ty.hasRank() && shape_ty.getRank() > 1) { llvm::SmallVector shape; int32_t dim_size = 0; @@ -1117,8 +1116,8 @@ static StatusOr PostProcessFuncOp(FuncOp func) { value.getType()); // Only the 8-bit constants are imported with narrow range. if (!qtype || qtype.getStorageTypeIntegralWidth() != 8 || - !(qtype.isa() || - qtype.isa())) { + !(llvm::isa(qtype) || + llvm::isa(qtype))) { return; } for (auto& use : value.getUses()) { @@ -1134,14 +1133,16 @@ static StatusOr PostProcessFuncOp(FuncOp func) { if (full_range_const == value) { mlir::quant::QuantizedType new_qtype; if (auto per_axis = - qtype.dyn_cast()) { + llvm::dyn_cast( + qtype)) { new_qtype = mlir::quant::UniformQuantizedPerAxisType::get( per_axis.getFlags(), per_axis.getStorageType(), per_axis.getExpressedType(), per_axis.getScales(), per_axis.getZeroPoints(), per_axis.getQuantizedDimension(), per_axis.getStorageTypeMin() - 1, per_axis.getStorageTypeMax()); } else if (auto per_tensor = - qtype.dyn_cast()) { + llvm::dyn_cast( + qtype)) { new_qtype = mlir::quant::UniformQuantizedType::get( per_tensor.getFlags(), per_tensor.getStorageType(), per_tensor.getExpressedType(), per_tensor.getScale(), @@ -1185,7 +1186,8 @@ int GetTensorIndex(const std::string& tensor_name, llvm::SmallVector GetStringsFromAttrWithSeparator( mlir::DictionaryAttr attr, const std::string& attr_key) { llvm::SmallVector result; - if (auto str = attr.get(attr_key).dyn_cast_or_null()) { + if (auto str = + llvm::dyn_cast_if_present(attr.get(attr_key))) { str.getValue().split(result, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); } @@ -1643,11 +1645,13 @@ void AddRegionsForTflWhileOp(mlir::ModuleOp module) { mlir::SymbolTable symbol_table(module); module.walk([&](mlir::TFL::WhileOp while_op) { auto cond = symbol_table.lookup( - while_op->getAttr("cond").cast().getValue()); + llvm::cast(while_op->getAttr("cond")) + .getValue()); AddCallOpInWhileOpRegion(while_op.getCond(), cond); while_op->removeAttr("cond"); auto body = symbol_table.lookup( - while_op->getAttr("body").cast().getValue()); + llvm::cast(while_op->getAttr("body")) + .getValue()); AddCallOpInWhileOpRegion(while_op.getBody(), body); while_op->removeAttr("body"); }); @@ -1658,15 +1662,15 @@ void AddRegionsForStableHLOOp(mlir::ModuleOp module) { std::vector to_delete_funcs; module.walk([&](mlir::vhlo::ReduceOpV1 reduce_op) { auto body = symbol_table.lookup( - reduce_op->getAttr("body").cast().getValue()); + llvm::cast(reduce_op->getAttr("body")) + .getValue()); InlineVhloOpRegion(reduce_op.getBody(), body); reduce_op->removeAttr("body"); to_delete_funcs.push_back(body); }); module.walk([&](mlir::vhlo::ReduceWindowOpV1 reduce_window_op) { auto body = symbol_table.lookup( - reduce_window_op->getAttr("body") - .cast() + llvm::cast(reduce_window_op->getAttr("body")) .getValue()); InlineVhloOpRegion(reduce_window_op.getBody(), body); reduce_window_op->removeAttr("body"); @@ -1674,8 +1678,8 @@ void AddRegionsForStableHLOOp(mlir::ModuleOp module) { }); module.walk([&](mlir::vhlo::ScatterOpV1 scatter_op) { auto update_computation = symbol_table.lookup( - scatter_op->getAttr(kScatterRegionFuncName) - .cast() + llvm::cast( + scatter_op->getAttr(kScatterRegionFuncName)) .getValue()); InlineVhloOpRegion(scatter_op.getUpdateComputation(), update_computation); scatter_op->removeAttr(kScatterRegionFuncName); @@ -1683,8 +1687,7 @@ void AddRegionsForStableHLOOp(mlir::ModuleOp module) { }); module.walk([&](mlir::vhlo::SortOpV1 sort_op) { auto comparator = symbol_table.lookup( - sort_op->getAttr("comparator") - .cast() + llvm::cast(sort_op->getAttr("comparator")) .getValue()); InlineVhloOpRegion(sort_op.getComparator(), comparator); sort_op->removeAttr("comparator"); @@ -1692,11 +1695,13 @@ void AddRegionsForStableHLOOp(mlir::ModuleOp module) { }); module.walk([&](mlir::vhlo::WhileOpV1 while_op) { auto cond = symbol_table.lookup( - while_op->getAttr("cond").cast().getValue()); + llvm::cast(while_op->getAttr("cond")) + .getValue()); InlineVhloOpRegion(while_op.getCond(), cond); while_op->removeAttr("cond"); auto body = symbol_table.lookup( - while_op->getAttr("body").cast().getValue()); + llvm::cast(while_op->getAttr("body")) + .getValue()); InlineVhloOpRegion(while_op.getBody(), body); while_op->removeAttr("body"); to_delete_funcs.push_back(body); diff --git a/tensorflow/compiler/mlir/lite/integrations/BUILD b/tensorflow/compiler/mlir/lite/integrations/BUILD index 1a54d980c52074..899c936e9929a9 100644 --- a/tensorflow/compiler/mlir/lite/integrations/BUILD +++ b/tensorflow/compiler/mlir/lite/integrations/BUILD @@ -20,8 +20,7 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/compiler/mlir/lite/integrations:__subpackages__", - "//tensorflow/lite/experimental/litert/python/google/tools/model_utils:__subpackages__", - "//third_party/odml/litert/litert/python/google/tools/model_utils:__subpackages__", + "//third_party/odml/litert/litert/python/tools/model_utils:__subpackages__", ], licenses = ["notice"], ) diff --git a/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc b/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc index 53095f8436ccd5..42ae13c57e4272 100644 --- a/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc +++ b/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/Dialect/Func/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -71,7 +72,7 @@ class MlirPythonPass pyfunc.inc_ref(); } - ~MlirPythonPass() = default; + ~MlirPythonPass() override = default; mlir::StringRef getName() const override { return name_; } mlir::StringRef getArgument() const override { return name_; } @@ -198,6 +199,15 @@ PYBIND11_MODULE(model_utils_core_pybind, m) { return attr_names; }); + m.def("get_dictionary_attr_names", [](MlirAttribute c_attr) { + auto attr = mlir::cast(unwrap(c_attr)); + std::vector attr_names; + for (auto attr : attr) { + attr_names.push_back(attr.getName().str()); + } + return attr_names; + }); + m.def("get_elements_attr_buffer", [](MlirAttribute c_attr) { auto attr = mlir::cast(unwrap(c_attr)); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td b/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td index a359fb9506b2b3..3881a1e291770d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td @@ -32,7 +32,7 @@ def GetSqueezedPermutation: NativeCodeCall<"GetSqueezedPermutation($0, $1)">; // Check to see if the tensor dimensions can be Squeezed by eliminating 1s' def CanSqueezeTensor : Constraint GetSqueezedShape($0).getNumElements()">>; + "GetShapeAttr($0).getNumElements() > GetSqueezedShape($0).getNumElements()">>; // Pattern to convert TFL_TransposeOp with rank>6 to rank<=6 if there are @@ -50,7 +50,7 @@ def ConvertTransposeToDecreaseRank : Pat< (TFL_TransposeOp (TFL_ReshapeOp $input, (Arith_ConstantOp (GetSqueezedShape $input))), (Arith_ConstantOp (GetSqueezedPermutation $input, $permutation))), - (Arith_ConstantOp (GetShape $output_transpose))), + (Arith_ConstantOp (GetShapeAttr $output_transpose))), [(AnyStaticShapeTensor $input), (HasRankAtLeast<7> $input), (CanSqueezeTensor $input)]>; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td index fa85389789e554..57e4ec22976df3 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td @@ -27,9 +27,9 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td" // Referred TF_AnyStrAttrOf in tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td class TFL_AnyStrAttrOf cases> : StringBasedAttr< CPred().getValue() == \"" # !head(cases) # "\"", + "llvm::cast($_self).getValue() == \"" # !head(cases) # "\"", !foreach(case, !tail(cases), - "$_self.cast().getValue() == \"" # case # "\""), + "llvm::cast($_self).getValue() == \"" # case # "\""), prev, cur, prev # " || " # cur)>, "string attribute whose value is " # !foldl(/*init*/!head(cases), /*list*/!tail(cases), diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index dfa9f5b094b949..9ca7f1451cba2d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -260,6 +260,58 @@ bool ShouldFoldOperation(Operation* inst) { (results_size <= kSizeFactor * operands_size)); } +// Returns dimension index for the given axis that supports negative +// indexing. +int64_t NormalizeDim(int64_t axis, int64_t rank) { + return axis >= 0 ? axis : axis + rank; +} + +Type InferReductionOpType(Value input, Value reduction_indices, + BoolAttr keep_dims) { + Type input_ty = input.getType(); + Type element_ty = getElementTypeOrSelf(input_ty); + + // Output type is unranked if input type is not ranked. + auto ranked_ty = mlir::dyn_cast(input_ty); + if (!ranked_ty) return UnrankedTensorType::get(element_ty); + int64_t rank = ranked_ty.getRank(); + + DenseIntElementsAttr indices; + if (!matchPattern(reduction_indices, m_Constant(&indices))) { + // Output type is unranked if reduction indices are not constant and reduced + // dimensions are not kept. + if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty); + + // Otherwise, output type has same rank as the input. + return RankedTensorType::get( + SmallVector(rank, ShapedType::kDynamic), element_ty); + } + + int64_t num_reduce_dim = 0; + llvm::SmallVector is_reduce_dim(rank, false); + for (const APInt& index : indices.getValues()) { + int64_t dim = NormalizeDim(index.getSExtValue(), rank); + // Invalid input. + assert(dim >= 0 && dim < rank); + + if (!is_reduce_dim[dim]) { + is_reduce_dim[dim] = true; + num_reduce_dim++; + } + } + + ArrayRef shape = ranked_ty.getShape(); + SmallVector out_shape; + out_shape.reserve(rank - (keep_dims.getValue() ? 0 : num_reduce_dim)); + for (int64_t i = 0; i < rank; ++i) { + if (!is_reduce_dim[i]) + out_shape.push_back(shape[i]); + else if (keep_dims.getValue()) + out_shape.push_back(1); + } + return RankedTensorType::get(out_shape, element_ty); +} + #include "tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.inc" } // namespace @@ -1529,7 +1581,7 @@ LogicalResult FullyConnectedOp::verify() { // Input's element size must be multiple of parameter's z_in dimension. const int z_in = filter_type.getDimSize(1); - const int num_input_elements = input_type.getNumElements(); + const int64_t num_input_elements = input_type.getNumElements(); if (z_in != 0 && num_input_elements % z_in != 0) { return op.emitOpError(llvm::formatv( "expect 'input' num_elements % {0} == 0, got input type ", z_in)) @@ -1545,7 +1597,7 @@ LogicalResult FullyConnectedOp::verify() { return mlir::success(); } - const int num_output_elements = output_type.getNumElements(); + const int64_t num_output_elements = output_type.getNumElements(); const int z_out = filter_type.getDimSize(0); if (num_output_elements % z_out != 0) { return op.emitOpError(llvm::formatv( @@ -2230,21 +2282,18 @@ namespace { // * The input's defining op is another tfl.reshape. // TODO(antiagainst): This pattern probably should be moved to the peephole // category, after we have the infra for peephole passes. -struct RemoveAdjacentReshape : public RewritePattern::SplitMatchAndRewrite { +struct RemoveAdjacentReshape : public RewritePattern { explicit RemoveAdjacentReshape(MLIRContext* context) - : RewritePattern::SplitMatchAndRewrite(ReshapeOp::getOperationName(), 1, - context) {} - - LogicalResult match(Operation* op) const override { - auto thisOp = cast(op); - auto prevOp = thisOp.getOperand(0).getDefiningOp(); - return isa_and_nonnull(prevOp) ? success() : failure(); - } + : RewritePattern(ReshapeOp::getOperationName(), 1, context) {} - void rewrite(Operation* op, PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { auto thisOp = cast(op); - auto prevOp = cast(thisOp.getOperand(0).getDefiningOp()); - + auto prevOp = + dyn_cast_or_null(thisOp.getOperand(0).getDefiningOp()); + if (!prevOp) { + return failure(); + } // Replace // %1 = "tfl.reshape"(%0, %shape0) // %2 = "tfl.reshape"(%1, %shape1) @@ -2252,6 +2301,7 @@ struct RemoveAdjacentReshape : public RewritePattern::SplitMatchAndRewrite { // %2 = "tfl.reshape"(%0, %shape1) rewriter.replaceOpWithNewOp( op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1)); + return success(); } }; @@ -2963,12 +3013,12 @@ namespace { /// This pattern matches and remove a tfl.fake_quant if all the users of this op /// and itself have "minmax" attribute set. -struct DropFakeQuant : public RewritePattern::SplitMatchAndRewrite { +struct DropFakeQuant : public RewritePattern { explicit DropFakeQuant(MLIRContext* context) - : RewritePattern::SplitMatchAndRewrite(FakeQuantOp::getOperationName(), 1, - context) {} + : RewritePattern(FakeQuantOp::getOperationName(), 1, context) {} - LogicalResult match(Operation* op) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { // We only match the op with valid "minmax" attribute. if (!HasValidMinMaxAttribute(op)) return failure(); @@ -2978,12 +3028,9 @@ struct DropFakeQuant : public RewritePattern::SplitMatchAndRewrite { for (auto* operand : fakeQuantOp.getResult().getUsers()) if (!HasValidMinMaxAttribute(operand)) return failure(); - return success(); - } - - void rewrite(Operation* op, PatternRewriter& rewriter) const override { // Replace the matched FakeQuantOp by its primary operand. rewriter.replaceOp(op, op->getOperand(0)); + return success(); } }; } // end anonymous namespace @@ -4041,6 +4088,12 @@ OpFoldResult SumOp::fold(FoldAdaptor adaptor) { return DenseFPElementsAttr::get(out_type, out_data); } +void SumOp::build(OpBuilder& builder, OperationState& result, Value input, + Value axes, BoolAttr keep_dims) { + Type out_ty = InferReductionOpType(input, axes, keep_dims); + build(builder, result, out_ty, input, axes, keep_dims); +} + //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// @@ -4447,6 +4500,27 @@ int64_t TransposeConvOp::GetArithmeticCount(Operation* op) { // StridedSliceOp //===----------------------------------------------------------------------===// +bool VerifyStridedSliceOpInputRankConstraints(StridedSliceOp op) { + auto ranked_input_type = + mlir::dyn_cast(op.getInput().getType()); + + // If input is unranked, there is nothing else to be verified. + if (!ranked_input_type) return true; + const int num_input_dims = ranked_input_type.getRank(); + + // The kernel will reshape the input tensor with new axis, it only supports + // this reshaped tensor up to 5D. + const uint32_t ellipsis_mask = op.getEllipsisMask(); + const uint32_t new_axis_mask = op.getNewAxisMask(); + int num_added_axis = 0; + for (int i = 0; i < 8; ++i) { + if (!((1 << i) & ellipsis_mask) && ((1 << i) & new_axis_mask)) { + num_added_axis++; + } + } + return (num_input_dims + num_added_axis <= 5); +} + LogicalResult StridedSliceOp::verify() { StridedSliceOp op = *this; auto ranked_input_type = @@ -4473,17 +4547,6 @@ LogicalResult StridedSliceOp::verify() { if (strides_type.getDimSize(0) > num_input_dims) return failure(); } - // The kernel will reshape the input tensor with new axis, it only supports - // this reshaped tensor up to 5D. - uint32_t ellipsis_mask = op.getEllipsisMask(); - uint32_t new_axis_mask = op.getNewAxisMask(); - int num_added_axis = 0; - for (int i = 0; i < 8; ++i) { - if (!((1 << i) & ellipsis_mask) && ((1 << i) & new_axis_mask)) { - num_added_axis++; - } - } - if (num_input_dims + num_added_axis > 5) return failure(); return success(); } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 986e02fe8e335a..5b082fe74b53d3 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -34,21 +34,21 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" //===----------------------------------------------------------------------===// // TFLite dialect string type - uses the TF string type as implementation //===----------------------------------------------------------------------===// -def TFL_Str : Type()">, +def TFL_Str : Type($_self)">, "TFLite string type">, BuildableType<"getType()">; //===----------------------------------------------------------------------===// // TFLite dialect quint8 type - uses the TF quint8 type as implementation //===----------------------------------------------------------------------===// -def TFL_Quint8 : Type()">, +def TFL_Quint8 : Type($_self)">, "TFLite quint8 type">, BuildableType<"getType()">; //===----------------------------------------------------------------------===// // Type that represents control dependencies //===----------------------------------------------------------------------===// -def TFL_Control: Type()">, "control">, +def TFL_Control: Type($_self)">, "control">, BuildableType<"$_builder.getType()">; @@ -151,10 +151,10 @@ def TFL_StatefulTensor : TypeAlias; // Returns true of operand is none type. class TFL_OperandIsNoneType : - CPred<"$_op.getOperand(" # i # ").getType().isa()">; + CPred<"llvm::isa($_op.getOperand(" # i # ").getType())">; class TFL_OperandIsUnrankedPred : - CPred<"$_op.getOperand(" # n # ").getType().isa()">; + CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">; // TODO: Some of these could be generalized and/or moved to more general // location. @@ -162,52 +162,52 @@ class TFL_OperandIsUnrankedPred : class TFL_OperandHasRank : PredOpTrait<"operand " # n # " is " # m # "-D", Or<[TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() == " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() == " # m>]>>; // Returns true if the n-th operand is ranked and has rank dim. class TFL_OperandHasKnownRank : And<[ - CPred<"$_op.getOperand(" # n # ").getType().isa()">, - CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() == " + CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">, + CPred<"llvm::cast($_op.getOperand(" # n # ").getType()).getRank() == " # dim>]>; // True if operand n is ranked and has a rank > dim. class TFL_OperandIsRankedAndHasDimPred : And<[ - CPred<"$_op.getOperand(" # n # ").getType().isa()">, - CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() > " + CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">, + CPred<"llvm::cast($_op.getOperand(" # n # ").getType()).getRank() > " # dim>]>; // Returns true if the n-th operand is ranked and has a dimension length = size // at the rank dim. class TFL_OperandDimEquals : And<[ TFL_OperandIsRankedAndHasDimPred, - CPred<"$_op.getOperand(" # n # ").getType().cast()" + CPred<"llvm::cast($_op.getOperand(" # n # ").getType())" ".getShape()[" # dim # " ] == " # size>]>; // Returns true if the n-th operand is ranked and has a dimension length <= // size at the rank dim. class TFL_OperandDimIsAtMost : And<[ TFL_OperandIsRankedAndHasDimPred, - CPred<"$_op.getOperand(" # n # ").getType().cast()" + CPred<"llvm::cast($_op.getOperand(" # n # ").getType())" ".getShape()[" # dim # " ] <= " # size>]>; // Returns true if the n-th operand has unknown rank or at least rank m. class TFL_OperandHasAtleastRank : PredOpTrait<"operand " # n # " is " # m # "-D", - Or<[CPred<"$_op.getOperand(" # n # ").getType().isa()">, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() >= " # m>]>>; + Or<[CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">, + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() >= " # m>]>>; class TFL_OperandRankEquals1DimOfOperand : PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size", Or<[TFL_OperandIsUnrankedPred, TFL_OperandIsUnrankedPred, - CPred<"!$_op.getOperand(" # y # - ").getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(" # x # - ").getType().cast().getRank() == " - "$_op.getOperand(" # y # - ").getType().cast().getShape()[0]">]>>; + CPred<"!llvm::cast($_op.getOperand(" # y # + ").getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(" # x # + ").getType()).getRank() == " + "llvm::cast($_op.getOperand(" # y # + ").getType()).getShape()[0]">]>>; class TFL_Operand0DOr1ElementTensor : PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element", @@ -219,14 +219,14 @@ class TFL_Operand0DOr1ElementTensor : class TFL_OperandsHaveSameDims : Or<[TFL_OperandIsUnrankedPred, TFL_OperandIsUnrankedPred, - CPred<"!$_op.getOperand(" # x # - ").getType().cast().hasStaticShape()">, - CPred<"!$_op.getOperand(" # y # - ").getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(" # x # - ").getType().cast().getShape()[" # i # "] == " - "$_op.getOperand(" # y # - ").getType().cast().getShape()[" # j # "]">]>; + CPred<"!llvm::cast($_op.getOperand(" # x # + ").getType()).hasStaticShape()">, + CPred<"!llvm::cast($_op.getOperand(" # y # + ").getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(" # x # + ").getType()).getShape()[" # i # "] == " + "llvm::cast($_op.getOperand(" # y # + ").getType()).getShape()[" # j # "]">]>; class TFL_OperandsHaveSameDimsTrait : PredOpTrait<"dim " # i # " of operand " # x # " equals to dim " # j # @@ -238,14 +238,14 @@ class TFL_OperandsHaveSameDimsTrait : class TFL_NumElementsEqualsDim : Or<[TFL_OperandIsUnrankedPred, TFL_OperandIsUnrankedPred, - CPred<"!$_op.getOperand(" # x # - ").getType().cast().hasStaticShape()">, - CPred<"!$_op.getOperand(" # y # - ").getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(" # x # - ").getType().cast().getNumElements() == " - "$_op.getOperand(" # y # - ").getType().cast().getShape()[" # j # "]">]>; + CPred<"!llvm::cast($_op.getOperand(" # x # + ").getType()).hasStaticShape()">, + CPred<"!llvm::cast($_op.getOperand(" # y # + ").getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(" # x # + ").getType()).getNumElements() == " + "llvm::cast($_op.getOperand(" # y # + ").getType()).getShape()[" # j # "]">]>; class TFL_NumElementsEqualsDimTrait : PredOpTrait<"operand " # x # " has num of elements equals to dim " # j # @@ -255,10 +255,10 @@ class TFL_NumElementsEqualsDimTrait : // Return true if number of elements of x-th operand equals to n. class TFL_NumElements : Or<[TFL_OperandIsUnrankedPred, - CPred<"!$_op.getOperand(" # x # - ").getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(" # x # - ").getType().cast().getNumElements() == " # n>]>; + CPred<"!llvm::cast($_op.getOperand(" # x # + ").getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(" # x # + ").getType()).getNumElements() == " # n>]>; class TFL_NumElementsTrait : PredOpTrait<"operand " # x # " has num of elements equals to " # n, @@ -268,16 +268,16 @@ class TFL_NumElementsTrait : // when used as element types. class TFL_TFTypesWithSameBits : And<[ - Or<[CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa()">, + Or<[CPred<"llvm::isa(getElementTypeOrSelf($_op.getResult(" # i # ")))">, CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isUnsignedInteger(" # num # ")">]>, - Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, + Or<[CPred<"llvm::isa(getElementTypeOrSelf($_op.getOperand(" # j # ")))">, CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; class TFL_TFOperandTypesWithSameBits : And<[ - Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isa()">, + Or<[CPred<"llvm::isa(getElementTypeOrSelf($_op.getOperand(" # i # ")))">, CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isUnsignedInteger(" # num # ")">]>, - Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, + Or<[CPred<"llvm::isa(getElementTypeOrSelf($_op.getOperand(" # j # ")))">, CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; class TFL_OperandIsNoneOrHasRank : @@ -285,21 +285,21 @@ class TFL_OperandIsNoneOrHasRank : Or<[ TFL_OperandIsNoneType, TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() == " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() == " # m>]>>; class TFL_OperandIsNoneOrHasRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", Or<[ TFL_OperandIsNoneType, TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() <= " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() <= " # m>]>>; class TFL_OperandHasRankAtMostPred : Or<[TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() <= " # m>]>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() <= " # m>]>; class TFL_OperandHasRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", @@ -310,54 +310,54 @@ class TFL_OperandHasRankAtMost : class TFL_TransposeOperandHasEffectiveRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", Or<[TFL_OperandIsUnrankedPred, - CPred<"GetSqueezedShape($_op.getOperand(" # n # - ")).cast().size() <= " # m>]>>; + CPred<"llvm::cast(GetSqueezedShape($_op.getOperand(" # n # + "))).size() <= " # m>]>>; class TFL_OperandHasRankAtLeast : PredOpTrait<"operand " # n # " is at least " # m # "-D", Or<[TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() >= " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() >= " # m>]>>; class TFL_OperandHasRankRange : PredOpTrait<"operand " # n # " has rank range [" # x # ", " # y # "]", Or<[TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() " - ">= " # x # " && $_op.getOperand(" # n # ").getType().cast()." + CPred<"llvm::cast($_op.getOperand(" # n # ").getType()).getRank() " + ">= " # x # " && llvm::cast($_op.getOperand(" # n # ").getType())." "getRank() <= " # y>]>>; def TFL_FloatNonNegative : AttrConstraint< - CPred<"$_self.isa() && " - "!$_self.cast().getValue().isNegative()">, + CPred<"llvm::isa($_self) && " + "!llvm::cast($_self).getValue().isNegative()">, "whose value is non-negative">; def TFL_BoolTrue : AttrConstraint< - CPred<"$_self.isa() && $_self.cast().getValue()">, + CPred<"llvm::isa($_self) && llvm::cast($_self).getValue()">, "whose value is true">; def TFL_BoolFalse : AttrConstraint< - CPred<"$_self.isa() && !$_self.cast().getValue()">, + CPred<"llvm::isa($_self) && !llvm::cast($_self).getValue()">, "whose value is false">; class TFL_StringEqualsTo : AttrConstraint< - CPred<"$_self.cast().getValue() == \"" # value # "\"">, + CPred<"llvm::cast($_self).getValue() == \"" # value # "\"">, "whose value equals to '" # value # "'">; // Ensures the array attribute's size is within the given maximum size. class TFL_ArrayMaxCount : AttrConstraint< - CPred<"$_self.isa() && $_self.cast().size() <= " # n>, + CPred<"llvm::isa($_self) && llvm::cast($_self).size() <= " # n>, "whose size is at most " # n>; // Ensures the given integer attribute has the given value. class TFL_IntEqualsTo : AttrConstraint< - CPred<"$_self.isa() && " - "$_self.cast().getInt() == " # n>, + CPred<"llvm::isa($_self) && " + "llvm::cast($_self).getInt() == " # n>, "whose value is " # n>; // Ensures the given LSTMKernelType attribute has the given value. class TFL_LSTMKernelTypeEqualsTo : AttrConstraint< - CPred<"$_self.isa() && " - "$_self.cast().getValue() == " # value>, + CPred<"llvm::isa($_self) && " + "llvm::cast($_self).getValue() == " # value>, "whose value is " # value>; // This is a quantization-aware version of TCresVTEtIsSameAsOp @@ -769,11 +769,12 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [ let hasOptions = 1; DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ - return getResult().getType().cast().getElementType(). - cast().getWidth() > 32 ? tflite::TensorType_INT64 : + return llvm::cast(llvm::cast( + getResult().getType()).getElementType()).getWidth() > 32 ? + tflite::TensorType_INT64 : tflite::TensorType_INT32; }], [{ - TypeAttr::get(getResult().getType().cast().getElementType()) + TypeAttr::get(llvm::cast(getResult().getType()).getElementType()) }]>; } @@ -801,11 +802,12 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [ let hasOptions = 1; DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ - return getResult().getType().cast().getElementType(). - cast().getWidth() > 32 ? tflite::TensorType_INT64 : + return llvm::cast(llvm::cast( + getResult().getType()).getElementType()).getWidth() > 32 ? + tflite::TensorType_INT64 : tflite::TensorType_INT32; }], [{ - TypeAttr::get(getResult().getType().cast().getElementType()) + TypeAttr::get(llvm::cast(getResult().getType()).getElementType()) }]>; } @@ -2430,7 +2432,8 @@ def TFL_SliceOp : TFL_Op<"slice", [ TFL_TCresVTEtIsSameAsOp<0, 0>>, Pure, SameOperandsAndResultsScale, - TFL_OperandHasRankAtMost<0, 5>, + TFL_RuntimePredOpTrait<"input must have rank at most 5", + TFL_OperandHasRankAtMostPred<0, 5>>, TFL_OperandHasRankAtMost<1, 1>, TFL_OperandHasRankAtMost<2, 1>]> { let summary = "Return a slice from 'input'."; @@ -2493,6 +2496,11 @@ def TFL_SumOp: TFL_Op<"sum", [ let hasFolder = 1; + let builders = [ + OpBuilder<(ins "Value":$input, "Value":$axes, + "BoolAttr":$keep_dims)> + ]; + // TODO(b/215655380): Re-enable this once there is 16-bit MLIR quantizer. // //let extraClassDeclaration = [{ @@ -3161,7 +3169,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [ let results = (outs TFL_TensorOf<[I32, I64]>:$output); DerivedTypeAttr out_type = DerivedTypeAttr<[{ - return getResult().getType().cast().getElementType(); + return llvm::cast(getResult().getType()).getElementType(); }]>; let hasOptions = 1; @@ -3934,9 +3942,9 @@ def TFL_SparseToDenseOp : TFL_Op<"sparse_to_dense", [ TFL_OperandHasRankAtMost<2, 1>, PredOpTrait<"the first operand should have a rank <= 2, when its rank is 2 and has static shape, the second dim should be <= 4", Or<[TFL_OperandIsUnrankedPred<0>, - CPred<"$_op.getOperand(0).getType().cast().getRank() <= 1">, - CPred<"$_op.getOperand(0).getType().cast().getRank() == 2 && !$_op.getOperand(0).getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(0).getType().cast().getRank() == 2 && $_op.getOperand(0).getType().cast().getShape()[1] <= 4">]>>]> { + CPred<"llvm::cast($_op.getOperand(0).getType()).getRank() <= 1">, + CPred<"llvm::cast($_op.getOperand(0).getType()).getRank() == 2 && !llvm::cast($_op.getOperand(0).getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(0).getType()).getRank() == 2 && llvm::cast($_op.getOperand(0).getType()).getShape()[1] <= 4">]>>]> { let summary = "Converts a sparse representation into a dense tensor."; let description = [{ @@ -3979,7 +3987,8 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [ PredOpTrait<"input and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, SameOperandsAndResultsScale, - TFL_OperandHasRankAtMost<0, 5>, + TFL_RuntimePredOpTrait<"input (with new_axis) must have rank at most 5", + CPred<"TFL::VerifyStridedSliceOpInputRankConstraints(llvm::cast($_op))">>, TFL_OperandHasRank<1, 1>, TFL_OperandHasRank<2, 1>, TFL_OperandHasRank<3, 1> @@ -4107,11 +4116,12 @@ value of `input` in the unique output `output`. In other words: ); DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{ - return getResult(1).getType().cast().getElementType(). - cast().getWidth() > 32 ? tflite::TensorType_INT64 : + return llvm::cast(llvm::cast( + getResult(1).getType()).getElementType()).getWidth() > 32 ? + tflite::TensorType_INT64 : tflite::TensorType_INT32; }], [{ - TypeAttr::get(getResult(1).getType().cast().getElementType()) + TypeAttr::get(llvm::cast(getResult(1).getType()).getElementType()) }]>; let hasOptions = 1; @@ -4153,13 +4163,13 @@ def TFL_DynamicUpdateSliceOp: TFL_Op<"dynamic_update_slice", [ }]; let arguments = (ins - TFL_TensorOf<[I1, I8, I32, I64, F32, F16]>:$operand, - TFL_TensorOf<[I1, I8, I32, I64, F32, F16]>:$update, + TFL_TensorOf<[I1, I8, I16, I32, I64, F32, F16]>:$operand, + TFL_TensorOf<[I1, I8, I16, I32, I64, F32, F16]>:$update, TFL_I32OrI64Tensor:$start_indices ); let results = ( - outs TFL_TensorOf<[I1, I8, I32, I64, F32, F16]>:$output); + outs TFL_TensorOf<[I1, I8, I16, I32, I64, F32, F16]>:$output); let hasFolder = 1; } diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 4dcf1497476f77..6cf9cd3cc9711d 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -217,6 +217,8 @@ absl::Status ConvertSavedModelToTFLiteFlatBuffer( pass_config.model_origin_framework = converter_flags.model_origin_framework(); pass_config.canonicalizing_inf_as_min_max_float = converter_flags.canonicalizing_inf_as_min_max_float(); + pass_config.unsafe_fuse_dynamic_shaped_broadcast = + converter_flags.unsafe_fuse_dynamic_shaped_broadcast(); if (converter_flags.strict_qdq_mode()) { pass_config.quant_specs.qdq_conversion_mode = diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD index f572ad6418feba..55f5727b81cd1d 100644 --- a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD @@ -109,16 +109,10 @@ td_library( gentbl_cc_library( name = "quantization_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "quantization_interface.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "quantization_interface.cc.inc", - ), - ], + tbl_outs = { + "quantization_interface.h.inc": ["-gen-op-interface-decls"], + "quantization_interface.cc.inc": ["-gen-op-interface-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "quantization.td", deps = [ diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td index 690fe4be1d46eb..143996e8816cac 100644 --- a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td @@ -31,12 +31,12 @@ include "mlir/Dialect/Quant/IR/QuantBase.td" // explicit signedness check to differentiate the signed/unsigned constraints // predicates from one another at the TD level. class QuantizedType params, bit signed> - : Type()">, - CPred<"$_self.cast()" # + : Type($_self)">, + CPred<"llvm::cast($_self)" # ".getStorageTypeIntegralWidth() == " # !head(params)>, - Or<[CPred<"$_self.cast()" # + Or<[CPred<"llvm::cast($_self)" # ".getStorageType().isSignlessInteger()">, - CPred<"$_self.cast()" # + CPred<"llvm::cast($_self)" # ".getStorageType().isSignedInteger() == " # signed>]>]>, "Q" # !if (signed, "I", "UI") # !head(params) # " type"> { string name = n; diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h index 4bac179aff06fb..205f69531a6788 100644 --- a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h @@ -201,7 +201,7 @@ bool QuantizableOpSupportsFloatOutputType(mlir::Operation* op); // Specialized version of location to string for flatbuffer exported locations. inline std::string GetTensorNameFromLoc(Location loc) { - if (auto name_loc = loc.dyn_cast()) { + if (auto name_loc = llvm::dyn_cast(loc)) { return name_loc.getName().str(); } return ""; @@ -219,7 +219,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { LogicalResult matchAndRewrite(quantfork::StatisticsOp op, PatternRewriter& rewriter) const override { - Type expressed = op.getType().cast().getElementType(); + Type expressed = llvm::cast(op.getType()).getElementType(); quant::QuantizedType quant_type; SmallVector mins, maxs; @@ -227,7 +227,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { // Per axis quantization (or per channel quantization) int stats_num = op.getAxisStats()->getNumElements(); if (stats_num == 0 || stats_num % 2 != 0) return failure(); - auto stats = op.getAxisStats()->dyn_cast(); + auto stats = llvm::dyn_cast(*op.getAxisStats()); if (!stats) return failure(); for (auto it = stats.begin(), e = stats.end(); it != e; ++it) { @@ -256,7 +256,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { quant_type = DownCastScale(quant_type, mins, maxs, op->getLoc()); } } else if (auto stats = - op.getLayerStats().dyn_cast()) { + llvm::dyn_cast(op.getLayerStats())) { // Per tensor quantization auto statValues = stats.getValues(); double rmin = FloatAttr::getValueAsDouble(statValues[0]); @@ -482,7 +482,7 @@ class QuantizationPattern : public RewritePattern { } if (!nodes_blocklist.empty()) { - if (auto name_loc = quantizing_op->getLoc().dyn_cast()) { + if (auto name_loc = llvm::dyn_cast(quantizing_op->getLoc())) { std::string sloc = name_loc.getName().str(); if (!sloc.empty() && (nodes_blocklist.find(sloc) != nodes_blocklist.end())) { @@ -504,12 +504,13 @@ class QuantizationPattern : public RewritePattern { inputs.reserve(quantizing_op->getNumOperands()); for (auto operand : quantizing_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (isa(operand_type)) { inputs.push_back(operand); continue; } - auto ele_type = operand.getType().cast().getElementType(); + auto ele_type = + llvm::cast(operand.getType()).getElementType(); if (static_cast(this) ->AllowDynamicRangeQuantizedOperand(quantizing_op, custom_map)) { @@ -569,13 +570,13 @@ class QuantizationPattern : public RewritePattern { Type result_type = result.getType(); // Add this to the test coverage once we create test ops with none // type results. - if (result_type.isa()) { + if (isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; } Type result_ele_type = - result.getType().cast().getElementType(); + llvm::cast(result.getType()).getElementType(); // If the user is the QuantizeOp, it must be the only user. if (result.hasOneUse() && llvm::isa(*result.user_begin())) { @@ -649,11 +650,9 @@ class QuantizationPattern : public RewritePattern { } for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { - if (!quantizing_op->getResult(i) - .getType() - .cast() - .getElementType() - .isa()) { + if (!isa( + cast(quantizing_op->getResult(i).getType()) + .getElementType())) { continue; } CreateVerifier(quantizing_op, quantized_op, rewriter, i, @@ -674,9 +673,7 @@ class QuantizationPattern : public RewritePattern { void RewireFloatModelBackbone(mlir::Operation* quantized_op, mlir::Operation* float_op) const { for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { - if (!float_op->getResult(i) - .getType() - .cast() + if (!llvm::cast(float_op->getResult(i).getType()) .getElementType() .isF32()) { continue; @@ -769,14 +766,14 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { auto flags = quant::QuantizationFlags::Signed; QType new_qtype; - if (auto uqtype = qtype.template dyn_cast()) { + if (auto uqtype = llvm::dyn_cast(qtype)) { new_qtype = quant::UniformQuantizedType::getChecked( op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(), uqtype.getScale(), uqtype.getZeroPoint() - offset, uqtype.getStorageTypeMin() - offset, uqtype.getStorageTypeMax() - offset); - } else if (auto aqtype = qtype.template dyn_cast< - quant::UniformQuantizedPerAxisType>()) { + } else if (auto aqtype = + llvm::dyn_cast(qtype)) { auto zero_points = aqtype.getZeroPoints(); llvm::SmallVector new_zero_points(zero_points.begin(), zero_points.end()); diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.cc b/tensorflow/compiler/mlir/lite/quantization/device_target.cc index dcfca75a78fd4e..c55114e62accd6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/device_target.cc +++ b/tensorflow/compiler/mlir/lite/quantization/device_target.cc @@ -72,11 +72,11 @@ ScaleDecomposeFn DeviceTarget::GetDecomposeFn( void DeviceTarget::AppendToSignature(Type spec, KernelSpecs::Signature* signature) { - if (auto quant = spec.dyn_cast_or_null()) { + if (auto quant = llvm::dyn_cast_or_null(spec)) { signature->push_back(AnyQuantizedType::get( quant.getFlags(), quant.getStorageType(), quant.getExpressedType(), quant.getStorageTypeMin(), quant.getStorageTypeMax())); - } else if (auto any = spec.dyn_cast_or_null()) { + } else if (auto any = llvm::dyn_cast_or_null(spec)) { signature->push_back(any); } else { // float signature->push_back(AnyQuantizedType()); @@ -113,17 +113,17 @@ LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale( llvm::SmallVector input_specs, out_specs; for (auto spec : rop.getInputSpecs()) { - input_specs.push_back(spec.cast().getValue()); + input_specs.push_back(llvm::cast(spec).getValue()); } for (auto spec : rop.getOutputSpecs()) { - out_specs.push_back(spec.cast().getValue()); + out_specs.push_back(llvm::cast(spec).getValue()); } - auto in_spec = input_specs[0].dyn_cast(); + auto in_spec = llvm::dyn_cast(input_specs[0]); // TODO(fengliuai): handles the PerAxis QuantizedType. - auto w_spec = input_specs[1].dyn_cast(); - auto b_spec = input_specs[2].dyn_cast(); - auto o_spec = out_specs[0].dyn_cast(); + auto w_spec = llvm::dyn_cast(input_specs[1]); + auto b_spec = llvm::dyn_cast(input_specs[2]); + auto o_spec = llvm::dyn_cast(out_specs[0]); if (!in_spec || !w_spec || !b_spec || !o_spec) return failure(); double scale_product = in_spec.getScale() * w_spec.getScale(); @@ -164,10 +164,8 @@ LogicalResult DeviceTarget::DecomposeSameScale( output_multipliers->push_back(kUnitQuantizedMultiplier); } - auto o_spec = rop.getOutputSpecs()[0] - .cast() - .getValue() - .dyn_cast(); + auto o_spec = llvm::dyn_cast( + llvm::cast(rop.getOutputSpecs()[0]).getValue()); if (!o_spec) return failure(); // output ranges diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 9347e96330203e..2a35475dcceb1b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Regex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -106,8 +107,8 @@ class ImportQuantStatsPass if (index < 0 || index >= static_cast(op->getNumResults())) return false; Value res = op->getResult(index); - return res.getType().isa() && - res.getType().cast().getElementType().isa(); + return isa(res.getType()) && + isa(cast(res.getType()).getElementType()); } // A method to retrieve the name for the given op. @@ -235,11 +236,11 @@ std::unique_ptr> CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) { auto get_name_func = [](Operation *op) { Location loc = tensorflow::GetLocationWithoutOpType(op->getLoc()); - if (auto name = loc.dyn_cast()) { + if (auto name = llvm::dyn_cast(loc)) { return name.getName().strref(); - } else if (auto fused_name = loc.dyn_cast()) { + } else if (auto fused_name = llvm::dyn_cast(loc)) { for (auto sub_loc : fused_name.getLocations()) { - if (auto named_sub_loc = sub_loc.dyn_cast()) { + if (auto named_sub_loc = llvm::dyn_cast(sub_loc)) { return named_sub_loc.getName().strref(); } } diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD index a6d6c61444548e..88022e023443f6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD @@ -26,30 +26,18 @@ td_library( gentbl_cc_library( name = "QuantOpsIncGen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "QuantOps.h.inc", - ), - ( - ["-gen-op-defs"], - "QuantOps.cc.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect=quantfork", - ], - "QuantOpsDialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - "-dialect=quantfork", - ], - "QuantOpsDialect.cc.inc", - ), - ], + tbl_outs = { + "QuantOps.h.inc": ["-gen-op-decls"], + "QuantOps.cc.inc": ["-gen-op-defs"], + "QuantOpsDialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=quantfork", + ], + "QuantOpsDialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=quantfork", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "QuantOps.td", deps = [":QuantizationOpsTdFiles"], @@ -58,15 +46,10 @@ gentbl_cc_library( gentbl_cc_library( name = "QuantPassIncGen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=quantfork", - ], - "Passes.h.inc", - ), - ], + tbl_outs = {"Passes.h.inc": [ + "-gen-pass-decls", + "-name=quantfork", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "Passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD index c3945abc74f740..4f36cb7e7b3d4a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD @@ -102,6 +102,7 @@ cc_library( "//tensorflow/core/platform:logging", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@flatbuffers//:runtime_cc", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc index b2d6fe97280174..655c1e4deadf91 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc index 8682cba5cdc5a9..e59ca11fbec4e5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -258,8 +259,9 @@ LogicalResult QuantizeContext::PropagateQuantParams( quant::AdjacentOperations *new_items, bool *changed) { // Use the final state to set all the operands' parameters. for (int i = 0, e = op->getNumOperands(); i != e; ++i) { - auto ele = op->getOperand(i).getType().cast().getElementType(); - if (ele.isa() && SetOperandParams(op, i, params)) { + auto ele = + llvm::cast(op->getOperand(i).getType()).getElementType(); + if (isa(ele) && SetOperandParams(op, i, params)) { *changed |= true; new_items->push_back(op->getOperand(i).getDefiningOp()); } @@ -267,8 +269,9 @@ LogicalResult QuantizeContext::PropagateQuantParams( // Use the final state to set all the results' parameters. for (int res = 0, e = op->getNumResults(); res != e; ++res) { - auto ele = op->getResult(res).getType().cast().getElementType(); - if (ele.isa() && SetResultParams(op, res, params)) { + auto ele = + llvm::cast(op->getResult(res).getType()).getElementType(); + if (isa(ele) && SetResultParams(op, res, params)) { auto users = op->getResult(res).getUsers(); *changed |= !users.empty(); new_items->append(users.begin(), users.end()); @@ -286,7 +289,7 @@ int QuantizeContext::StatesManager::InitializeState( params_attr = op.getInputSpecs()[index]; } QuantParams params = - params_attr.cast().getValue().dyn_cast(); + dyn_cast(cast(params_attr).getValue()); bool immutable = !EmptyParams(params); int next_state_index = states_.size(); states_.push_back({params, immutable}); diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD index 7d2ff18de0ab51..2ce14328fb1ad2 100644 --- a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD @@ -15,13 +15,13 @@ cc_library( "//tensorflow/cc/saved_model:constants", "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir/lite:tensorflow_lite_tf_unfreeze_global_tensors", - "//tensorflow/compiler/mlir/lite/stablehlo:tf_stablehlo", "//tensorflow/compiler/mlir/quantization/stablehlo:passes", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:config", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:weight_only_ptq", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "//tensorflow/compiler/mlir/stablehlo:tf_stablehlo", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_freeze_variables", "//tensorflow/core/protobuf:for_core_protos_cc", diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc index ff68df33d74726..3f5bcc10eedd54 100644 --- a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc @@ -29,13 +29,13 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/loader.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD index 8a73407338f697..dec7cbc852da00 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD @@ -36,12 +36,7 @@ td_library( gentbl_cc_library( name = "ptq_fallback_to_flex_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "fallback_to_flex_patterns.inc", - ), - ], + tbl_outs = {"fallback_to_flex_patterns.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "fallback_to_flex_patterns.td", deps = [":ptq_td_files"], diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index feb8afac64c098..c23ee8b20b36a8 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -1,6 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@local_xla//xla/tsl/platform:build_config_root.bzl", "if_static") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -97,38 +96,10 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "legalize_utils", - srcs = ["transforms/utils.cc"], - hdrs = ["transforms/utils.h"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@local_xla//xla/mlir_hlo", - ], -) - -tf_cc_test( - name = "legalize_utils_test", - srcs = ["transforms/utils_test.cc"], - deps = [ - ":legalize_utils", - "@com_google_googletest//:gtest_main", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@local_xla//xla/mlir_hlo", - ], -) - gentbl_cc_library( name = "legalize_tf_patterns_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_tf.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_tf_patterns.td", deps = [ @@ -137,88 +108,11 @@ gentbl_cc_library( "@llvm-project//mlir:FuncTdFiles", "@llvm-project//mlir:TensorOpsTdFiles", "@local_xla//xla/mlir_hlo:hlo_ops_td_files", + "@stablehlo//:chlo_ops_td_files", + "@stablehlo//:stablehlo_ops_td_files", ], ) -cc_library( - name = "legalize_tf", - srcs = [ - "transforms/generated_legalize_tf.inc", - "transforms/legalize_tf.cc", - ], - hdrs = [ - "transforms/legalize_tf_passes.h", - ], - deps = [ - ":legalize_tf_patterns_inc_gen", - ":legalize_utils", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", - "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", - "//tensorflow/core:framework", - "//tensorflow/core/kernels:conv_grad_shape_utils", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@local_tsl//tsl/platform:bfloat16", - "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:padding", - "@local_xla//xla/client:sharding_builder", - "@local_xla//xla/client/lib:conv_grad_size_util", - "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", - "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:convert_op_folder", - "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", - "@local_xla//xla/tsl/platform:status", - "@stablehlo//:chlo_ops", - ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), -) - -cc_library( - name = "tf_stablehlo", - srcs = [ - "transforms/tf_stablehlo_pass.cc", - ], - hdrs = [ - "transforms/tf_stablehlo_pass.h", - ], - copts = [ - "-Ithird_party", - ], - deps = [ - ":legalize_tf", - ":stablehlo_util", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib", - "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_with_tf2xla", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", - "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:hlo_dialect_registration", - "@local_xla//xla/mlir_hlo:mhlo_passes", - "@local_xla//xla/mlir_hlo:type_conversion", - "@stablehlo//:chlo_ops", - "@stablehlo//:register", - ], - alwayslink = 1, -) - cc_library( name = "tfl_stablehlo", srcs = [ @@ -264,12 +158,12 @@ cc_library( ":smuggle_disallowed_ops", ":stablehlo_fuse_convolution_pass", ":stablehlo_unfuse_batch_norm_pass", - ":tf_stablehlo", ":unfold_splat_constant_pass", "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/stablehlo:fold_broadcast_pass", "//tensorflow/compiler/mlir/stablehlo:fuse_convolution_pass", "//tensorflow/compiler/mlir/stablehlo:rename_entrypoint_to_main", + "//tensorflow/compiler/mlir/stablehlo:tf_stablehlo", "//tensorflow/compiler/mlir/stablehlo:unfuse_batch_norm_pass", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", @@ -330,15 +224,10 @@ cc_library( gentbl_cc_library( name = "passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=OdmlStablehlo", - ], - "transforms/stablehlo_passes.h.inc", - ), - ], + tbl_outs = {"transforms/stablehlo_passes.h.inc": [ + "-gen-pass-decls", + "-name=OdmlStablehlo", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/stablehlo_passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], @@ -626,12 +515,7 @@ cc_library( gentbl_cc_library( name = "hlo_legalize_tf_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_hlo.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_hlo.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_hlo_patterns.td", deps = [ @@ -645,12 +529,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_legalize_tflite_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_tflite_legalize_hlo.inc", - ), - ], + tbl_outs = {"transforms/generated_tflite_legalize_hlo.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/tflite_legalize_hlo_patterns.td", deps = [ @@ -708,12 +587,7 @@ cc_library( gentbl_cc_library( name = "prepare_hlo_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_prepare_hlo.inc", - ), - ], + tbl_outs = {"transforms/generated_prepare_hlo.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/prepare_hlo.td", deps = [ @@ -806,6 +680,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:tf_device_pass_inc_gen", "//tensorflow/core:framework", "//tensorflow/core:lib", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", @@ -959,12 +834,7 @@ cc_library( gentbl_cc_library( name = "composite_lowering_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_composite_lowering.inc", - ), - ], + tbl_outs = {"transforms/generated_composite_lowering.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/composite_lowering_patterns.td", deps = [ @@ -988,7 +858,6 @@ tf_cc_binary( " [tf.lite.OpsSet.EXPERIMENTAL_STABLEHLO_OPS]", deps = [ ":check_accepted_ops_pass", - ":legalize_tf", ":op_stat_pass", ":stablehlo_util", ":transforms", @@ -1000,6 +869,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", + "//tensorflow/compiler/mlir/stablehlo:legalize_tf", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", @@ -1045,7 +915,6 @@ tf_cc_binary( ":stablehlo_fuse_convolution_pass", ":stablehlo_unfuse_batch_norm_pass", ":tf_legalize_hlo", - ":tf_stablehlo", ":tfl_legalize_chlo", ":tfl_legalize_hlo", ":tfl_stablehlo", @@ -1054,6 +923,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir:tf_mlir_opt_main", "//tensorflow/compiler/mlir/stablehlo:fold_broadcast_pass", "//tensorflow/compiler/mlir/stablehlo:fuse_convolution_pass", + "//tensorflow/compiler/mlir/stablehlo:tf_stablehlo", "//tensorflow/compiler/mlir/stablehlo:unfuse_batch_norm_pass", ], ) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD index c54545bd331391..d6b46ee3d31ad1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD @@ -64,12 +64,7 @@ cc_library( gentbl_cc_library( name = "shlo_simplify_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_shlo_simplify.inc", - ), - ], + tbl_outs = {"transforms/generated_shlo_simplify.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/shlo_simplify.td", deps = ["@stablehlo//:stablehlo_ops_td_files"], @@ -91,15 +86,10 @@ cc_library( gentbl_cc_library( name = "passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=ODMLConverter", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "-name=ODMLConverter", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc index cb48050db47cb5..778e76c79c984b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc @@ -104,7 +104,7 @@ static LogicalResult FoldDivOpInternal(stablehlo::DivOp op, } auto res_attr = DenseElementsAttr::get( - const_oprs[0].getType().cast(), res); + mlir::cast(const_oprs[0].getType()), res); rewriter.replaceOpWithNewOp(adaptor.value().Op(), res_attr); return success(); @@ -112,10 +112,10 @@ static LogicalResult FoldDivOpInternal(stablehlo::DivOp op, static LogicalResult FoldDivOp(stablehlo::DivOp op, PatternRewriter& rewriter) { auto etype = op.getType().getElementType(); - if (etype.isa()) { + if (mlir::isa(etype)) { return FoldDivOpInternal(op, rewriter); } - if (etype.isa()) { + if (mlir::isa(etype)) { return FoldDivOpInternal(op, rewriter); } return failure(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.td b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.td index c8d19baeb11d0d..620fd42ec05486 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.td @@ -19,10 +19,10 @@ include "mlir/IR/CommonAttrConstraints.td" include "mlir/IR/CommonTypeConstraints.td" def CloneF32ElementsAttrWithOnes - : NativeCodeCall<"DenseElementsAttr::get($0.getType().cast(), (float)1.0)">; + : NativeCodeCall<"DenseElementsAttr::get(llvm::cast($0.getType()), (float)1.0)">; def NotConstant : Constraint< - CPred<"$0.isa() || !llvm::isa($0.getDefiningOp())">, + CPred<"llvm::isa($0) || !llvm::isa($0.getDefiningOp())">, "Is not a constant.">; def : Pat<(StableHLO_DivOp $l, diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc index fab718c7a4447b..5f5942dcb714db 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc @@ -56,13 +56,13 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir index a7095618ab0901..f9c8c4953fb931 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir @@ -466,4 +466,14 @@ func.func private @XlaCallModule_odml.random_standard_normal.impl_0(%arg0: tenso } // CHECK-LABEL func.func @random_standard_normal // CHECK: %0 = "tfl.random_standard_normal"(%arg0) <{seed = 0 : i64, seed2 = 1 : i64}> : (tensor<3xi32>) -> tensor<1x2x3xf32> -// CHECK: return %0 : tensor<1x2x3xf32> \ No newline at end of file +// CHECK: return %0 : tensor<1x2x3xf32> + + +func.func private @XlaCallModule_tfl.unpack.impl_0(%arg0: tensor<1x3x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32>) +func.func @jax_unstack(%arg0: tensor<1x3x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32>) { + %0:3 = mhlo.composite "tfl.unpack" %arg0 {composite_attributes = {num = 3 : i32, axis = 1 : i32}, decomposition = @XlaCallModule_tfl.unpack.impl_0} : (tensor<1x3x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32>) + return %0#0, %0#1, %0#2 : tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32> +} + +// CHECK-LABEL: jax_unstack +// CHECK: %0:3 = "tfl.unpack"(%arg0) <{axis = 1 : i32, num = 3 : i32}> : (tensor<1x3x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32>) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir index f363b369d76373..2fa440eee1a3f3 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir @@ -845,3 +845,51 @@ func.func @mhlo_nd_fft(%arg0: tensor<2x3x345x256xf32>) -> tensor<2x3x345x129xcom // CHECK: return %2 : tensor<2x3x345x129xcomplex> // ----- + +// CHECK-LABEL: @mhlo_dynamic_fft_1 +func.func @mhlo_dynamic_fft_1(%arg0: tensor) -> tensor> { + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<2560> : tensor<1xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %0 : tensor> + // CHECK: %4 = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor) -> tensor + // CHECK: %5 = mhlo.reshape %4 : (tensor) -> tensor<1xi32> + // CHECK: %6 = "mhlo.concatenate"(%5, %3, %2, %1) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK: %7 = mhlo.dynamic_reshape %arg0, %6 : (tensor, tensor<4xi32>) -> tensor + // CHECK: %8 = "mhlo.fft"(%7) <{fft_length = dense<[1, 2560]> : tensor<2xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + // CHECK: %9 = "mhlo.get_dimension_size"(%8) <{dimension = 0 : i64}> : (tensor>) -> tensor + // CHECK: %10 = mhlo.reshape %9 : (tensor) -> tensor<1xi32> + // CHECK: %11 = "mhlo.concatenate"(%10, %3, %0) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + // CHECK: %12 = mhlo.dynamic_reshape %8, %11 : (tensor>, tensor<3xi32>) -> tensor> + // CHECK: return %12 : tensor> +} + +// ----- + +// CHECK-LABEL: @mhlo_dynamic_fft_2 +func.func @mhlo_dynamic_fft_2(%arg0: tensor) -> tensor> { + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<2560> : tensor<1xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %0 : tensor> + // CHECK: %3 = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor) -> tensor + // CHECK: %4 = mhlo.reshape %3 : (tensor) -> tensor<1xi32> + // CHECK: %5 = "mhlo.get_dimension_size"(%arg0) <{dimension = 1 : i64}> : (tensor) -> tensor + // CHECK: %6 = mhlo.reshape %5 : (tensor) -> tensor<1xi32> + // CHECK: %7 = "mhlo.concatenate"(%4, %6, %2, %1) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK: %8 = mhlo.dynamic_reshape %arg0, %7 : (tensor, tensor<4xi32>) -> tensor + // CHECK: %9 = "mhlo.fft"(%8) <{fft_length = dense<[1, 2560]> : tensor<2xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + // CHECK: %10 = "mhlo.get_dimension_size"(%9) <{dimension = 0 : i64}> : (tensor>) -> tensor + // CHECK: %11 = mhlo.reshape %10 : (tensor) -> tensor<1xi32> + // CHECK: %12 = "mhlo.get_dimension_size"(%9) <{dimension = 1 : i64}> : (tensor>) -> tensor + // CHECK: %13 = mhlo.reshape %12 : (tensor) -> tensor<1xi32> + // CHECK: %14 = "mhlo.concatenate"(%11, %13, %0) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + // CHECK: %15 = mhlo.dynamic_reshape %9, %14 : (tensor>, tensor<3xi32>) -> tensor> + // CHECK: return %15 : tensor> +} + +// ----- + +// CHECK-LABEL: @mhlo_dynamic_fft_2_neg +func.func @mhlo_dynamic_fft_2_neg(%arg0: tensor) -> tensor> { + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<2560> : tensor<1xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %0 : tensor> + // CHECK: %0 = "mhlo.fft"(%arg0) <{fft_length = dense<2560> : tensor<1xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + // CHECK: return %0 : tensor> +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir index 6d48cfee7e5438..a77d02e78c1dce 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir @@ -3801,6 +3801,26 @@ func.func @mhlo_nd_fft_1(%arg0: tensor<2x3x345x4x256xf32>) -> tensor<2x3x345x4x1 // ----- +// CHECK-LABEL: @mhlo_dynamic_fft_1 +func.func @mhlo_dynamic_fft_1(%arg0: tensor) -> tensor> { + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<[1, 2560]> : tensor<2xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %0 : tensor> + // CHECK: %cst = arith.constant dense<[1, 2560]> : tensor<2xi32> + // CHECK: %0 = "tfl.rfft2d"(%arg0, %cst) : (tensor, tensor<2xi32>) -> tensor> + // CHECK: return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: @mhlo_dynamic_fft_2 +func.func @mhlo_dynamic_fft_2(%arg0: tensor) -> tensor> { + %9 = "mhlo.fft"(%arg0) <{fft_length = dense<[1, 2560]> : tensor<2xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %9 : tensor> + // CHECK: %cst = arith.constant dense<[1, 2560]> : tensor<2xi32> + // CHECK: %0 = "tfl.rfft2d"(%arg0, %cst) : (tensor, tensor<2xi32>) -> tensor> + // CHECK: return %0 : tensor> +} + //===----------------------------------------------------------------------===// // mhlo.imag //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc index c94fd5cd5fede4..4107859b7412af 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc @@ -427,11 +427,21 @@ class UniformDequantizeFunctionCallPattern { // %4 = stablehlo.uniform_dequantize %3 // Dequantize the output. // ``` class ComposeUniformQuantizedConvolutionOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::ConvolutionOp op) const final { + LogicalResult matchAndRewrite(stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::ConvolutionOp op) const { // Verify operands' types. for (Type operand_type : op.getOperandTypes()) { if (Type element_type = @@ -643,8 +653,7 @@ class ComposeUniformQuantizedConvolutionOp return success(); } - void rewrite(stablehlo::ConvolutionOp op, - PatternRewriter& rewriter) const final { + void rewrite(stablehlo::ConvolutionOp op, PatternRewriter& rewriter) const { // Rewrite `call @uniform_quantize` -> `stablehlo.uniform_quantize`. auto input_i8_to_f32_convert_op = cast(op.getOperand(0).getDefiningOp()); @@ -881,10 +890,21 @@ class ComposeUniformQuantizedConvolutionOp // cast isn't present, the filter constant (%3) should be i8 quantized values // disguised in f32. class ComposeUniformQuantizedDotGeneralOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; - LogicalResult match(stablehlo::DotGeneralOp op) const final { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::DotGeneralOp op) const { auto input_i8_to_f32_convert_op = TryCast(op.getOperand(0).getDefiningOp(), /*name=*/"input_i8_to_f32_convert_op"); @@ -988,8 +1008,7 @@ class ComposeUniformQuantizedDotGeneralOp return success(); } - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const final { + void rewrite(stablehlo::DotGeneralOp op, PatternRewriter& rewriter) const { // Build uniform quantized type for input. auto input_i8_to_f32_convert_op = cast(op.getOperand(0).getDefiningOp()); @@ -1304,11 +1323,21 @@ class ComposeUniformQuantizedDotGeneralOp // %5 = stablehlo.uniform_dequantize %4 // Dequantize the output. // ``` class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::DotGeneralOp op) const final { + LogicalResult matchAndRewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::DotGeneralOp op) const { // q1 - z1 if (failed(MatchQuantizedOperand(op.getOperand(0)))) { LLVM_DEBUG(llvm::dbgs() @@ -1365,8 +1394,7 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations return success(); } - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const final { + void rewrite(stablehlo::DotGeneralOp op, PatternRewriter& rewriter) const { // Build uniform quantized type for input 1 (lhs). auto input1_zero_point_subtract_op = cast(op.getOperand(0).getDefiningOp()); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td index b7e1f252507035..2cf060c6379d53 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -29,21 +29,21 @@ def LegalizeHardSwishComposite: Pat< (TFL_HardSwishOp $input)>; def IsNchwLayoutOp: Constraint() " + "$0.get(\"is_nchw_op\") && llvm::dyn_cast($0.get(\"is_nchw_op\")) " "== mlir::BoolAttr::get($_builder.getContext(), true)">>; def IsNhwcLayoutOp: Constraint>; class HasRank : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() == " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() == " # n>>; class HasRankAtLeast : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() >= " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() >= " # n>>; def I32ElementsVal : Constraint().getElementType().isInteger(32)">, + "llvm::cast($0.getType()).getElementType().isInteger(32)">, "32 bit integer tensor">; // TODO(b/343278954): Move the creation of transposes to a separate prepare pass @@ -133,6 +133,27 @@ def LegalizeCompositeGELU : Pat< (TFL_GeluOp $inputs, (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; +def LegalizeCompositeGELUDynamicShaped : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $inputs), + ConstantStrAttr, $attrs, $_, $_), + (TFL_GeluOp $inputs, + (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; + +def LegalizeCompositeGELUDynamicShaped2 : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $_, $inputs), + ConstantStrAttr, $attrs, $_, $_), + (TFL_GeluOp $inputs, + (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; + +def LegalizeCompositeGELUDynamicShaped3 : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $_, $_, $inputs), + ConstantStrAttr, $attrs, $_, $_), + (TFL_GeluOp $inputs, + (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; + def LegalizeCompositeOdmlEmbeddingLookup : Pat< (MHLO_CompositeOp:$composite (variadic $indices, $table), @@ -152,6 +173,15 @@ def LegalizeCompositeOdmlEmbeddingLookupDynamicShaped : Pat< (HasRankAtLeast<2> $table)]>; def LegalizeCompositeOdmlEmbeddingLookupDynamicShaped2 : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $_, $indices, $table), + ConstantStrAttr, $attrs, $_, $_), + (TFL_EmbeddingLookupOp $indices, $table), + [(HasRank<1> $indices), + (I32ElementsVal $indices), + (HasRankAtLeast<2> $table)]>; + +def LegalizeCompositeOdmlEmbeddingLookupDynamicShaped3 : Pat< (MHLO_CompositeOp:$composite (variadic $_, $indices, $table), ConstantStrAttr, $attrs, $_, $_), @@ -174,4 +204,22 @@ def LegalizeCompositeOdmlRandomStandardNormal : Pat< ConstantStrAttr, $attrs, $_, $_), (TFL_RandomStandardNormalOp $shape, (GetCompositeAttributeAs<"seed", "IntegerAttr"> $attrs), - (GetCompositeAttributeAs<"seed2", "IntegerAttr"> $attrs))>; \ No newline at end of file + (GetCompositeAttributeAs<"seed2", "IntegerAttr"> $attrs))>; + +def LegalizeCompositeUnpack : Pat< + (MHLO_CompositeOp:$composite + (variadic $inputs), + ConstantStrAttr, $attrs, $_, $_), + (TFL_UnpackOp $inputs, + (GetCompositeAttributeAs<"num", "IntegerAttr"> $attrs), + (GetCompositeAttributeAs<"axis", "IntegerAttr"> $attrs))>; + +def LegalizeCompositePack4Elements : Pat< + (MHLO_CompositeOp:$composite + // TD not able to represent variadic of variadic now. + // Move to C++ matcher to support more cases. + (variadic $i0, $i1, $i2, $i3), + ConstantStrAttr, $attrs, $_, $_), + (TFL_PackOp (variadic $i0, $i1, $i2, $i3), + (GetCompositeAttributeAs<"values_count", "IntegerAttr"> $attrs), + (GetCompositeAttributeAs<"axis", "IntegerAttr"> $attrs))>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td index 30d6f4247fba52..7d905119b3f08f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td @@ -33,7 +33,7 @@ def GetI32DenseAttr: NativeCodeCall< // Receives a composite DictionaryAttr and returns the value of the Attribute // with the key `attr_name` as the type provided by `attr_type`. class GetCompositeAttributeAs: - NativeCodeCall<"$0.get(\"" # attr_name # "\").dyn_cast<" # attr_type # ">()">; + NativeCodeCall<"llvm::dyn_cast<" # attr_type # ">($0.get(\"" # attr_name # "\"))">; // Receives a composite DictionaryAttr and returns the value of the Attribute // with the key `attr_name` as a DenseIntElementsAttr. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc index 20c4a616554801..044848ce93ce61 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD index c58fca93e6e53d..9e2f1cf33f495f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD @@ -341,8 +341,8 @@ cc_library( srcs = ["fft.cc"], hdrs = ["fft.h"], deps = [ - "//tensorflow/compiler/mlir/lite:const_tensor_utils", "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc index 8f08a0f8a2b1c0..f2d29774c31c89 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include -#include #include #include +#include "mhlo/IR/hlo_ops.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -32,7 +32,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep -#include "tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir::odml { @@ -62,14 +62,6 @@ bool IsSupportedRfftOp(mhlo::FftOp fft_op) { if (fft_lengths.size() > 2) return false; // Only support 2D FFT. - // TFLite RFFT2d supports only int32 fft_lengths that are powers of 2. - for (int64_t fft_length : fft_lengths) { - if (fft_length != 1 && (!TFL::IsPowerOfTwo(fft_length) || - fft_length > std::numeric_limits::max())) { - return false; - } - } - // Check if the trailing input shape matches the fft_lengths. const std::vector input_shape = mlir::cast(fft_op.getOperand().getType()).getShape(); @@ -77,6 +69,16 @@ bool IsSupportedRfftOp(mhlo::FftOp fft_op) { fft_lengths.begin(), fft_lengths.end()); } +// Returns a tensor of the dimension size of the input tensor. Result of +// mhlo::GetDimensionSizeOp is always a scalar value, but we need a tensor to +// concatenate with other dimension sizes. +Value GetDimensionSizeTensor(OpBuilder& rewriter, Location loc, Value input, + int64_t dim) { + auto size_scalar = rewriter.create(loc, input, dim); + return rewriter.create( + loc, RankedTensorType::get({1}, rewriter.getI32Type()), size_scalar); +} + // Convert rfft to rfft2d. // The transformation pattern looks like below: // @@ -114,18 +116,22 @@ class ConvertNDFftTo2DFftOp : public OpRewritePattern { auto input_type = mlir::dyn_cast_or_null(fft_op.getOperand().getType()); const std::vector input_shape = - mlir::cast(fft_op.getOperand().getType()).getShape(); + input_type + ? input_type.getShape() + : mlir::cast(fft_op.getOperand().getType()).getShape(); - auto fft_operand = fft_op.getOperand(); + Value fft_operand = fft_op.getOperand(); auto output_type = mlir::cast(fft_op.getResult().getType()); // Create a new fft_length attribute for the 2D FFT. SmallVector new_fft_lengths = {1, fft_lengths.back()}; auto new_fft_lengths_attr = rewriter.getI64TensorAttr(new_fft_lengths); + bool is_dynamic_shape = !input_type || !input_type.hasStaticShape(); + // Input can have a single trivial batch dim next to the fft dimension, in // which case we don't need to expand the input. - if (input_type && (input_shape[input_shape.size() - 2] != 1)) { + if (input_shape[input_shape.size() - 2] != 1) { const std::vector output_shape = output_type.getShape(); // [a, b, c, d, e] -> [a, b, c, d, 1, e] @@ -133,11 +139,42 @@ class ConvertNDFftTo2DFftOp : public OpRewritePattern { input_shape.end() - 1}; expanded_input_shape.push_back(1); expanded_input_shape.push_back(input_shape.back()); - // Replace the expand_dims op with a reshape op: - auto expanded_input_type = mlir::RankedTensorType::get( + auto expanded_input_type = tensorflow::GetTypeFromTFTensorShape( expanded_input_shape, input_type.getElementType()); - fft_operand = rewriter.create( - fft_op.getLoc(), expanded_input_type, fft_operand); + + // Dynamic shape needs to be handled separately as mhlo::ReshapeOp does + // not support dynamic shape. + if (is_dynamic_shape) { + // Programmatically- + // 1. Get the dimensions of the input tensor and create shape vector. + // 2. Insert a 1 as the penultimate dimension size. + // 3. Concatenate the dimension sizes to create a new SHAPE tensor. + SmallVector expanded_input_shape_values; + for (int i = 0; i < input_shape.size() - 1; ++i) { + expanded_input_shape_values.push_back(GetDimensionSizeTensor( + rewriter, fft_op.getLoc(), fft_operand, i)); + } + expanded_input_shape_values.push_back(rewriter.create( + fft_op.getLoc(), rewriter.getI32TensorAttr({1}))); + expanded_input_shape_values.push_back(GetDimensionSizeTensor( + rewriter, fft_op.getLoc(), fft_operand, input_shape.size() - 1)); + + auto expanded_input_shape_tensor = rewriter.create( + fft_op.getLoc(), + RankedTensorType::get( + {static_cast(expanded_input_shape_values.size())}, + rewriter.getI32Type()), + expanded_input_shape_values, 0); + + // Create a new mhlo.dynamic_reshape op with the expanded input and + // expanded input shape. SHAPE tensor is created in the previous step. + fft_operand = rewriter.create( + fft_op.getLoc(), expanded_input_type, fft_operand, + expanded_input_shape_tensor); + } else { + fft_operand = rewriter.create( + fft_op.getLoc(), expanded_input_type, fft_operand); + } SmallVector new_output_shape = {output_shape.begin(), output_shape.end() - 1}; @@ -152,12 +189,34 @@ class ConvertNDFftTo2DFftOp : public OpRewritePattern { rewriter.create(fft_op.getLoc(), output_type, fft_operand, fft_op.getFftType(), new_fft_lengths_attr); - if (input_type && (input_shape[input_shape.size() - 2] != 1)) { + if (input_shape[input_shape.size() - 2] != 1) { // Squeeze the output dimensions back to 2D. - auto squeeze_op = rewriter.create( - fft_op.getLoc(), fft_op.getResult().getType(), new_fft.getResult()); - - rewriter.replaceOp(fft_op, squeeze_op.getResult()); + if (is_dynamic_shape) { + SmallVector output_shape_values; + for (int i = 0; i < new_fft.getResult().getType().getShape().size() - 2; + ++i) { + output_shape_values.push_back(GetDimensionSizeTensor( + rewriter, fft_op.getLoc(), new_fft.getResult(), i)); + } + output_shape_values.push_back(GetDimensionSizeTensor( + rewriter, fft_op.getLoc(), new_fft.getResult(), + new_fft.getResult().getType().getShape().size() - 1)); + + auto shape_tensor = rewriter.create( + fft_op.getLoc(), + RankedTensorType::get( + {static_cast(output_shape_values.size())}, + rewriter.getI32Type()), + output_shape_values, 0); + auto squeeze_op = rewriter.create( + fft_op.getLoc(), fft_op.getResult().getType(), new_fft.getResult(), + shape_tensor); + rewriter.replaceOp(fft_op, squeeze_op.getResult()); + } else { + auto squeeze_op = rewriter.create( + fft_op.getLoc(), fft_op.getResult().getType(), new_fft.getResult()); + rewriter.replaceOp(fft_op, squeeze_op.getResult()); + } } else { rewriter.replaceOp(fft_op, new_fft.getResult()); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td index f9fd092f1e04fb..05a68b2cff370e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td @@ -283,7 +283,7 @@ def : Pat<(MHLO_ConcatenateOp $inputs, $dim), //===----------------------------------------------------------------------===// class HasChloCompareType : - CPred<"$_self.cast<::mlir::chlo::ComparisonTypeAttr>().getValue() == " # value>; + CPred<"llvm::cast<::mlir::chlo::ComparisonTypeAttr>($_self).getValue() == " # value>; // Attribute value should be such that it matches the comparison used by // TensorFlow, if the attribute is present. @@ -298,7 +298,7 @@ class CHLO_ComparisonDirectionValue : ConstantAttr; class HasMhloCompareType : - CPred<"$_self.cast<::mlir::mhlo::ComparisonTypeAttr>().getValue() == " # value>; + CPred<"llvm::cast<::mlir::mhlo::ComparisonTypeAttr>($_self).getValue() == " # value>; // Attribute value should be such that it matches the comparison used by // TensorFlow, if the attribute is present. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc index e1f1681a3d7ae1..b09864ac8eac2f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc @@ -79,7 +79,6 @@ class StablehloToOdmlTypeConverter : public vhlo::VhloTypeConverter { }); addBuiltinToVhloConversions(); - addArgumentMaterialization(MaterializeIllegalCast); addSourceMaterialization(MaterializeIllegalCast); addTargetMaterialization(MaterializeIllegalCast); } @@ -112,7 +111,6 @@ class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter { }); addVhloToBuiltinConversions(); - addArgumentMaterialization(MaterializeIllegalCast); addSourceMaterialization(MaterializeIllegalCast); addTargetMaterialization(MaterializeIllegalCast); } @@ -144,7 +142,7 @@ void ConvertAndWrapUsesInUnrealizedCast(Value result, TypeConverter &converter, IRRewriter &rewriter) { auto type = result.getType(); result.setType(converter.convertType(result.getType())); - auto new_value = converter.materializeArgumentConversion( + auto new_value = converter.materializeSourceConversion( rewriter, result.getLoc(), type, {result}); rewriter.replaceAllUsesExcept(result, new_value, new_value.getDefiningOp()); } @@ -160,7 +158,7 @@ void WrapOperandsInUnrealizedCastAndConvert(Operation *op, IRRewriter &rewriter) { for (int i = 0; i < op->getNumOperands(); ++i) { auto operand = op->getOperand(i); - auto new_operand = converter.materializeArgumentConversion( + auto new_operand = converter.materializeSourceConversion( rewriter, op->getLoc(), converter.convertType(operand.getType()), {operand}); op->setOperand(i, new_operand); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc index d251f49cfa28bf..b0bbeb57c5a6ac 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc @@ -91,7 +91,7 @@ struct TransposeCommuteWithPad : public OpRewritePattern { LogicalResult matchAndRewrite(stablehlo::PadOp pad_op, PatternRewriter& rewriter) const override { Value pad_input = pad_op.getOperand(); - RankedTensorType pad_type = pad_op.getType().cast(); + RankedTensorType pad_type = mlir::cast(pad_op.getType()); auto transpose_op = pad_input.getDefiningOp(); if (!transpose_op || !transpose_op->hasOneUse()) return failure(); @@ -132,7 +132,7 @@ struct TransposeCommuteWithReduceWindow Value reduce_input = inputs[0]; RankedTensorType reduce_type = - reduce_op.getResultTypes()[0].cast(); + mlir::cast(reduce_op.getResultTypes()[0]); auto transpose_op = reduce_input.getDefiningOp(); if (!transpose_op || !transpose_op->hasOneUse()) return failure(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td index 9b6f6efbfcf4f6..c0b274ac1f852b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td @@ -56,10 +56,10 @@ def AreDnumsFullyDefined : Constraint()" + "llvm::cast($2.getType())" ".clone($0.PermuteShape(" "$1," - "$2.getType().cast().getShape()))">; + "llvm::cast($2.getType()).getShape()))">; def IsStandardConv : Constraint())">>; @@ -380,7 +380,7 @@ def GetExplicitPaddingArgs : NativeCodeCall< // Gets element type from Value. def GetElementType : NativeCodeCall< - "$0.getType().cast().getElementType()">; + "llvm::cast($0.getType()).getElementType()">; // Given element type, get a DenseElements with scalar shape and 0 value. def GetZeroScalarAttrFromType : NativeCodeCall< @@ -439,9 +439,9 @@ def UnfuseConvWithExplicitPadding : Pat<(MHLO_ConvolutionOp:$conv def TrivialStrides : NativeCodeCall< "DenseIntElementsAttr::get(" - "RankedTensorType::get({$0.getType().cast().getRank()}," + "RankedTensorType::get({llvm::cast($0.getType()).getRank()}," "$_builder.getI64Type())," - "llvm::SmallVector($0.getType().cast().getRank()," + "llvm::SmallVector(llvm::cast($0.getType()).getRank()," "1))">; def SliceStart : NativeCodeCall< diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_fuse_convolution_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_fuse_convolution_pass.cc index 6ccdb72abf3413..fcecd557aeab1c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_fuse_convolution_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_fuse_convolution_pass.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include #include #include "stablehlo/dialect/StablehloOps.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_unfuse_batch_norm_pass.cc index 3b0ec3c974007a..32d76f91848003 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_unfuse_batch_norm_pass.cc @@ -15,7 +15,6 @@ limitations under the License. #include #include -#include #include #include "stablehlo/dialect/StablehloOps.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc index e8a2bc870e960d..c876347d2a2c64 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc @@ -14,9 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.h" +#include +#include #include #include -#include #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td index 4fb22ae6bbe992..e438e9580697e2 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td @@ -56,7 +56,7 @@ def : Pat< def I64AttrToI32Attr: NativeCodeCall< "$_builder.getI32IntegerAttr(" - "static_cast($0.cast().getInt()))">; + "static_cast(llvm::cast($0).getInt()))">; def : Pat< (MHLO_ConcatenateOp $inputs, $dim), @@ -298,7 +298,7 @@ foreach pair = [ // Check implicit bool cast of `$_self` to ensure Attribute is non-null before // casting. def HasSupportedComparisonType : AttrConstraint< - CPred<"!$_self || SupportedComparisonType($_self.cast())">>; + CPred<"!$_self || SupportedComparisonType(llvm::cast($_self))">>; class MHLO_ComparisonDirectionValue : ConstantAttr() - .getValues(); + const auto filter_values = + llvm::cast(filter_constant_op.getValue()) + .getValues(); ArrayRef filter_shape = - filter_constant_op.getType().cast().getShape(); + llvm::cast(filter_constant_op.getType()).getShape(); // Reverse the shapes. This makes sense, assuming that the filter tensor has a // rank of 2 (no batch dimension). @@ -159,16 +160,16 @@ TFL::QConstOp CreateTransposedTflConstOpForFilter( Type new_filter_quantized_type; if (is_per_channel) { - auto filter_quantized_type = GetElementType(filter_constant_op.getResult()) - .cast(); + auto filter_quantized_type = llvm::cast( + GetElementType(filter_constant_op.getResult())); new_filter_quantized_type = CreateI8F32UniformQuantizedPerAxisType( filter_constant_op->getLoc(), *rewriter.getContext(), filter_quantized_type.getScales(), filter_quantized_type.getZeroPoints(), /*quantization_dimension=*/0, /*narrow_range=*/true); } else { - auto filter_quantized_type = GetElementType(filter_constant_op.getResult()) - .cast(); + auto filter_quantized_type = llvm::cast( + GetElementType(filter_constant_op.getResult())); new_filter_quantized_type = CreateI8F32UniformQuantizedType( filter_constant_op->getLoc(), *rewriter.getContext(), filter_quantized_type.getScale(), filter_quantized_type.getZeroPoint(), @@ -235,8 +236,8 @@ TFL::QConstOp CreateTflConstOpForDummyBias( Type bias_quantized_type; if (is_per_channel) { const auto filter_quantized_element_type = - GetElementType(filter_const_op.getResult()) - .cast(); + llvm::cast( + GetElementType(filter_const_op.getResult())); // The storage type is i32 for bias, which is the precision used for // accumulation. @@ -247,8 +248,8 @@ TFL::QConstOp CreateTflConstOpForDummyBias( /*quantization_dimension=*/0); } else { const auto filter_quantized_element_type = - GetElementType(filter_const_op.getResult()) - .cast(); + llvm::cast( + GetElementType(filter_const_op.getResult())); // The storage type is i32 for bias, which is the precision used for // accumulation. @@ -297,8 +298,8 @@ Type GetQuantizedOutputType(Operation* op, PatternRewriter& rewriter, } // StableHLO Quantizer outputs an i32 type. Rewrite to i8 type result // to meet TFLite op requirement. - auto result_quantized_type = GetElementType(uniform_quantize_op->getResult(0)) - .cast(); + auto result_quantized_type = llvm::cast( + GetElementType(uniform_quantize_op->getResult(0))); auto new_result_quantized_type = CreateI8F32UniformQuantizedType( uniform_quantize_op->getLoc(), *rewriter.getContext(), result_quantized_type.getScale(), result_quantized_type.getZeroPoint()); @@ -306,8 +307,8 @@ Type GetQuantizedOutputType(Operation* op, PatternRewriter& rewriter, // fused `qi8` type. rewriter.replaceAllUsesWith(uniform_quantize_op->getResult(0), op->getResult(0)); - return op->getResult(0).getType().cast().clone( - new_result_quantized_type); + return llvm::cast(op->getResult(0).getType()) + .clone(new_result_quantized_type); } // Matches kernel dimension numbers, ranks of input and output and constant @@ -331,7 +332,7 @@ LogicalResult MatchConvolutionFormat(stablehlo::ConvolutionOp op) { return failure(); } - const auto input_type = op.getLhs().getType().cast(); + const auto input_type = llvm::cast(op.getLhs().getType()); if (input_type.getRank() != 4) { LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " "Expected input rank of 4. Got: " @@ -339,7 +340,7 @@ LogicalResult MatchConvolutionFormat(stablehlo::ConvolutionOp op) { return failure(); } - const auto filter_type = op.getRhs().getType().cast(); + const auto filter_type = llvm::cast(op.getRhs().getType()); if (filter_type.getRank() != 4) { LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " "Expected filter rank of 4. Got: " @@ -444,17 +445,17 @@ int64_t GetConvolutionKernelInputFeatureDimension(bool is_depthwise) { // stablehlo.uniform_quantize -> tfl.quantize // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. class RewriteUniformQuantizeOp - : public OpRewritePattern< - stablehlo::UniformQuantizeOp>::SplitMatchAndRewrite { - using SplitMatchAndRewrite::SplitMatchAndRewrite; + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; // Determines whether the input and output types are compatible with // `tfl.quantize`. See the definition for the `QUANTIZE` kernel for the // detailed limitations // (https://github.com/tensorflow/tensorflow/blob/8f145d579aa0ee7f4187af32dbbf4e12fdabbffe/tensorflow/lite/kernels/quantize.cc#L105). - LogicalResult match(stablehlo::UniformQuantizeOp op) const override { + LogicalResult matchAndRewrite(stablehlo::UniformQuantizeOp op, + PatternRewriter& rewriter) const override { const Type input_element_type = GetElementType(op.getOperand()); - if (!(input_element_type.isa() || + if (!(llvm::isa(input_element_type) || IsI32F32UniformQuantizedType(input_element_type) || IsI32F32UniformQuantizedPerAxisType(input_element_type))) { LLVM_DEBUG(llvm::dbgs() << "Uniform quantize op's input should be a " @@ -465,43 +466,37 @@ class RewriteUniformQuantizeOp // Output type of `UniformQuantizeOp` is guaranteed to be a quantized // tensor with integer storage type. - const auto output_storage_type = GetElementType(op.getResult()) - .cast() - .getStorageType() - .cast(); + const auto output_storage_type = llvm::cast( + llvm::cast(GetElementType(op.getResult())) + .getStorageType()); if (!IsSupportedByTfliteQuantizeOrDequantizeOps(output_storage_type)) { LLVM_DEBUG(llvm::dbgs() << "Failed to match storage type of output quantized type.\n"); return failure(); } - return success(); - } - - void rewrite(stablehlo::UniformQuantizeOp op, - PatternRewriter& rewriter) const override { Type output_type = *op->getResultTypes().begin(); rewriter.replaceOpWithNewOp( op, output_type, /*input=*/op.getOperand(), /*qtype=*/TypeAttr::get(output_type)); + return success(); } }; // stablehlo.uniform_dequantize -> tfl.dequantize class RewriteUniformDequantizeOp - : public OpRewritePattern< - stablehlo::UniformDequantizeOp>::SplitMatchAndRewrite { - using SplitMatchAndRewrite::SplitMatchAndRewrite; + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; // Determines whether the input and output types are compatible with // `tfl.dequantize`. See the definition for the `DEQUANTIZE` kernel for the // detailed limitations // (https://github.com/tensorflow/tensorflow/blob/8f145d579aa0ee7f4187af32dbbf4e12fdabbffe/tensorflow/lite/kernels/dequantize.cc#L52). - LogicalResult match(stablehlo::UniformDequantizeOp op) const override { - const auto input_storage_type = GetElementType(op.getOperand()) - .cast() - .getStorageType() - .cast(); + LogicalResult matchAndRewrite(stablehlo::UniformDequantizeOp op, + PatternRewriter& rewriter) const override { + const auto input_storage_type = llvm::cast( + llvm::cast(GetElementType(op.getOperand())) + .getStorageType()); if (!IsSupportedByTfliteQuantizeOrDequantizeOps(input_storage_type)) { LLVM_DEBUG(llvm::dbgs() << "Failed to match storage type of input quantized type.\n"); @@ -510,21 +505,17 @@ class RewriteUniformDequantizeOp // Output type is guaranteed to be a float tensor for a valid StableHLO. const auto output_element_type = - GetElementType(op.getResult()).cast(); - if (!output_element_type.isa()) { + llvm::cast(GetElementType(op.getResult())); + if (!llvm::isa(output_element_type)) { LLVM_DEBUG(llvm::dbgs() << "Uniform dequantize op's output element type " "should be f32. Got: " << output_element_type << ".\n"); return failure(); } - return success(); - } - - void rewrite(stablehlo::UniformDequantizeOp op, - PatternRewriter& rewriter) const override { rewriter.replaceOpWithNewOp( op, /*resultTypes=*/op->getResultTypes(), /*input=*/op.getOperand()); + return success(); } }; @@ -563,17 +554,26 @@ class RewriteUniformDequantizeOp // * The filter tensor's rank is 2. The contracting dimension should be the // first dimension (dim 0), i.e. [c_y, r_y] where c_y == r_x. class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: // Sets benefit to 10 to make this pattern more preferred than smaller local // transformations like `stablehlo.transpose`->`tfl.transpose`, as this // pattern involves `stablehlo.transpose` in some cases. explicit RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp( MLIRContext* ctx) - : OpRewritePattern::SplitMatchAndRewrite( - ctx, /*benefit=*/10) {} + : OpRewritePattern(ctx, /*benefit=*/10) {} - LogicalResult match(stablehlo::DotGeneralOp op) const override { + LogicalResult matchAndRewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::DotGeneralOp op) const { const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = op.getDotDimensionNumbers(); const bool is_batch_matmul = !IsDotGeneralFullyConnected(op).value(); @@ -605,8 +605,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp has_i32_output); } - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const override { + void rewrite(stablehlo::DotGeneralOp op, PatternRewriter& rewriter) const { const Type output_type = GetElementType(op.getResult()); const bool has_i32_output = IsI32F32UniformQuantizedType(output_type) || @@ -624,7 +623,6 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp } } - private: static LogicalResult MatchDotGeneralToTflBatchMatmulOp( stablehlo::DotGeneralOp op, const stablehlo::DotDimensionNumbersAttr dot_dimension_nums, @@ -655,7 +653,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp "quantized dot_general.\n"); return failure(); } - const auto input_type = op.getLhs().getType().cast(); + const auto input_type = llvm::cast(op.getLhs().getType()); const int input_rank = input_type.getRank(); const auto input_contracting_dim = dot_dimension_nums.getLhsContractingDimensions()[0]; @@ -666,7 +664,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp return failure(); } - const auto filter_type = op.getRhs().getType().cast(); + const auto filter_type = llvm::cast(op.getRhs().getType()); const Type filter_element_type = filter_type.getElementType(); if (!IsI8F32UniformQuantizedType(filter_element_type)) { LLVM_DEBUG(llvm::dbgs() @@ -675,7 +673,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp << filter_type << "\n"); return failure(); } - const int rhs_rank = filter_type.cast().getRank(); + const int rhs_rank = llvm::cast(filter_type).getRank(); const auto rhs_contracting_dim = dot_dimension_nums.getRhsContractingDimensions()[0]; if ((rhs_contracting_dim != rhs_rank - 1) && @@ -702,7 +700,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp return failure(); } - const auto input_type = op.getLhs().getType().cast(); + const auto input_type = llvm::cast(op.getLhs().getType()); if (!(input_type.getRank() == 2 || input_type.getRank() == 3)) { LLVM_DEBUG(llvm::dbgs() << "Input expected to have rank of 2 or 3. Got: " << input_type << ".\n"); @@ -710,7 +708,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp } const Value filter = op.getRhs(); - const auto filter_type = filter.getType().cast(); + const auto filter_type = llvm::cast(filter.getType()); if (filter_type.getRank() != 2) { LLVM_DEBUG(llvm::dbgs() << "Filter tensor expected to have a tensor rank of 2. Got: " @@ -752,7 +750,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp } static LogicalResult MatchInputDotGeneralCommonPattern(const Value input) { - const auto input_type = input.getType().cast(); + const auto input_type = llvm::cast(input.getType()); if (const auto input_element_type = input_type.getElementType(); !IsI8F32UniformQuantizedType(input_element_type)) { LLVM_DEBUG(llvm::dbgs() @@ -769,7 +767,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp } static LogicalResult MatchFilterCommonPattern(const Value filter) { - auto filter_type = filter.getType().cast(); + auto filter_type = llvm::cast(filter.getType()); if (!filter_type.hasRank()) { LLVM_DEBUG(llvm::dbgs() << "Expected rhs of dot_general has rank. Got: " << filter.getType() << "\n"); @@ -830,11 +828,11 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp // dynamic-range quantized. const BoolAttr asymmetric_quantize_inputs = nullptr; - const int lhs_rank = lhs_value.getType().cast().getRank(); + const int lhs_rank = llvm::cast(lhs_value.getType()).getRank(); const BoolAttr adj_x = (lhs_contracting_dims[0] == lhs_rank - 2 ? rewriter.getBoolAttr(true) : rewriter.getBoolAttr(false)); - const int rhs_rank = rhs_value.getType().cast().getRank(); + const int rhs_rank = llvm::cast(rhs_value.getType()).getRank(); const BoolAttr adj_y = (rhs_contracting_dims[0] == rhs_rank - 1 ? rewriter.getBoolAttr(true) : rewriter.getBoolAttr(false)); @@ -855,7 +853,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp // Update BMM if rhs is a constant. if (filter_constant_op != nullptr) { const auto rhs_uniform_quantized_type = - rhs_value.getType().cast(); + llvm::cast(rhs_value.getType()); const auto rhs_constant_value_attr = cast(filter_constant_op.getValue()); auto rhs_constant_op = rewriter.create( @@ -886,7 +884,8 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp rhs_value.getDefiningOp(), rewriter, /*is_per_channel=*/true); const double input_scale = - GetElementType(lhs_value).cast().getScale(); + llvm::cast(GetElementType(lhs_value)) + .getScale(); TFL::QConstOp bias_tfl_op; bool fuse_bias_constant = FindUserOfType(op) && has_i32_output; @@ -922,23 +921,23 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp Operation* add_op = FindUserOfType(op); uniform_quantize_op = FindUserOfType(add_op); const auto filter_quantized_type = - GetElementType(op->getOperand(1)) - .cast(); + llvm::cast( + GetElementType(op->getOperand(1))); const SmallVector bias_scales = GetBiasScales( - /*input_scale=*/GetElementType(op->getOperand(0)) - .cast() + /*input_scale=*/llvm::cast( + GetElementType(op->getOperand(0))) .getScale(), /*filter_scales=*/filter_quantized_type.getScales()); const ArrayRef output_shape = - op->getResult(0).getType().cast().getShape(); + llvm::cast(op->getResult(0).getType()).getShape(); const SmallVector bias_shape = { output_shape[output_shape.size() - 1]}; // `tfl.fully_connected`'s `GetChannelDimIndex` is 0. const auto bias_quantized_type = CreateI32F32UniformQuantizedPerAxisType( op->getLoc(), *op->getContext(), std::move(bias_scales), - GetElementType(op->getResult(0)) - .cast() + llvm::cast( + GetElementType(op->getResult(0))) .getZeroPoints(), /*quantization_dimension=*/0); Operation* bias_const_op = GetBiasConstOp(add_op); @@ -957,14 +956,14 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp } const auto result_quantized_type = - GetElementType(uniform_quantize_op->getResult(0)) - .cast(); + llvm::cast( + GetElementType(uniform_quantize_op->getResult(0))); const auto new_result_quantized_type = CreateI8F32UniformQuantizedType( uniform_quantize_op->getLoc(), *rewriter.getContext(), result_quantized_type.getScale(), result_quantized_type.getZeroPoint()); - output_type = op->getResult(0).getType().cast().clone( - new_result_quantized_type); + output_type = llvm::cast(op->getResult(0).getType()) + .clone(new_result_quantized_type); // Omit any bias and requantize ops as `tfl.fully_connected` outputs a // fused `qi8` type. FindUserOfType<>(uniform_quantize_op)->setOperand(0, op->getResult(0)); @@ -1007,10 +1006,21 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp // * The filter tensor's format is `[0, 1, i, o]`. // * Not a depthwise convolution. class RewriteQuantizedConvolutionOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; - LogicalResult match(stablehlo::ConvolutionOp op) const override { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::ConvolutionOp op) const { const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType(GetElementType(op.getResult())); const bool fuse_bias_constant = @@ -1056,8 +1066,7 @@ class RewriteQuantizedConvolutionOp return success(); } - void rewrite(stablehlo::ConvolutionOp op, - PatternRewriter& rewriter) const override { + void rewrite(stablehlo::ConvolutionOp op, PatternRewriter& rewriter) const { const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType(GetElementType(op.getResult())); stablehlo::ConvDimensionNumbersAttr dimension_numbers = @@ -1148,9 +1157,8 @@ class RewriteQuantizedConvolutionOp } } - private: static LogicalResult MatchInput(Value input) { - auto input_type = input.getType().cast(); + auto input_type = llvm::cast(input.getType()); if (const auto input_element_type = input_type.getElementType(); !IsI8F32UniformQuantizedType(input_element_type)) { LLVM_DEBUG(llvm::dbgs() @@ -1163,7 +1171,7 @@ class RewriteQuantizedConvolutionOp } static LogicalResult MatchFilter(Value filter) { - auto filter_type = filter.getType().cast(); + auto filter_type = llvm::cast(filter.getType()); const Type filter_element_type = filter_type.getElementType(); if (!IsI8F32UniformQuantizedPerAxisType(filter_type.getElementType())) { LLVM_DEBUG( @@ -1173,7 +1181,7 @@ class RewriteQuantizedConvolutionOp return failure(); } - if (filter_element_type.cast() + if (llvm::cast(filter_element_type) .getQuantizedDimension() != 3) { LLVM_DEBUG(llvm::dbgs() << "Quantized dimension should be 3. Got: " << filter_element_type << "\n"); @@ -1220,7 +1228,7 @@ class RewriteQuantizedConvolutionOp tfl_pad_values.push_back(0); const auto input_tensor_type = - input_value.getType().cast(); + llvm::cast(input_value.getType()); const int64_t rank = input_tensor_type.getRank(); SmallVector padded_output_tensor_shape = @@ -1356,12 +1364,12 @@ class RewriteQuantizedConvolutionOp std::tuple GetInOutDimensions( stablehlo::ConvolutionOp op, stablehlo::ConvDimensionNumbersAttr dimension_numbers) const { - const auto [input_height, input_width] = - GetDimSize(op->getOperand(0).getType().cast().getShape(), - dimension_numbers.getInputSpatialDimensions()); - const auto [output_height, output_width] = - GetDimSize(op->getResult(0).getType().cast().getShape(), - dimension_numbers.getOutputSpatialDimensions()); + const auto [input_height, input_width] = GetDimSize( + llvm::cast(op->getOperand(0).getType()).getShape(), + dimension_numbers.getInputSpatialDimensions()); + const auto [output_height, output_width] = GetDimSize( + llvm::cast(op->getResult(0).getType()).getShape(), + dimension_numbers.getOutputSpatialDimensions()); return {input_height, input_width, output_height, output_width}; } @@ -1400,7 +1408,8 @@ class RewriteQuantizedConvolutionOp Value filter_value = op.getOperand(1); Operation* filter_op = filter_value.getDefiningOp(); auto filter_uniform_quantized_type = - GetElementType(filter_value).cast(); + llvm::cast( + GetElementType(filter_value)); auto filter_constant_value_attr = cast( cast(filter_value.getDefiningOp()).getValue()); const DenseIntElementsAttr new_filter_value_attr = @@ -1443,8 +1452,8 @@ class RewriteQuantizedConvolutionOp const SmallVector bias_shape, const bool has_i32_output, const bool fuse_bias_constant) const { const SmallVector bias_scales = GetBiasScales( - /*input_scale=*/GetElementType(op.getOperand(0)) - .cast() + /*input_scale=*/llvm::cast( + GetElementType(op.getOperand(0))) .getScale(), /*filter_scales=*/new_filter_quantized_type.getScales()); @@ -1481,17 +1490,16 @@ class RewriteQuantizedConvolutionOp // Rewrites quantized `stablehlo.transpose` to `tfl.transpose`. class RewriteQuantizedTransposeOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; - - LogicalResult match(stablehlo::TransposeOp op) const override { - return success(IsOpFullyQuantized(op)); - } + using OpRewritePattern::OpRewritePattern; - void rewrite(stablehlo::TransposeOp op, - PatternRewriter& rewriter) const override { - auto operand_type = op.getOperand().getType().cast(); + LogicalResult matchAndRewrite(stablehlo::TransposeOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } + auto operand_type = llvm::cast(op.getOperand().getType()); const int64_t rank = operand_type.getRank(); ArrayRef shape(rank); TensorType permutation_type = @@ -1506,54 +1514,54 @@ class RewriteQuantizedTransposeOp rewriter.create(op.getLoc(), permutation_attr); rewriter.replaceOpWithNewOp(op, op.getOperand(), permutation); + return success(); } }; // Rewrites quantized stablehlo.reshape to tfl.reshape. class RewriteQuantizedReshapeOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::ReshapeOp op) const override { - return success(IsOpFullyQuantized(op)); - } - - void rewrite(stablehlo::ReshapeOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::ReshapeOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } rewriter.replaceOpWithNewOp( op, op.getOperand(), CreateI32ShapeConstantOp(op.getResult().getType(), op->getLoc(), rewriter)); + return success(); } }; class RewriteQuantizedDynamicReshapeOp - : public OpRewritePattern< - stablehlo::DynamicReshapeOp>::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; - - LogicalResult match(stablehlo::DynamicReshapeOp op) const override { - return success(IsQuantizedTensorType(op.getOperand().getType()) && - IsQuantizedTensorType(op.getResult().getType())); - } + using OpRewritePattern::OpRewritePattern; - void rewrite(stablehlo::DynamicReshapeOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::DynamicReshapeOp op, + PatternRewriter& rewriter) const override { + if (!IsQuantizedTensorType(op.getOperand().getType()) || + !IsQuantizedTensorType(op.getResult().getType())) { + return failure(); + } rewriter.replaceOpWithNewOp(op, op.getOperand(), op.getOutputShape()); + return success(); } }; // Rewrites quantized stablehlo.select to tfl.select_v2. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. -class RewriteQuantizedSelectOp - : public OpRewritePattern::SplitMatchAndRewrite { +class RewriteQuantizedSelectOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::SelectOp op) const override { + LogicalResult matchAndRewrite(stablehlo::SelectOp op, + PatternRewriter& rewriter) const override { if (!IsQuantizedTensorType(op.getOperand(1).getType())) { return failure(); } @@ -1563,52 +1571,47 @@ class RewriteQuantizedSelectOp if (!IsQuantizedTensorType(op.getResult().getType())) { return failure(); } - return success(); - } - - void rewrite(stablehlo::SelectOp op, - PatternRewriter& rewriter) const override { Value pred = op.getOperand(0); Value on_true = op.getOperand(1); Value on_false = op.getOperand(2); rewriter.replaceOpWithNewOp(op, pred, on_true, on_false); + return success(); } }; // Rewrites quantized stablehlo.concatenate to tfl.concatenation. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. class RewriteQuantizedConcatenateOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::ConcatenateOp op) const override { - return success(IsOpFullyQuantized(op)); - } - - void rewrite(stablehlo::ConcatenateOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::ConcatenateOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } Type output_type = op.getResult().getType(); uint32_t axis = CastI64ToI32(op.getDimension()).value(); rewriter.replaceOpWithNewOp( op, output_type, op.getOperands(), axis, /*fused_activation_function=*/rewriter.getStringAttr("NONE")); + return success(); } }; // Rewrites quantized stablehlo.pad to tfl.padv2. // tfl.dilate is introduced in between when interior padding exists. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. -class RewriteQuantizedPadOp - : public OpRewritePattern::SplitMatchAndRewrite { +class RewriteQuantizedPadOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::PadOp op) const override { - return success(IsOpFullyQuantized(op)); - } - - void rewrite(stablehlo::PadOp op, PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::PadOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } Value input = op.getOperand(); // If any of the interior padding is non-zero, operand should be dilated // first, and then padded. @@ -1617,7 +1620,7 @@ class RewriteQuantizedPadOp input = InsertDilateOp(op, rewriter); } - TensorType operand_type = input.getType().cast(); + TensorType operand_type = llvm::cast(input.getType()); const int64_t rank = operand_type.getRank(); // Shape of padding should be [rank, 2]. SmallVector shape{rank, 2}; @@ -1632,18 +1635,19 @@ class RewriteQuantizedPadOp padding_value.push_back(CastI64ToI32(padding_high[i]).value()); } - TensorType output_type = op.getResult().getType().cast(); + TensorType output_type = llvm::cast(op.getResult().getType()); Value constant_values = op.getPaddingValue(); auto padding_attr = DenseIntElementsAttr::get(padding_type, padding_value); auto padding = rewriter.create(op.getLoc(), padding_attr); rewriter.replaceOpWithNewOp(op, output_type, input, padding, constant_values); + return success(); } Value InsertDilateOp(stablehlo::PadOp op, PatternRewriter& rewriter) const { Value input = op.getOperand(); - TensorType operand_type = input.getType().cast(); + TensorType operand_type = llvm::cast(input.getType()); const int64_t rank = operand_type.getRank(); ArrayRef dilate_shape(rank); @@ -1663,7 +1667,7 @@ class RewriteQuantizedPadOp dilated_shape[i] = operand_shape[i] + interior_padding_i64[i] * (operand_shape[i] - 1); } - TensorType output_type = op.getResult().getType().cast(); + TensorType output_type = llvm::cast(op.getResult().getType()); Type dilated_output_type = output_type.clone(dilated_shape); Value constant_values = op.getPaddingValue(); @@ -1673,18 +1677,16 @@ class RewriteQuantizedPadOp }; // Rewrites quantized stablehlo.slice to tfl.slice or tfl.strided_slice. -class RewriteQuantizedSliceOp - : public OpRewritePattern::SplitMatchAndRewrite { +class RewriteQuantizedSliceOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; - - LogicalResult match(stablehlo::SliceOp op) const override { - return success(IsOpFullyQuantized(op)); - } + using OpRewritePattern::OpRewritePattern; - void rewrite(stablehlo::SliceOp op, - PatternRewriter& rewriter) const override { - auto operand_type = op.getOperand().getType().cast(); + LogicalResult matchAndRewrite(stablehlo::SliceOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } + auto operand_type = llvm::cast(op.getOperand().getType()); Type output_type = op.getResult().getType(); const int64_t rank = operand_type.getRank(); @@ -1716,7 +1718,7 @@ class RewriteQuantizedSliceOp if (llvm::all_of(strides, [](int64_t stride) { return stride == 1; })) { rewriter.replaceOpWithNewOp( op, output_type, op.getOperand(), start_idx, slice_size); - return; + return success(); } SmallVector stride_i32 = CastI64ArrayToI32(strides).value(); @@ -1727,6 +1729,7 @@ class RewriteQuantizedSliceOp /*begin_mask=*/0, /*end_mask=*/0, /*ellipsis_mask=*/0, /*new_axis_mask=*/0, /*shrink_axis_mask=*/0, /*offset=*/false); + return success(); } }; @@ -1736,19 +1739,17 @@ class RewriteQuantizedSliceOp // output rank. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. class RewriteQuantizedBroadcastInDimOp - : public OpRewritePattern< - stablehlo::BroadcastInDimOp>::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::BroadcastInDimOp op) const override { - return success(IsOpFullyQuantized(op)); - } - - void rewrite(stablehlo::BroadcastInDimOp op, - PatternRewriter& rewriter) const override { - auto operand_type = op.getOperand().getType().cast(); - auto output_type = op.getResult().getType().cast(); + LogicalResult matchAndRewrite(stablehlo::BroadcastInDimOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } + auto operand_type = llvm::cast(op.getOperand().getType()); + auto output_type = llvm::cast(op.getResult().getType()); Value input = op.getOperand(); // If broadcast_dimensions is not in ascending order, transpose first. @@ -1773,6 +1774,7 @@ class RewriteQuantizedBroadcastInDimOp rewriter.replaceOpWithNewOp(op, output_type, input, shape); + return success(); } Value InsertTransposeOp(stablehlo::BroadcastInDimOp op, @@ -1786,7 +1788,7 @@ class RewriteQuantizedBroadcastInDimOp return static_cast(llvm::find(sorted_dims, dim) - sorted_dims.begin()); })); - auto operand_type = op.getOperand().getType().cast(); + auto operand_type = llvm::cast(op.getOperand().getType()); TensorType perm_type = operand_type.cloneWith( {static_cast(permutation.size())}, rewriter.getI32Type()); auto perm_attr = DenseIntElementsAttr::get(perm_type, permutation); @@ -1799,7 +1801,7 @@ class RewriteQuantizedBroadcastInDimOp Value InsertExpandDimsOp(stablehlo::BroadcastInDimOp op, PatternRewriter& rewriter, Value input, int64_t output_rank) const { - auto input_type = input.getType().cast(); + auto input_type = llvm::cast(input.getType()); SmallVector input_shape(input_type.getShape()); SmallVector input_dims = llvm::to_vector(op.getBroadcastDimensions()); @@ -1834,10 +1836,20 @@ class RewriteQuantizedBroadcastInDimOp // Rewrites quantized stablehlo.reduce_window with max to tfl.max_pool_2d. class RewriteQuantizedReduceWindowOpWithMax - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(stablehlo::ReduceWindowOp op, + PatternRewriter& rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: LogicalResult MatchBinaryReduceFunction(Region& function) const { Block& body = function.front(); if (body.getNumArguments() != 2) return failure(); @@ -1853,7 +1865,7 @@ class RewriteQuantizedReduceWindowOpWithMax reduce_op.getRhs() == body.getArgument(1)); } - LogicalResult match(stablehlo::ReduceWindowOp op) const override { + LogicalResult match(stablehlo::ReduceWindowOp op) const { // Check that the reduce-window is a max-reduce-window. if (failed(MatchBinaryReduceFunction(op.getBody()))) { return failure(); @@ -1887,8 +1899,7 @@ class RewriteQuantizedReduceWindowOpWithMax return success(IsOpFullyQuantized(op)); } - void rewrite(stablehlo::ReduceWindowOp op, - PatternRewriter& rewriter) const override { + void rewrite(stablehlo::ReduceWindowOp op, PatternRewriter& rewriter) const { Type result_type = op.getResult(0).getType(); Value input = op.getOperand(0); // Ops with padding is rejected in matching function, so we can use the @@ -1929,12 +1940,12 @@ class RewriteQuantizedReduceWindowOpWithMax // Condition 3 - `offset_dims` should be the last dimensions of `output`. // Condition 4 - shape of slice should be same with shape of input on the // offset dimensions. -class RewriteQuantizedGatherOp - : public OpRewritePattern::SplitMatchAndRewrite { +class RewriteQuantizedGatherOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::GatherOp op) const override { + LogicalResult matchAndRewrite(stablehlo::GatherOp op, + PatternRewriter& rewriter) const override { const Type input_type = op.getOperand().getType(); const Type output_type = op.getResult().getType(); if (!IsQuantizedTensorType(input_type) || @@ -1942,7 +1953,7 @@ class RewriteQuantizedGatherOp return failure(); } - auto output_tensor_type = output_type.cast(); + auto output_tensor_type = llvm::cast(output_type); if (!output_tensor_type.hasRank()) { return failure(); } @@ -1998,7 +2009,7 @@ class RewriteQuantizedGatherOp // Input type is checked to be quantized tensor type. const auto input_shape = - op.getOperand().getType().cast().getShape(); + llvm::cast(op.getOperand().getType()).getShape(); SmallVector input_offset_shape; for (int64_t i = 0; i < input_shape.size(); ++i) { if (!llvm::is_contained(start_index_map, i)) { @@ -2014,38 +2025,31 @@ class RewriteQuantizedGatherOp } } - return success(); - } - - void rewrite(stablehlo::GatherOp op, - PatternRewriter& rewriter) const override { rewriter.replaceOpWithNewOp( op, /*output=*/op.getResult().getType(), /*params=*/op.getOperand(), /*indices=*/op.getStartIndices()); + return success(); } }; // Rewrites quantized stablehlo.dynamic_slice to tfl.slice. // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. class RewriteQuantizedDynamicSliceOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::DynamicSliceOp op) const override { + LogicalResult matchAndRewrite(stablehlo::DynamicSliceOp op, + PatternRewriter& rewriter) const override { if (!IsQuantizedTensorType(op.getOperand().getType()) || - !IsQuantizedTensorType(op.getResult().getType())) { + !IsQuantizedTensorType(op.getResult().getType()) || + !quant::HasStaticShape(op.getOperand())) { return failure(); } - return success(quant::HasStaticShape(op.getOperand())); - } - - void rewrite(stablehlo::DynamicSliceOp op, - PatternRewriter& rewriter) const override { Type output = op.getResult().getType(); Value input = op.getOperand(); - TensorType operand_type = input.getType().cast(); + TensorType operand_type = llvm::cast(input.getType()); ArrayRef operand_shape = operand_type.getShape(); const int64_t rank = operand_type.getRank(); const Type i64_type = rewriter.getI64Type(); @@ -2098,20 +2102,20 @@ class RewriteQuantizedDynamicSliceOp auto size = rewriter.create(op.getLoc(), size_attr); rewriter.replaceOpWithNewOp(op, output, input, begin, size); + return success(); } }; -class RewriteQuantizedAddOp - : public OpRewritePattern::SplitMatchAndRewrite { +class RewriteQuantizedAddOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::AddOp op) const override { - return success(IsI8F32UniformQuantizedType(GetElementType(op.getLhs())) && - IsI8F32UniformQuantizedType(GetElementType(op.getRhs()))); - } - - void rewrite(stablehlo::AddOp op, PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::AddOp op, + PatternRewriter& rewriter) const override { + if (!IsI8F32UniformQuantizedType(GetElementType(op.getLhs())) || + !IsI8F32UniformQuantizedType(GetElementType(op.getRhs()))) { + return failure(); + } TFL::QConstOp lhs_qconst_op; TFL::QConstOp rhs_qconst_op; @@ -2121,7 +2125,7 @@ class RewriteQuantizedAddOp auto stablehlo_const_op = dyn_cast_or_null( broadcast_op.getOperand().getDefiningOp()); auto const_uniform_quantized_type = - stablehlo_const_op.getResult().getType().cast(); + llvm::cast(stablehlo_const_op.getResult().getType()); return rewriter.create( op.getLoc(), TypeAttr::get(const_uniform_quantized_type), cast(stablehlo_const_op.getValue())); @@ -2137,24 +2141,25 @@ class RewriteQuantizedAddOp lhs_qconst_op ? lhs_qconst_op : op.getOperand(0), rhs_qconst_op ? rhs_qconst_op : op.getOperand(1), /*fused_activation_function=*/rewriter.getStringAttr("NONE")); + return success(); } }; // Rewrites quantized `stablehlo.constant` to `tfl.pseudo_qconst`. class RewriteQuantizedConstantOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; - - LogicalResult match(stablehlo::ConstantOp op) const override { - return success(IsQuantizedTensorType(op.getOutput().getType())); - } + using OpRewritePattern::OpRewritePattern; - void rewrite(stablehlo::ConstantOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::ConstantOp op, + PatternRewriter& rewriter) const override { + if (!IsQuantizedTensorType(op.getOutput().getType())) { + return failure(); + } rewriter.replaceOpWithNewOp( op, /*qtype=*/TypeAttr::get(op.getOutput().getType()), /*value=*/op.getValue()); + return success(); } }; @@ -2163,28 +2168,28 @@ class RewriteQuantizedConstantOp // `stablehlo.dot_general` op relies on existing passes for conversion of // StableHLO -> MHLO -> TF -> TFL. class RewriteHybridQuantizedDotGeneralOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::DotGeneralOp op) const override { + LogicalResult matchAndRewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const override { // Lhs and result should not be quantized and rhs should be quantized. - return success(!IsQuantizedTensorType(op->getOperand(0).getType()) && - IsQuantizedTensorType(op->getOperand(1).getType()) && - !IsQuantizedTensorType(op->getResult(0).getType())); - } - - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const override { + if (IsQuantizedTensorType(op->getOperand(0).getType()) || + !IsQuantizedTensorType(op->getOperand(1).getType()) || + IsQuantizedTensorType(op->getResult(0).getType())) { + return failure(); + } Value rhs = op.getRhs(); Type lhs_element_type = - op.getLhs().getType().template cast().getElementType(); + llvm::cast(op.getLhs().getType()).getElementType(); Type dequantized_rhs_type = quant::CloneTypeWithNewElementType(rhs.getType(), lhs_element_type); auto dq = rewriter.create( op->getLoc(), /*output=*/dequantized_rhs_type, /*input=*/rhs); rewriter.replaceAllUsesExcept(rhs, dq.getOutput(), dq); + return success(); } }; @@ -2194,26 +2199,24 @@ class RewriteHybridQuantizedDotGeneralOp // Legalization of float `stablehlo.convolution` op relies on existing passes // for conversion of StableHLO -> MHLO -> TF -> TFL. class RewriteHybridQuantizedConvolutionOp - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: explicit RewriteHybridQuantizedConvolutionOp(MLIRContext* ctx) - : OpRewritePattern::SplitMatchAndRewrite( - ctx, /*benefit=*/5) {} + : OpRewritePattern(ctx, /*benefit=*/5) {} - LogicalResult match(stablehlo::ConvolutionOp op) const override { + LogicalResult matchAndRewrite(stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { if (failed(MatchConvolutionFormat(op))) { LLVM_DEBUG(llvm::dbgs() << "Failed to match dimension format for convolution_op.\n"); return failure(); } // Lhs and result should not be quantized and rhs should be quantized. - return success(!IsQuantizedTensorType(op->getOperand(0).getType()) && - IsQuantizedTensorType(op->getOperand(1).getType()) && - !IsQuantizedTensorType(op->getResult(0).getType())); - } - - void rewrite(stablehlo::ConvolutionOp op, - PatternRewriter& rewriter) const override { + if (IsQuantizedTensorType(op->getOperand(0).getType()) || + !IsQuantizedTensorType(op->getOperand(1).getType()) || + IsQuantizedTensorType(op->getResult(0).getType())) { + return failure(); + } const bool is_depthwise = IsDepthwiseConvolution(op); Operation* filter_op = op.getRhs().getDefiningOp(); @@ -2236,13 +2239,14 @@ class RewriteHybridQuantizedConvolutionOp op.setDimensionNumbersAttr(new_dimension_numbers); Type lhs_element_type = - op.getOperand(0).getType().template cast().getElementType(); + llvm::cast(op.getOperand(0).getType()).getElementType(); Type dequantized_rhs_type = quant::CloneTypeWithNewElementType( new_filter.getType(), lhs_element_type); auto dq = rewriter.create( op->getLoc(), /*output=*/dequantized_rhs_type, /*input=*/new_filter); rewriter.replaceAllUsesExcept(filter_op->getResult(0), dq.getOutput(), dq); + return success(); } private: @@ -2250,11 +2254,12 @@ class RewriteHybridQuantizedConvolutionOp Type GetNewWeightQuantizedType(MLIRContext* context, Location location, ArrayRef new_shape, Type filter_type, bool is_depthwise) const { - auto tensor_type = filter_type.cast(); + auto tensor_type = llvm::cast(filter_type); auto element_type = tensor_type.getElementType(); RankedTensorType new_filter_result_type; - if (element_type.isa()) { - auto per_axis_type = element_type.cast(); + if (llvm::isa(element_type)) { + auto per_axis_type = + llvm::cast(element_type); int64_t kernel_output_feature_dim = GetConvolutionKernelOutputFeatureDimension(is_depthwise); auto new_filter_quantized_type = CreateI8F32UniformQuantizedPerAxisType( @@ -2266,8 +2271,9 @@ class RewriteHybridQuantizedConvolutionOp RankedTensorType::getChecked(location, /*shape=*/new_shape, /*type=*/new_filter_quantized_type); - } else if (element_type.isa()) { - auto per_tensor_type = element_type.cast(); + } else if (llvm::isa(element_type)) { + auto per_tensor_type = + llvm::cast(element_type); new_filter_result_type = RankedTensorType::getChecked(location, /*shape=*/new_shape, diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc deleted file mode 100644 index b120a6f02e1460..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" - -#include - -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/mlir_hlo/utils/hlo_utils.h" - -namespace mlir { -namespace odml { - -mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, - OpBuilder* builder) { - return builder->create(loc, - hlo::getScalarOfType(ty, raw_value)); -} - -mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, - OpBuilder* builder) { - return builder->create(loc, - hlo::getScalarNegZeroOfType(ty)); -} - -DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) { - RankedTensorType ty = - RankedTensorType::get(static_cast(attr.size()), - IntegerType::get(attr.getContext(), 64)); - return DenseIntElementsAttr::get(ty, attr.getValue()); -} - -DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, - Builder* builder) { - RankedTensorType ty = RankedTensorType::get( - {static_cast(values.size())}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, values); -} - -} // namespace odml -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h deleted file mode 100644 index fc7c2316655df9..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ - -#include - -#include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" - -namespace mlir { -namespace odml { - -// Builds body for reduce op by using the template binary op as the -// reducer op. -template -void BuildReduceBody(Type element_type, Region* body, OpBuilder* builder) { - OpBuilder::InsertionGuard guard(*builder); - Block* block = builder->createBlock(body); - - // Block arguments are scalars of the given element type. - Type type = RankedTensorType::get(/*shape=*/{}, element_type); - Location loc = body->getLoc(); - block->addArguments({type, type}, SmallVector(2, loc)); - - auto reducer = - builder->create(loc, block->getArgument(0), block->getArgument(1)); - builder->create(loc, reducer.getResult()); -} - -mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, - OpBuilder* builder); - -mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, - OpBuilder* builder); - -// Converts an ArrayAttr to a 1D 64-bit dense elements attribute. -DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr); -DenseIntElementsAttr GetI64ElementsAttr(llvm::ArrayRef values, - Builder* builder); - -} // namespace odml -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc deleted file mode 100644 index 40d3cc27164427..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" - -#include - -#include -#include -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" - -namespace mlir { -namespace odml { -namespace { - -TEST(UtilsTest, GetScalarConstOfType) { - MLIRContext context; - context.loadDialect(); - OpBuilder builder(&context); - Location loc = UnknownLoc::get(&context); - Type ty = builder.getI32Type(); - mhlo::ConstantOp op = GetScalarConstOfType(ty, loc, 123, &builder); - EXPECT_EQ(op.getValue().getValues()[0], 123); - - op->destroy(); -} - -TEST(UtilsTest, GetScalarNegZeroOfType) { - MLIRContext context; - context.loadDialect(); - OpBuilder builder(&context); - Location loc = UnknownLoc::get(&context); - Type ty = builder.getF32Type(); - mhlo::ConstantOp op = GetScalarNegZeroOfType(ty, loc, &builder); - EXPECT_EQ(op.getValue().getValues()[0], -0.f); - - op->destroy(); -} - -TEST(UtilsTest, GetI64ElementsAttr) { - MLIRContext context; - context.loadDialect(); - OpBuilder builder(&context); - Location loc = UnknownLoc::get(&context); - SmallVector values = {1, 2, 3}; - auto valuesAttr = builder.getI64ArrayAttr(values); - DenseIntElementsAttr attr = GetI64ElementsAttr(valuesAttr); - EXPECT_THAT(SmallVector(attr.getValues()), - testing::ElementsAreArray(values)); -} - -TEST(UtilsTest, GetI64ElementsAttrBuilder) { - MLIRContext context; - context.loadDialect(); - OpBuilder builder(&context); - Location loc = UnknownLoc::get(&context); - SmallVector values = {1, 2, 3}; - DenseIntElementsAttr attr = GetI64ElementsAttr(values, &builder); - EXPECT_THAT(SmallVector(attr.getValues()), - testing::ElementsAreArray(values)); -} - -} // namespace - -} // namespace odml -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/tests/cleanup_optimization_barrier.mlir b/tensorflow/compiler/mlir/lite/tests/cleanup_optimization_barrier.mlir new file mode 100644 index 00000000000000..12625023255f09 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/cleanup_optimization_barrier.mlir @@ -0,0 +1,14 @@ +// RUN: tf-opt %s --tfl-cleanup-optimization-barrier --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @cleanup_barrier(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> +// CHECK: %1 = tfl.add(%0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> +// CHECK: return %1 : tensor<2x2xf32> + +func.func @cleanup_barrier(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = arith.constant dense<5.000000e+00> : tensor + %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %1 = stablehlo.optimization_barrier %0 : tensor<2x2xf32> + %2 = tfl.add(%1, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + return %2 : tensor<2x2xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir index 12de9da5939573..adb22ddd009a80 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir @@ -4,11 +4,14 @@ module { // CHECK-LABEL: func.func public @main func.func public @main(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> { - // CHECK-ROUNDTRIP: %[[iota_1:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32 - // CHECK-ROUNDTRIP: %[[iota_2:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32> - // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%[[iota_1]], %[[iota_2]], %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : + // CHECK-ROUNDTRIP: %0 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]]]]> : tensor<1x3x1x1xi32>}> : () -> tensor<1x3x1x1xi32> + // CHECK-ROUNDTRIP: %1 = "tfl.pseudo_const"() <{value = dense<[4, 3, 5, 1]> : tensor<4xi64>}> : () -> tensor<4xi64> + // CHECK-ROUNDTRIP: %2 = "tfl.broadcast_to"(%0, %1) : (tensor<1x3x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %3 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]]], {{\[\[\[}}1]]], {{\[\[\[}}2]]], {{\[\[\[}}3]]]]> : tensor<4x1x1x1xi32>}> : () -> tensor<4x1x1x1xi32> + // CHECK-ROUNDTRIP: %4 = "tfl.broadcast_to"(%3, %1) : (tensor<4x1x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%2, %4, %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : // CHECK-ROUNDTRIP-SAME: (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> - // CHECK-ROUNDTRIP: %[[gather:.*]] = "stablehlo.gather"(%arg0, %2) <{ + // CHECK-ROUNDTRIP: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ // CHECK-ROUNDTRIP-SAME: dimension_numbers = #stablehlo.gather< // CHECK-ROUNDTRIP-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], // CHECK-ROUNDTRIP-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir index 44d1bb7dd8b72f..7e42ff310c080f 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir @@ -4,11 +4,14 @@ module { // CHECK-LABEL: func.func public @main func.func public @main(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> { - // CHECK-ROUNDTRIP: %[[iota_1:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32 - // CHECK-ROUNDTRIP: %[[iota_2:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32> - // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%[[iota_1]], %[[iota_2]], %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : + // CHECK-ROUNDTRIP: %0 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]]]]> : tensor<1x3x1x1xi32>}> : () -> tensor<1x3x1x1xi32> + // CHECK-ROUNDTRIP: %1 = "tfl.pseudo_const"() <{value = dense<[4, 3, 5, 1]> : tensor<4xi64>}> : () -> tensor<4xi64> + // CHECK-ROUNDTRIP: %2 = "tfl.broadcast_to"(%0, %1) : (tensor<1x3x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %3 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]]], {{\[\[\[}}1]]], {{\[\[\[}}2]]], {{\[\[\[}}3]]]]> : tensor<4x1x1x1xi32>}> : () -> tensor<4x1x1x1xi32> + // CHECK-ROUNDTRIP: %4 = "tfl.broadcast_to"(%3, %1) : (tensor<4x1x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%2, %4, %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : // CHECK-ROUNDTRIP-SAME: (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> - // CHECK-ROUNDTRIP: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %2, %arg2) <{ + // CHECK-ROUNDTRIP: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ // CHECK-ROUNDTRIP-SAME: scatter_dimension_numbers = #stablehlo.scatter // CHECK-ROUNDTRIP-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], // CHECK-ROUNDTRIP-SAME: scatter_dims_to_operand_dims = [0, 2, 1, 3], index_vector_dim = 3>}> diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir index 2c17e734c58dad..e0793cbf803c4f 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir @@ -5,7 +5,6 @@ func.func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> ten func.return %0: tensor<3x3xbf16> // CHECK-LABEL: broadcast_to_bf16 -// CHECK: [[CST:%.*]] = arith.constant dense<1.000000e+00> : tensor<3x3xbf16> -// CHECK: [[MUL:%.*]] = tfl.mul(%arg0, [[CST]]) <{fused_activation_function = "NONE"}> : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16> -// CHECK: return [[MUL]] : tensor<3x3xbf16> +// CHECK: %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16> +// CHECK: return %0 : tensor<3x3xbf16> } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index c0978d484ee11e..c3dc00ca74f1ae 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -2589,6 +2589,14 @@ func.func @dynamic_update_slice_f16_arg(%arg0: tensor<4x5xf16>, %arg1: tensor<1x // CHECK: "tfl.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<4x5xf16>, tensor<1x5xf16>, tensor<2xi32>) -> tensor<4x5xf16> } +func.func @dynamic_update_slice_i16(%arg0: tensor<4x5xi16>, %arg1: tensor<1x5xi16>, %arg2: tensor<2xi32>) -> tensor<4x5xi16> { + %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x5xi16>, tensor<1x5xi16>, tensor<2xi32>) -> tensor<4x5xi16> + func.return %0 : tensor<4x5xi16> + +// CHECK-LABEL:dynamic_update_slice_i16 +// CHECK: "tfl.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<4x5xi16>, tensor<1x5xi16>, tensor<2xi32>) -> tensor<4x5xi16> +} + func.func @testReluI32(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> func.return %0: tensor<1xi32> diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 772a462747d01a..56b82b9042593f 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -1759,6 +1759,14 @@ func.func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %a // ----- +func.func @testStridedSliceWithInvalidInputRank(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x1x1x2x2x5xf32> { + // expected-error @+1 {{op failed to verify that input (with new_axis) must have rank at most 5}} + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 6 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x1x1x2x2x5xf32> + func.return %0 : tensor<1x1x1x2x2x5xf32> +} + +// ----- + // CHECK-LABEL: testOneHot func.func @testOneHot(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<*xf32> { // CHECK: "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) <{axis = -1 : i32}> : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xf32> @@ -2601,6 +2609,13 @@ func.func @fully_connected(%arg0: tensor<1x37xf32>, %arg1: tensor<40x37xf32>, %a // ----- +func.func @fully_connected_with_int64_num_elements(%arg0: tensor<2048x128xf32>, %arg1: tensor<1049088x128xf32>, %arg2: none) -> tensor<2048x1049088xf32> { + %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) <{fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<2048x128xf32>, tensor<1049088x128xf32>, none) -> tensor<2048x1049088xf32> + func.return %0 : tensor<2048x1049088xf32> +} + +// ----- + func.func @fully_connected_no_bias(%arg0: tensor<2x2x10xf32>, %arg1: tensor<40x40xf32>, %arg2: none) -> tensor<1x40xf32> { %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x2x10xf32>, tensor<40x40xf32>, none) -> tensor<1x40xf32> func.return %0 : tensor<1x40xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 0a52298b17e7b9..84db8eb8f2a9a4 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -3718,6 +3718,46 @@ func.func @gelu_approximate(%arg0: tensor<3xf32>) -> tensor<3xf32> { // CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> } +func.func @gelu_approximate_with_mul(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.797884583> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor + %cst_3 = arith.constant dense<4.471500e-02> : tensor + %99 = "tfl.mul"(%arg0, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %0 = "tfl.mul"(%99, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %1 = "tfl.mul"(%0, %cst_3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %2 = "tfl.add"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tfl.tanh"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.add"(%4, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %7 = "tfl.mul"(%6, %5) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %7 : tensor<3xf32> + +// CHECK-LABEL:gelu_approximate +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> +} + +func.func @gelu_approximate_with_mul2(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.797884583> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor + %cst_3 = arith.constant dense<4.471500e-02> : tensor + %99 = "tfl.mul"(%arg0, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %0 = "tfl.mul"(%arg0, %99) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %1 = "tfl.mul"(%0, %cst_3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %2 = "tfl.add"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tfl.tanh"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.add"(%4, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %7 = "tfl.mul"(%6, %5) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %7 : tensor<3xf32> + +// CHECK-LABEL:gelu_approximate +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> +} + func.func @gelu_approximate1(%arg0: tensor<3xf32>) -> tensor<3xf32> { %cst = arith.constant dense<0.797884583> : tensor %cst_0 = arith.constant dense<5.000000e-01> : tensor @@ -3738,6 +3778,49 @@ func.func @gelu_approximate1(%arg0: tensor<3xf32>) -> tensor<3xf32> { // CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> } +func.func @gelu_approximate1_with_mul(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.797884583> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor + %cst_2 = arith.constant dense<3.000000e+00> : tensor + %cst_3 = arith.constant dense<4.471500e-02> : tensor + %99 = "tfl.mul"(%arg0, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %0 = "tfl.mul"(%99, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %1 = "tfl.mul"(%0, %cst_3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %2 = "tfl.add"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tfl.tanh"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.add"(%4, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%5, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %7 = "tfl.mul"(%arg0, %6) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %7 : tensor<3xf32> + +// CHECK-LABEL:gelu_approximate +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> +} + + +func.func @gelu_approximate1_with_mul1(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.797884583> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor + %cst_2 = arith.constant dense<3.000000e+00> : tensor + %cst_3 = arith.constant dense<4.471500e-02> : tensor + %99 = "tfl.mul"(%arg0, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %0 = "tfl.mul"(%arg0, %99) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %1 = "tfl.mul"(%0, %cst_3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %2 = "tfl.add"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tfl.tanh"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.add"(%4, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%5, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %7 = "tfl.mul"(%arg0, %6) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %7 : tensor<3xf32> + +// CHECK-LABEL:gelu_approximate +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> +} + func.func @gelu_approximate_no_match(%arg0: tensor<3xf32>) -> tensor<3xf32> { %cst = arith.constant dense<0.797884583> : tensor %cst_0 = arith.constant dense<5.000000e-01> : tensor @@ -4316,11 +4399,11 @@ func.func @FuseExcessBroadcastingOnReshapes(%arg0: tensor<1x8xf32>) -> tensor<1x %1 = "tfl.broadcast_to"(%0, %cst_0) : (tensor<1x1x1x8x1x1xf32>, tensor<6xi32>) -> tensor<1x1x1x8x16x1xf32> %2 = "tfl.reshape"(%1, %cst_1) : (tensor<1x1x1x8x16x1xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> return %2 : tensor<1x1x1x128xf32> - // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<8x16xf32> + // CHECK: %cst = arith.constant dense<[8, 16]> : tensor<2xi64> // CHECK: %cst_0 = arith.constant dense<[1, 1, 1, 128]> : tensor<4xi32> // CHECK: %cst_1 = arith.constant dense<[8, 1]> : tensor<2xi32> // CHECK: %0 = "tfl.reshape"(%arg0, %cst_1) : (tensor<1x8xf32>, tensor<2xi32>) -> tensor<8x1xf32> - // CHECK: %1 = tfl.mul(%0, %cst) <{fused_activation_function = "NONE"}> : (tensor<8x1xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK: %1 = "tfl.broadcast_to"(%0, %cst) : (tensor<8x1xf32>, tensor<2xi64>) -> tensor<8x16xf32> // CHECK: %2 = "tfl.reshape"(%1, %cst_0) : (tensor<8x16xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> // CHECK: return %2 : tensor<1x1x1x128xf32> } @@ -4342,83 +4425,63 @@ func.func @FuseExcessBroadcastingOnReshapesDynamicShapes(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> - // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - // CHECK: return %0 : tensor<3x3xf32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i32_low_dim func.func @broadcast_to_i32_low_dim(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x3xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> return %0 : tensor<3x3xi32> - // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> - // CHECK: return %0 : tensor<3x3xi32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_low_dim_with_unknown_shape func.func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> - // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - // CHECK: return %0 : tensor<3x3xf32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i16_low_dim func.func @broadcast_to_i16_low_dim(%arg0: tensor<3xi16>, %arg1: tensor<2xi32>) -> tensor<3x3xi16> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi16>, tensor<2xi32>) -> tensor<3x3xi16> return %0 : tensor<3x3xi16> - // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi16> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi16>, tensor<3x3xi16>) -> tensor<3x3xi16> - // CHECK: return %0 : tensor<3x3xi16> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i32_low_dim_with_unknown_output func.func @broadcast_to_i32_low_dim_with_unknown_output(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<*xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32> return %0 : tensor<*xi32> - // CHECK: %cst = arith.constant dense<1> : tensor - // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<*xi32> - // CHECK: %1 = tfl.mul(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32> - // CHECK: return %1 : tensor<*xi32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_ui32 func.func @broadcast_to_ui32(%arg0: tensor, %arg1: tensor<1xi64>) -> tensor<10xui32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor, tensor<1xi64>) -> tensor<10xui32> return %0 : tensor<10xui32> - // CHECK: %cst = arith.constant dense<1> : tensor<10xui32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor, tensor<10xui32>) -> tensor<10xui32> - // CHECK: return %0 : tensor<10xui32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_f32 func.func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> - // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - // CHECK: return %0 : tensor<3x3xf32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i32 func.func @broadcast_to_i32(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x3xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> return %0 : tensor<3x3xi32> - // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> - // CHECK: return %0 : tensor<3x3xi32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i32_with_dynamic_shape_and_output func.func @broadcast_to_i32_with_dynamic_shape_and_output(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x?xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x?xi32> return %0 : tensor<3x?xi32> - // CHECK: %cst = arith.constant dense<1> : tensor - // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<3x?xi32> - // CHECK: %1 = tfl.mul(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x?xi32>) -> tensor<3x?xi32> - // CHECK: return %1 : tensor<3x?xi32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_ui32_with_dynamic_output @@ -4610,3 +4673,75 @@ func.func @EliminateBooleanCastCompare(%arg0: tensor<*xi1>) -> (tensor<*xi1>, te // CHECK: %9 = "tfl.zeros_like"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> // CHECK: return %0, %1, %3, %arg0, %arg0, %4, %5, %7, %8, %arg0, %9, %arg0 : tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1> } + +// CHECK-LABEL: @ReorderTransposeReshapeTranspose +func.func @ReorderTransposeReshapeTranspose(%arg0: tensor<282x2048xf32>) -> tensor<2x1x282x1024xf32> { + %cst = arith.constant dense<[1, 0]> : tensor<2xi32> + %cst_1 = arith.constant dense<[2, 1024, 1, 282]> : tensor<4xi32> + %cst_2 = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<282x2048xf32>, tensor<2xi32>) -> tensor<2048x282xf32> + %1 = "tfl.reshape"(%0, %cst_1) : (tensor<2048x282xf32>, tensor<4xi32>) -> tensor<2x1024x1x282xf32> + %2 = "tfl.transpose"(%1, %cst_2) : (tensor<2x1024x1x282xf32>, tensor<4xi32>) -> tensor<2x1x282x1024xf32> + return %2: tensor<2x1x282x1024xf32> + + // CHECK: %cst = arith.constant dense<[1, 3, 0, 2]> : tensor<4xi32> + // CHECK-NEXT: %cst_0 = arith.constant dense<[282, 2, 1024, 1]> : tensor<4xi32> + // CHECK-NEXT: %0 = "tfl.reshape"(%arg0, %cst_0) : (tensor<282x2048xf32>, tensor<4xi32>) -> tensor<282x2x1024x1xf32> + // CHECK-NEXT: %1 = "tfl.transpose"(%0, %cst) : (tensor<282x2x1024x1xf32>, tensor<4xi32>) -> tensor<2x1x282x1024xf32> + // CHECK-NEXT: return %1 : tensor<2x1x282x1024xf32> +} + +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConst +func.func @FullyConnectedSwapOperandsWhenLHSIsConst(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<2x4xf32> { + %cst = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x2xf32>, tensor<4x2xf32>, none) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> + + // CHECK: %cst = arith.constant dense<[1, 0]> : tensor<2xi32> + // CHECK-NEXT: %cst_0 = arith.constant dense<{{\[}}[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32> + // CHECK-NEXT: %0 = "tfl.fully_connected"(%arg0, %cst_0, %arg1) <{asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32> + // CHECK-NEXT: %1 = "tfl.transpose"(%0, %cst) : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<2x4xf32> + // CHECK-NEXT: return %1 : tensor<2x4xf32> +} + +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConstBias +func.func @FullyConnectedSwapOperandsWhenLHSIsConstBias(%arg0: tensor<4x2xf32>) -> tensor<2x4xf32> { + %cst = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %cst_1 = arith.constant dense<2.0> : tensor<2xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %cst_1) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x2xf32>, tensor<4x2xf32>, tensor<2xf32>) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> + + // CHECK: [[cst:%.*]] = arith.constant + // CHECK-NEXT: [[cst_1:%.*]] = arith.constant + // CHECK-NOT: %0 = "tfl.fully_connected"(%arg0, [[cst]], [[cst_1]]) +} + +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConstKeepNumDimsTrue +func.func @FullyConnectedSwapOperandsWhenLHSIsConstKeepNumDimsTrue(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<2x4xf32> { + %cst = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<2x2xf32>, tensor<4x2xf32>, none) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> + + // CHECK: [[cst:%.*]] = arith.constant + // CHECK-NOT: %0 = "tfl.fully_connected"(%arg0, [[cst]], %arg1) +} + +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConstFusedActivationFunction +func.func @FullyConnectedSwapOperandsWhenLHSIsConstFusedActivationFunction(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<2x4xf32> { + %cst = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) {asymmetric_quantize_inputs = true, fused_activation_function = "RELU", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x2xf32>, tensor<4x2xf32>, none) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> + + // CHECK: [[cst:%.*]] = arith.constant + // CHECK-NOT: %0 = "tfl.fully_connected"(%arg0, [[cst]], %arg1) +} + +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConstLHSRank3 +func.func @FullyConnectedSwapOperandsWhenLHSIsConstLHSRank3(%arg0: tensor<512x512xf32>, %arg1: none) -> tensor<1x1x512xf32> { + %cst = arith.constant dense<1.0> : tensor<1x1x512xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) {asymmetric_quantize_inputs = true, fused_activation_function = "RELU", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x512xf32>, tensor<512x512xf32>, none) -> tensor<1x1x512xf32> + func.return %0 : tensor<1x1x512xf32> + + // CHECK: %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) +} + diff --git a/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir b/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir index 79f50aaaadab3d..39b1346bcf93d6 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir @@ -170,3 +170,36 @@ func.func @BatchmatmulToReduceSumF32(%arg0: tensor<1x16384x257xf32>) -> (tensor< // CHECK: %[[CONST_DIM:.*]] = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> // CHECK: %[[RED:.*]] = "tfl.sum"(%arg0, %[[CONST_DIM]]) <{keep_dims = true}> : (tensor<1x16384x257xf32>, tensor<1xi32>) -> tensor<1x1x257xf32> } + +// CHECK-LABEL: FuseBatchMatmulToTransposeNoBatchDims +func.func @FuseBatchMatmulToTransposeNoBatchDims(%arg0: tensor<2048x32x128xf32>, %arg1: tensor<4x128xf32>) -> tensor<4x65536xf32> { + %36 = "tfl.pseudo_const"() <{value = dense<[2, 0, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> + %37 = "tfl.transpose"(%arg0, %36) : (tensor<2048x32x128xf32>, tensor<3xi32>) -> tensor<128x2048x32xf32> + %38 = "tfl.pseudo_const"() <{value = dense<[128, 65536]> : tensor<2xi32>}> : () -> tensor<2xi32> + %39 = "tfl.reshape"(%37, %38) : (tensor<128x2048x32xf32>, tensor<2xi32>) -> tensor<128x65536xf32> + %41 = "tfl.batch_matmul"(%arg1, %39) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x128xf32>, tensor<128x65536xf32>) -> tensor<4x65536xf32> + return %41 : tensor<4x65536xf32> + // CHECK-NOT: "tfl.transpose" +} + +// CHECK-LABEL: FuseBatchMatmulToTransposeWithBatchDims +func.func @FuseBatchMatmulToTransposeWithBatchDims(%arg0: tensor<2048x1x8x32x32xf32>, %arg1: tensor<2048x1x2x32xf32>) -> tensor<2048x1x2x256xf32> { + %104 = "tfl.pseudo_const"() <{value = dense<[0, 1, 4, 2, 3]> : tensor<5xi32>}> : () -> tensor<5xi32> + %106 = "tfl.pseudo_const"() <{value = dense<[2048, 1, 32, 256]> : tensor<4xi32>}> : () -> tensor<4xi32> + %202 = "tfl.transpose"(%arg0, %104) : (tensor<2048x1x8x32x32xf32>, tensor<5xi32>) -> tensor<2048x1x32x8x32xf32> + %203 = "tfl.reshape"(%202, %106) : (tensor<2048x1x32x8x32xf32>, tensor<4xi32>) -> tensor<2048x1x32x256xf32> + %204 = "tfl.batch_matmul"(%arg1, %203) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2048x1x2x32xf32>, tensor<2048x1x32x256xf32>) -> tensor<2048x1x2x256xf32> + return %204 : tensor<2048x1x2x256xf32> + // CHECK-NOT: "tfl.transpose" +} + +// CHECK-LABEL: FuseBatchMatmulToTransposeNegative +func.func @FuseBatchMatmulToTransposeNegative(%arg0: tensor<2048x32x1x8x2xf32>, %arg1: tensor<2048x1x32x2xf32>) -> tensor<2048x1x32x256xf32> { + %88 = "tfl.pseudo_const"() <{value = dense<[0, 2, 4, 1, 3]> : tensor<5xi32>}> : () -> tensor<5xi32> + %90 = "tfl.pseudo_const"() <{value = dense<[2048, 1, 2, 256]> : tensor<4xi32>}> : () -> tensor<4xi32> + %194 = "tfl.transpose"(%arg0, %88) : (tensor<2048x32x1x8x2xf32>, tensor<5xi32>) -> tensor<2048x1x2x32x8xf32> + %195 = "tfl.reshape"(%194, %90) : (tensor<2048x1x2x32x8xf32>, tensor<4xi32>) -> tensor<2048x1x2x256xf32> + %196 = "tfl.batch_matmul"(%arg1, %195) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2048x1x32x2xf32>, tensor<2048x1x2x256xf32>) -> tensor<2048x1x32x256xf32> + return %196 : tensor<2048x1x32x256xf32> + // CHECK: "tfl.transpose" +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/optimize_broadcast_like.mlir b/tensorflow/compiler/mlir/lite/tests/optimize_broadcast_like.mlir index 514b2812816699..4940eebc701eab 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize_broadcast_like.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize_broadcast_like.mlir @@ -1,4 +1,5 @@ -// RUN: tf-opt -tfl-optimize-broadcast-like -split-input-file %s | FileCheck %s +// RUN: tf-opt -tfl-optimize-broadcast-like='unsafe-fuse-dynamic-shaped-broadcast=false' -split-input-file %s | FileCheck %s +// RUN: tf-opt -tfl-optimize-broadcast-like='unsafe-fuse-dynamic-shaped-broadcast=true' -split-input-file %s | FileCheck --check-prefix=UNSAFE-DYNAMIC-CHECK %s // CHECK-LABEL: @broadcast_mul0 func.func @broadcast_mul0(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { @@ -1162,3 +1163,72 @@ func.func @broadcast_zeros_like(%arg0: tensor<1x2xf32>) -> (tensor<2x2xf32>) { // CHECK: %[[broadcasted:.*]] = "tfl.broadcast_to"(%[[zeros]], %[[constant]]) : (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x2xf32> // CHECK: return %[[broadcasted]] } + +// CHECK-LABEL: @broadcast_mul_dynamic_rhs +func.func @broadcast_mul_dynamic_rhs(%arg0: tensor, %arg1: tensor<1x7xf32>) -> tensor { + %shape = "tfl.shape"(%arg0) : (tensor) -> tensor<2xi32> + %0 = "tfl.broadcast_to"(%arg1, %shape) : (tensor<1x7xf32>, tensor<2xi32>) -> tensor + %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor, tensor) -> tensor + func.return %1 : tensor + // UNSAFE-DYNAMIC-CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor, tensor<1x7xf32>) -> tensor +} + +// CHECK-LABEL: @broadcast_mul_dynamic_rhs2 +func.func @broadcast_mul_dynamic_rhs2(%arg0: tensor, %arg1: tensor<7xf32>) -> tensor { + %shape = "tfl.shape"(%arg0) : (tensor) -> tensor<2xi32> + %0 = "tfl.broadcast_to"(%arg1, %shape) : (tensor<7xf32>, tensor<2xi32>) -> tensor + %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor, tensor) -> tensor + func.return %1 : tensor + // UNSAFE-DYNAMIC-CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor, tensor<7xf32>) -> tensor +} + +// CHECK-LABEL: @broadcast_mul_dynamic_lhs +func.func @broadcast_mul_dynamic_lhs(%arg0: tensor<1x7xf32>, %arg1: tensor) -> tensor { + %shape = "tfl.shape"(%arg1) : (tensor) -> tensor<2xi32> + %0 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<1x7xf32>, tensor<2xi32>) -> tensor + %1 = "tfl.mul"(%0, %arg1) {fused_activation_function = "NONE"} : (tensor, tensor) -> tensor + func.return %1 : tensor + // UNSAFE-DYNAMIC-CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x7xf32>, tensor) -> tensor +} + +// CHECK-LABEL: @move_broadcast_through_sum +func.func @move_broadcast_through_sum(%arg0: tensor<1x1x40x100x40x3xf32>) -> tensor<1x4x100x40x3xf32> { + %cst_0 = arith.constant dense<[1, 4, 40, 100, 40, 3]> : tensor<6xi64> + %cst_1 = arith.constant dense<2> : tensor<1xi32> + %0 = "tfl.broadcast_to"(%arg0, %cst_0) : (tensor<1x1x40x100x40x3xf32>, tensor<6xi64>) -> tensor<1x4x40x100x40x3xf32> + %1 = "tfl.sum"(%0, %cst_1) <{keep_dims = false}> : (tensor<1x4x40x100x40x3xf32>, tensor<1xi32>) -> tensor<1x4x100x40x3xf32> + return %1 : tensor<1x4x100x40x3xf32> + // CHECK: %cst = arith.constant dense<[1, 4, 100, 40, 3]> : tensor<5xi32> + // CHECK: %cst_0 = arith.constant dense<2> : tensor<1xi32> + // CHECK: %0 = "tfl.sum"(%arg0, %cst_0) <{keep_dims = false}> : (tensor<1x1x40x100x40x3xf32>, tensor<1xi32>) -> tensor<1x1x100x40x3xf32> + // CHECK: %1 = "tfl.broadcast_to"(%0, %cst) : (tensor<1x1x100x40x3xf32>, tensor<5xi32>) -> tensor<1x4x100x40x3xf32> + // CHECK: return %1 : tensor<1x4x100x40x3xf32> +} + +// CHECK-LABEL: @move_broadcast_through_sum_keep_dims +func.func @move_broadcast_through_sum_keep_dims(%arg0: tensor<1x1x40x100x40x3xf32>) -> tensor<1x4x1x100x40x3xf32> { + %cst_0 = arith.constant dense<[1, 4, 40, 100, 40, 3]> : tensor<6xi64> + %cst_1 = arith.constant dense<2> : tensor<1xi32> + %0 = "tfl.broadcast_to"(%arg0, %cst_0) : (tensor<1x1x40x100x40x3xf32>, tensor<6xi64>) -> tensor<1x4x40x100x40x3xf32> + %1 = "tfl.sum"(%0, %cst_1) <{keep_dims = true}> : (tensor<1x4x40x100x40x3xf32>, tensor<1xi32>) -> tensor<1x4x1x100x40x3xf32> + return %1 : tensor<1x4x1x100x40x3xf32> + // CHECK: %cst = arith.constant dense<[1, 4, 1, 100, 40, 3]> : tensor<6xi32> + // CHECK: %cst_0 = arith.constant dense<2> : tensor<1xi32> + // CHECK: %0 = "tfl.sum"(%arg0, %cst_0) <{keep_dims = true}> : (tensor<1x1x40x100x40x3xf32>, tensor<1xi32>) -> tensor<1x1x1x100x40x3xf32> + // CHECK: %1 = "tfl.broadcast_to"(%0, %cst) : (tensor<1x1x1x100x40x3xf32>, tensor<6xi32>) -> tensor<1x4x1x100x40x3xf32> + // CHECK: return %1 : tensor<1x4x1x100x40x3xf32> +} + +// CHECK-LABEL: @move_broadcast_through_sum_neg +func.func @move_broadcast_through_sum_neg(%arg0: tensor<1x1x40x100x40x3xf32>) -> tensor<1x40x100x40x3xf32> { + %cst_0 = arith.constant dense<[1, 4, 40, 100, 40, 3]> : tensor<6xi64> + %cst_1 = arith.constant dense<1> : tensor<1xi32> + %0 = "tfl.broadcast_to"(%arg0, %cst_0) : (tensor<1x1x40x100x40x3xf32>, tensor<6xi64>) -> tensor<1x4x40x100x40x3xf32> + %1 = "tfl.sum"(%0, %cst_1) <{keep_dims = false}> : (tensor<1x4x40x100x40x3xf32>, tensor<1xi32>) -> tensor<1x40x100x40x3xf32> + return %1 : tensor<1x40x100x40x3xf32> + // CHECK: %cst = arith.constant dense<[1, 4, 40, 100, 40, 3]> : tensor<6xi64> + // CHECK: %cst_0 = arith.constant dense<1> : tensor<1xi32> + // CHECK: %0 = "tfl.broadcast_to"(%arg0, %cst) : (tensor<1x1x40x100x40x3xf32>, tensor<6xi64>) -> tensor<1x4x40x100x40x3xf32> + // CHECK: %1 = "tfl.sum"(%0, %cst_0) <{keep_dims = false}> : (tensor<1x4x40x100x40x3xf32>, tensor<1xi32>) -> tensor<1x40x100x40x3xf32> + // CHECK: return %1 : tensor<1x40x100x40x3xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir index 005aec23403c7f..8971ca0d6d3788 100644 --- a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir @@ -188,9 +188,21 @@ func.func @FoldPerAxisReshape() -> tensor<1x2x2x!quant.uniform>, value = dense<[[-127, 127], [-85, -80]]> : tensor<2x2xi8>}> : () -> tensor<2x2x!quant.uniform> %1 = "tfl.reshape"(%0, %cst) : (tensor<2x2x!quant.uniform>, tensor<3xi32>) -> tensor<1x2x2x!quant.uniform> return %1 : tensor<1x2x2x!quant.uniform> - + // CHECK{LITERAL}: %0 = "tfl.pseudo_qconst"() <{qtype = tensor<1x2x2x!quant.uniform>, value = dense<[[[-127, 127], [-85, -80]]]> : tensor<1x2x2xi8>}> : () -> tensor<1x2x2x!quant.uniform> // CHECK-NOT: tfl.reshape // CHECK: return %0 : tensor<1x2x2x!quant.uniform> } + +// CHECK-LABEL: RemoveVolatileQConstOps +func.func @RemoveVolatileQConstOps() -> tensor<640xf32> { + %1 = "tfl.pseudo_qconst"() <{qtype = tensor<640x!quant.uniform>, value = dense<0> : tensor<640xi32>}> {volatile} : () -> tensor<640x!quant.uniform> + %2 = "tfl.dequantize"(%1) : (tensor<640x!quant.uniform>) -> tensor<640xf32> + func.return %2 : tensor<640xf32> + // CHECK: %0 = "tfl.pseudo_qconst"() <{qtype = tensor<640x!quant.uniform>, value = dense<0> : tensor<640xi32>}> {volatile} : () -> tensor<640x!quant.uniform> + // CHECK: return %0 : tensor<640x!quant.uniform> + + // QDQ-CHECK: %cst = arith.constant dense<0.000000e+00> : tensor<640xf32> + // QDQ-CHECK: return %cst : tensor<640xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/quantize-strict.mlir b/tensorflow/compiler/mlir/lite/tests/quantize-strict.mlir index 4240ea65988461..43984be64310e3 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize-strict.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize-strict.mlir @@ -73,7 +73,7 @@ func.func @QuantizeConvWithBiasAndReluWeightOnly(%arg0: tensor<1x4x4x3xf32>) -> func.func @QuantizeConvWithBiasAndReluSRQ(%arg0: tensor<1x4x4x3xf32>) -> (tensor<1x4x4x1xf32>) { %cst = arith.constant dense<1.14751196> : tensor<1xf32> - %0 = "tfl.quantize"(%cst) <{qtype = tensor<1x!quant.uniform>}> {volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> + %0 = "tfl.quantize"(%cst) <{qtype = tensor<1x!quant.uniform>}> : (tensor<1xf32>) -> tensor<1x!quant.uniform> %1 = "tfl.dequantize"(%0) : (tensor<1x!quant.uniform>) -> tensor<1xf32> %cst_0 = arith.constant dense<[[[[1.76285899, -0.257785767, 0.20429258], [1.16310906, 0.23124367, 0.529797196]], [[0.348971426, -0.319283515, -0.772461354], [0.316666812, 1.88180697, -1.78054631]]]]> : tensor<1x2x2x3xf32> %2 = "tfl.quantize"(%arg0) <{qtype = tensor<1x4x4x3x!quant.uniform>}> : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3x!quant.uniform> @@ -105,3 +105,14 @@ func.func @DQQToRequantize(%arg0: tensor<1x128x128x320x!quant.uniform> } +// ----- + +func.func @VolatileQuantizeConst() -> (tensor<1xf32>) { + %cst = arith.constant dense<1.14751196> : tensor<1xf32> + %0 = "tfl.quantize"(%cst) <{qtype = tensor<1x!quant.uniform>}> {volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> + %1 = "tfl.dequantize"(%0) : (tensor<1x!quant.uniform>) -> tensor<1xf32> + return %1 : tensor<1xf32> +// CHECK: %0 = "tfl.pseudo_qconst"() <{qtype = tensor<1x!quant.uniform>, value = dense<20578> : tensor<1xi32>}> {volatile} : () -> tensor<1x!quant.uniform> +// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x!quant.uniform>) -> tensor<1xf32> +// CHECK: return %1 : tensor<1xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 25789ab44d17be..f508a45c924b52 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -70,8 +70,14 @@ void AddOptimizationPasses(const tflite::ConverterFlags& converter_flags, pass_manager->addPass(mlir::TFL::CreatePushTransposeThroughEwisePass()); - pass_manager->addNestedPass( - mlir::TFL::Create()); + // Add BroadcastLike optimization pass. + { + mlir::TFL::OptimizeBroadcastLikePassOptions options; + options.unsafe_fuse_dynamic_shaped_broadcast = + pass_config.unsafe_fuse_dynamic_shaped_broadcast; + pass_manager->addNestedPass( + mlir::TFL::Create(options)); + } // Add TFLite optimize pass. mlir::TFL::OptimizePassOptions optimize_pass_options; @@ -355,8 +361,13 @@ void AddPostQuantizationStableHloToTfPasses( // broadcasting support. This needs to be run immediately after HLO->TFL // legalization, otherwise the newly generated TFL broadcast ops can fold // and materialize the weights. - pass_manager.addNestedPass( - mlir::TFL::Create()); + { + mlir::TFL::OptimizeBroadcastLikePassOptions options; + options.unsafe_fuse_dynamic_shaped_broadcast = + pass_config.unsafe_fuse_dynamic_shaped_broadcast; + pass_manager.addNestedPass( + mlir::TFL::Create(options)); + } } // folds tf.BroadcastTo ops with subsequent ops if they have built in // broadcasting support. This needs to be run immediately after HLO->TF @@ -637,6 +648,7 @@ void AddPostVariableFreezingTFToTFLConversionPasses( converter_flags.reduce_type_precision()) { pass_manager->addPass(mlir::TFL::CreateReduceTypePrecisionPass()); } + pass_manager->addPass(mlir::TFL::CreateCleanupOptimizationBarrierPass()); // This pass should alway run before the end of the model conversion but // not after the CreateSplitMergedOperandsPass below. diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 5b20a6e72f9984..5cc8856aedd0c7 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -48,7 +48,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" @@ -58,9 +57,15 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" +using llvm::cl::opt; using mlir::MLIRContext; using mlir::ModuleOp; +// NOLINTNEXTLINE +opt upgrade_legacy("tf-upgrade-legacy", + llvm::cl::desc("Upgrade legacy TF graph behavior"), + llvm::cl::init(false)); + // NOLINTNEXTLINE static llvm::cl::opt weight_quantization( "weight_quantization", diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc index 0f05c371868b8d..7769a0ada95102 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" +#include + #include "llvm/Support/CommandLine.h" using llvm::cl::opt; @@ -218,3 +220,73 @@ opt model_origin_framework( "model-origin-framework", llvm::cl::desc("The source model type: PYTORCH, JAX, TENSORFLOW, etc."), llvm::cl::init("UNSET")); + +// NOLINTNEXTLINE +opt input_arrays( + "tf-input-arrays", llvm::cl::desc("Input tensor names, separated by ','"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt input_dtypes( + "tf-input-data-types", + llvm::cl::desc("(Optional) Input tensor data types, separated by ','. Use " + "'' if a single data type is skipped. The data type from " + "the import graph is used if it is skipped."), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt input_shapes( + "tf-input-shapes", + llvm::cl::desc( + "Input tensor shapes. Shapes for different tensors are separated by " + "':', and dimension sizes for the same tensor are separated by ','"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt output_arrays( + "tf-output-arrays", llvm::cl::desc("Output tensor names, separated by ','"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt control_output_arrays( + "tf-control-output-arrays", + llvm::cl::desc("Control output node names, separated by ','"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt inference_type( + "tf-inference-type", + llvm::cl::desc( + "Sets the type of real-number arrays in the output file. Only allows " + "float and quantized types"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt min_values( + "tf-input-min-values", + llvm::cl::desc( + "Sets the lower bound of the input data. Separated by ','; Each entry " + "in the list should match an entry in -tf-input-arrays. This is " + "used when -tf-inference-type is a quantized type."), + llvm::cl::Optional, llvm::cl::init("")); + +// NOLINTNEXTLINE +opt max_values( + "tf-input-max-values", + llvm::cl::desc( + "Sets the upper bound of the input data. Separated by ','; Each entry " + "in the list should match an entry in -tf-input-arrays. This is " + "used when -tf-inference-type is a quantized type."), + llvm::cl::Optional, llvm::cl::init("")); + +// NOLINTNEXTLINE +opt debug_info_file( + "tf-debug-info", + llvm::cl::desc("Path to the debug info file of the input graph def"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt enable_shape_inference( + "tf-enable-shape-inference-on-import", + llvm::cl::desc("Enable shape inference on import (temporary)"), + llvm::cl::init(false)); diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h index c225291360c9df..6095b69d471ad8 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h @@ -48,6 +48,17 @@ extern llvm::cl::opt enable_dynamic_update_slice; extern llvm::cl::opt preserve_assert_op; extern llvm::cl::opt legalize_custom_tensor_list_ops; extern llvm::cl::opt reduce_type_precision; +extern llvm::cl::opt input_arrays; +extern llvm::cl::opt input_dtypes; +extern llvm::cl::opt input_shapes; +extern llvm::cl::opt output_arrays; +extern llvm::cl::opt control_output_arrays; +extern llvm::cl::opt inference_type; +extern llvm::cl::opt min_values; +extern llvm::cl::opt max_values; +extern llvm::cl::opt debug_info_file; +extern llvm::cl::opt upgrade_legacy; +extern llvm::cl::opt enable_shape_inference; // Import saved model. extern llvm::cl::opt import_saved_model_object_graph; diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index 304473e201068a..8d4e3d4c604f28 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" diff --git a/tensorflow/compiler/mlir/lite/tools/BUILD b/tensorflow/compiler/mlir/lite/tools/BUILD index 63590fc545fd4d..055877d0b32200 100644 --- a/tensorflow/compiler/mlir/lite/tools/BUILD +++ b/tensorflow/compiler/mlir/lite/tools/BUILD @@ -22,47 +22,3 @@ cc_library( ) # LINT.ThenChange(//tensorflow/lite/tools:command_line_flags) - -cc_library( - name = "translate_cl_options", - srcs = [ - "tf_mlir_translate_cl.cc", - ], - hdrs = [ - "tf_mlir_translate_cl.h", - ], - deps = [ - "@llvm-project//llvm:Support", - ], - alwayslink = 1, -) - -cc_library( - name = "translate_registration", - srcs = [ - "tf_mlir_translate_registration.cc", - ], - deps = [ - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags", - "//tensorflow/compiler/mlir/tensorflow/translate/tools:file_tf_mlir_translate", - "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/container:flat_hash_set", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:TranslateLib", - "@local_tsl//tsl/platform:protobuf", - "@local_xla//xla/client:client_library", - "@local_xla//xla/client:compile_only_client", - "@local_xla//xla/service/cpu:cpu_compiler", - "@local_xla//xla/service/cpu:cpu_transfer_manager", - "@local_xla//xla/stream_executor/host:host_platform", - "@local_xla//xla/stream_executor/host:host_platform_id", - ], - alwayslink = 1, -) diff --git a/tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.cc b/tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.cc new file mode 100644 index 00000000000000..8cb785ac86d84f --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.cc @@ -0,0 +1,55 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo + +namespace mlir { +namespace TFL { +namespace { + +#define DEBUG_TYPE "cleanup-optimization-barrier" + +// Replaces the shlo.optimization_barrier op with its input. +struct CleanupOptimizationBarrier + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::OptimizationBarrierOp op, + PatternRewriter& rewriter) const override { + rewriter.replaceOp(op, op.getOperands()); + return success(); + } +}; +} // end namespace + +void CleanupOptimizationBarrierPass::runOnOperation() { + auto* ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // end namespace TFL +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.h b/tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.h new file mode 100644 index 00000000000000..3a6bd2a863e016 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.h @@ -0,0 +1,58 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_CLEANUP_OPTIMIZATION_BARRIER_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_CLEANUP_OPTIMIZATION_BARRIER_PASS_H_ + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +namespace mlir { +namespace TFL { + +// Pass to clean up shlo.optimization_barrier ops. + +class CleanupOptimizationBarrierPass + : public TFL::Pass { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CleanupOptimizationBarrierPass) + + CleanupOptimizationBarrierPass() = default; + CleanupOptimizationBarrierPass(const CleanupOptimizationBarrierPass&) {}; + + void runOnOperation() override; + static llvm::StringRef GetName() { return "CleanupOptimizationBarrierPass"; } + static llvm::StringRef GetArgument() { + return "tfl-cleanup-optimization-barrier"; + } + static llvm::StringRef GetDescription() { + return "Pass to clean up shlo.optimization_barrier ops."; + } + + private: + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } +}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_CLEANUP_OPTIMIZATION_BARRIER_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.cc b/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.cc index f0fb9361980f67..5a3f23fe6df382 100644 --- a/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.cc +++ b/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h" +#include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/optimize_pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/variable_freezing_pipeline_options.h" @@ -33,6 +34,12 @@ void ConverterPassOptionsSetter::SetOptions( options.enable_tflite_variables = pass_config_.enable_tflite_variables; } +void ConverterPassOptionsSetter::SetOptions( + OptimizeBroadcastLikePassOptions& options) const { + // options.unsafe_fuse_dynamic_shaped_broadcast = + // converter_flags_.unsafe_fuse_dynamic_shaped_broadcast(); +} + void ConverterPassOptionsSetter::SetOptions(EmptyPassOptions& options) const {} } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h b/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h index 01f71afe84ca3f..59151448b92f0a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h +++ b/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h @@ -26,6 +26,7 @@ namespace TFL { class OptimizePassOptions; class VariableFreezingPipelineOptions; class EmptyPassOptions; +class OptimizeBroadcastLikePassOptions; // PassOptionsSetter to set TFLite Converter Pass/Pipeline Options based on // ConverterFlags and TFL::PassConfig values. @@ -40,6 +41,7 @@ class ConverterPassOptionsSetter : public PassOptionsSetter { void SetOptions(OptimizePassOptions& options) const override; void SetOptions(VariableFreezingPipelineOptions& options) const override; void SetOptions(EmptyPassOptions& options) const override; + void SetOptions(OptimizeBroadcastLikePassOptions& options) const override; private: tflite::ConverterFlags converter_flags_; diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index f1b602a6763aca..995a878cfc47cf 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -54,7 +54,7 @@ namespace { class DefaultQuantParamsPass : public impl::DefaultQuantParamsPassBase { public: - using DefaultQuantParamsPassBase::DefaultQuantParamsPassBase; + DefaultQuantParamsPass() {} explicit DefaultQuantParamsPass(double default_min, double default_max, bool is_signed) { diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 26659b157933f2..a148fbd3685f5c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -26,20 +26,20 @@ include "tensorflow/compiler/mlir/lite/utils/utils.td" def CreateEmptyBoolAttr : NativeCodeCall<"::mlir::BoolAttr()">; def DenseElementsAttr : ElementsAttrBase< - CPred<"$_self.isa()">, + CPred<"llvm::isa($_self)">, "non-opaque constant tensor">; def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; + CPred<"llvm::cast($_self).getShapedType().getElementType().isF32()">, "float constant tensor">; def Int64ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getShapedType().getElementType().isInteger(64)">, "Int 64 constant tensor">; + CPred<"llvm::cast($_self).getShapedType().getElementType().isInteger(64)">, "Int 64 constant tensor">; // Extract the ith int element from an ArrayAttr $0 as an 32-bit IntegerAttr // with builder. class ExtractI32At : NativeCodeCall< - "$_builder.getI32IntegerAttr($_self.cast().getValue()[" # i # - "].cast().getInt())">; + "$_builder.getI32IntegerAttr(llvm::cast(llvm::cast($_self).getValue()[" # i # + "]).getInt())">; // Use the tensor type information from $0 and convert min $1, max $2 and // numBits $3 and narrowRange $4 to a QuantizedType. @@ -48,7 +48,7 @@ def ConvertToQuantTypeFromAttrs : NativeCodeCall< // Converts an integer attribute $0 to 32-bit with builder. def convertIntAttrTo32Bit : NativeCodeCall< - "$_builder.getI32IntegerAttr($0.cast().getInt())">; + "$_builder.getI32IntegerAttr(llvm::cast($0).getInt())">; // Builds a constant bool attribute. class GetBoolAttr : @@ -56,15 +56,15 @@ class GetBoolAttr : // Converts an integer attribute $0 to 64-bit with builder. def convertIntAttrTo64Bit : NativeCodeCall< - "$_builder.getI64IntegerAttr($0.cast().getInt())">; + "$_builder.getI64IntegerAttr(llvm::cast($0).getInt())">; // Extracts the single integer element from $_self. def ExtractSingleElementAsInteger : NativeCodeCall< - "ExtractSingleElementAsInteger($_self.cast())">; + "ExtractSingleElementAsInteger(llvm::cast($_self))">; // Extracts the single int32 element from $_self. def ExtractSingleElementAsInt32 : NativeCodeCall< - "$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast()).getInt())">; + "$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger(llvm::cast($_self)).getInt())">; // Converts tensor with int64 to int32. def CreateTFCastToInt32Op : NativeCodeCall< @@ -75,7 +75,7 @@ def CreateInt32ConstOrCast : NativeCodeCall< // Creates an int32 constant op from an integer attribute $0. def CreateInt32ConstOpFromIntAttr - : NativeCodeCall<"$_builder.create($_loc, DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getI32Type()), {static_cast($0.cast().getInt())}))">; + : NativeCodeCall<"$_builder.create($_loc, DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getI32Type()), {static_cast(llvm::cast($0).getInt())}))">; //===----------------------------------------------------------------------===// // Nullary ops patterns. @@ -100,8 +100,8 @@ def IsDataFormatNHWC : ConstantAttr; def IsDataFormatNCHW : ConstantAttr; class I32VectorElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() &&" - "$_self.cast().getType()." + CPred<"llvm::isa($_self) &&" + "llvm::cast($_self).getType()." "getElementType().isSignlessInteger(32)">, "32-bit int elements attribute of shape [" # len # "]"> { @@ -123,8 +123,8 @@ def IsAllOnes : AttrConstraint>; // Constraint that attribute is string with value either "SAME" or "VALID" def IsSameOrValid : AttrConstraint< - CPred<"$_self.cast().getValue() == \"SAME\" || " # - "$_self.cast().getValue() == \"VALID\"">, + CPred<"llvm::cast($_self).getValue() == \"SAME\" || " # + "llvm::cast($_self).getValue() == \"VALID\"">, "'SAME' or 'VALID' paddings">; def TFL_GetMirrorPaddingType : NativeCodeCall< @@ -443,8 +443,8 @@ def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted), (TFL_TopKV2Op $input, $k)>; -def ReductionDimensionIsLastDim : Constraint().getInt() == " - "$1.getType().cast().getRank() - 1 || $0.cast().getInt() == -1)">>; +def ReductionDimensionIsLastDim : Constraint($0).getInt() == " + "llvm::cast($1.getType()).getRank() - 1 || llvm::cast($0).getInt() == -1)">>; // Legalizes TF_ApproxTopKOp to TFL_TopKV2Op with the following constraints: // 1. It computes max k @@ -558,10 +558,10 @@ def LegalizeConv2DBackpropInput : Pat< /*fused_activation_function=*/TFL_AF_None)>; def IsRankZeroAttr - : CPred<"$_self.cast().getType().getRank() == 0">; + : CPred<"llvm::cast($_self).getType().getRank() == 0">; def HasValueZero - : CPred<"$_self.cast()." + : CPred<"llvm::cast($_self)." "getSplatValue<::mlir::IntegerAttr>().getInt() == 0">; // TFLite only supports MatrixSetDiag ops with scalar zero k attribute. diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_variables.td b/tensorflow/compiler/mlir/lite/transforms/legalize_variables.td index 5c26b6ea468565..72ec563930d7d2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_variables.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_variables.td @@ -22,7 +22,7 @@ def HasSupportedElementType : Constraint>; def IsSupportedElementType : - Constraint())">>; + Constraint($0.getType()))">>; def LegalizeVarHandle : Pat< (TF_VarHandleOp:$result $container, $shared_name), diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 2b5b7537f5154c..182d593cb14351 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -139,7 +139,7 @@ Value CreateI32SplatTensor(Location loc, PatternRewriter *rewriter, Type PrependLeadingDimIfRanked(int64_t dim, Type type, PatternRewriter *rewriter) { Type dtype = getElementTypeOrSelf(type); - if (RankedTensorType ty = type.dyn_cast()) { + if (RankedTensorType ty = llvm::dyn_cast(type)) { llvm::SmallVector shape = {dim}; shape.append(ty.getShape().begin(), ty.getShape().end()); return tensorflow::GetTypeFromTFTensorShape(shape, dtype); @@ -256,7 +256,7 @@ struct ConvertConst : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { // Verify that the tensor proto contains tensor of type variant and scalar // shape. The variant type should hold a TensorList. - auto proto_attr = op.getValue().dyn_cast(); + auto proto_attr = llvm::dyn_cast(op.getValue()); if (!proto_attr) return failure(); tensorflow::Tensor tensor; if (!tensorflow::ConvertToTensor(proto_attr, &tensor).ok()) @@ -270,13 +270,13 @@ struct ConvertConst : public OpConversionPattern { if (!list) return failure(); // Verify output type is variant and contains exactly one ranked subtypes. - auto variant_ty = - getElementTypeOrSelf(op.getType()).dyn_cast(); + auto variant_ty = llvm::dyn_cast( + getElementTypeOrSelf(op.getType())); if (!variant_ty) return failure(); ArrayRef subtypes = variant_ty.getSubtypes(); if (subtypes.size() != 1) return failure(); RankedTensorType list_element_ty = - subtypes.front().dyn_cast(); + llvm::dyn_cast(subtypes.front()); if (!list_element_ty) return failure(); // Extract tensor elements for the TensorList and construct result type @@ -372,7 +372,8 @@ struct ConvertTensorListSetItem loc, tensorflow::GetTypeFromTFTensorShape({1}, shape_dtype), item_rank, scalar_zero); // Create two slice ops. - Type element_type = input.getType().cast().getElementType(); + Type element_type = + llvm::cast(input.getType()).getElementType(); UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type); Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1); TF::SliceOp slice1 = @@ -441,7 +442,8 @@ struct ConvertTensorListSetItem // Expand the dimension of item so that it will have the same rank with // input. // ExpandDims(item, 0) - Type element_type = input.getType().cast().getElementType(); + Type element_type = + llvm::cast(input.getType()).getElementType(); UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type); auto expanded_item = rewriter.create( op.getLoc(), unranked_tensor, item, scalar_zero); @@ -494,7 +496,8 @@ struct ConvertTensorListInitOp : public TensorListOpConverterBase { // looking at the first `TensorListSetItemOp` writing to this tensor list. // Here we assume that the element_shape won't be changed before calling // the first `TensorListSetItemOp`. - if (auto shaped_type = element_shape.getType().dyn_cast()) { + if (auto shaped_type = + llvm::dyn_cast(element_shape.getType())) { if (shaped_type.hasRank() && shaped_type.getRank() == 0) { bool element_shape_acquired = false; auto uses = op.getResult().getUses(); @@ -517,8 +520,8 @@ struct ConvertTensorListInitOp : public TensorListOpConverterBase { if (TF::TensorListSetItemOp set_op = llvm::dyn_cast( inside_use.getOwner())) { - if (auto shaped_type = - set_op.getItem().getType().dyn_cast()) { + if (auto shaped_type = llvm::dyn_cast( + set_op.getItem().getType())) { if (shaped_type.hasStaticShape()) { RankedTensorType type = tensorflow::GetTypeFromTFTensorShape( @@ -592,7 +595,8 @@ struct ConvertTensorListInitOp : public TensorListOpConverterBase { } auto attr = DenseIntElementsAttr::get( - element_shape.getType().cast(), new_element_shape_values); + llvm::cast(element_shape.getType()), + new_element_shape_values); auto new_element_shape = rewriter.create( op.getLoc(), element_shape.getType(), attr); element_shape = new_element_shape; @@ -603,7 +607,7 @@ struct ConvertTensorListInitOp : public TensorListOpConverterBase { Type result_type = UnrankedTensorType::get(element_dtype); Value leading_dim = GetNumElements(op, adaptor.getOperands(), &rewriter); if (auto element_type = - op.element_type().template dyn_cast()) { + llvm::dyn_cast(op.element_type())) { result_rank = element_type.getRank() + 1; int64_t leading_dim_v = -1; ElementsAttr element_attr; @@ -662,12 +666,12 @@ struct ConvertTensorListReserve return CreateI32SplatConst(op.getLoc(), rewriter, {1}, attr.getInt()); } if (auto const_op = num_elements.getDefiningOp()) { - return CreateI32SplatConst(op->getLoc(), rewriter, {1}, - (*const_op.getValue() - .cast() - .getValues() - .begin()) - .getSExtValue()); + return CreateI32SplatConst( + op->getLoc(), rewriter, {1}, + (*llvm::cast(const_op.getValue()) + .getValues() + .begin()) + .getSExtValue()); } return rewriter->create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape({1}, shape_dtype), @@ -713,8 +717,8 @@ struct ConvertTensorListPushBack loc, expanded_item_type, item, scalar_zero); Type elem_type = getElementTypeOrSelf(item); - auto handle_dtype = getElementTypeOrSelf(op.getOutputHandle().getType()) - .cast(); + auto handle_dtype = llvm::cast( + getElementTypeOrSelf(op.getOutputHandle().getType())); Type result_type = GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); @@ -756,8 +760,8 @@ struct ConvertTensorListResize // Infer result type of this op based on TF's shape inference result. Type elem_type = getElementTypeOrSelf(input_handle); - auto handle_dtype = getElementTypeOrSelf(op.getOutputHandle().getType()) - .cast(); + auto handle_dtype = llvm::cast( + getElementTypeOrSelf(op.getOutputHandle().getType())); Type result_type = GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); @@ -952,7 +956,8 @@ struct ConvertTensorListStack // trivial Reshape op (that doesn't actually change the input's shape) and // also populate the shape info to the op result. The shape of the // tensorlist is inferred from `num_elements` and `element_shape`. - auto ranked_type = element_shape.getType().dyn_cast(); + auto ranked_type = + llvm::dyn_cast(element_shape.getType()); DenseIntElementsAttr dense_elem_attr; if ((ranked_type && ranked_type.getRank() == 0) || !matchPattern(element_shape, m_Constant(&dense_elem_attr))) { @@ -1013,7 +1018,7 @@ struct ConvertTensorListConcatV2 // First unpack the input tensor along the first dimension. Type input_element_type = getElementTypeOrSelf(input); int64_t num_unpacked = 0; - if (auto type = input.getType().dyn_cast()) { + if (auto type = llvm::dyn_cast(input.getType())) { if (type.getDimSize(0) > 0) { num_unpacked = type.getDimSize(0); } else { @@ -1091,7 +1096,7 @@ struct ConvertYield : public OpConversionPattern { // if `type` is a tensor of variant. Otherwise, returns `type` unmodified. Type VariantToUnrankedTensorType(Type type, Value value) { TF::VariantType variant_ty = - getElementTypeOrSelf(type).dyn_cast(); + llvm::dyn_cast(getElementTypeOrSelf(type)); if (!variant_ty) { return type; } @@ -1102,7 +1107,7 @@ Type VariantToUnrankedTensorType(Type type, Value value) { } Type value_type = value.getType(); Type element_type; - variant_ty = value_type.dyn_cast(); + variant_ty = llvm::dyn_cast(value_type); if (variant_ty && !variant_ty.getSubtypes().empty()) { element_type = variant_ty.getSubtypes()[0].getElementType(); } else { @@ -1114,7 +1119,7 @@ Type VariantToUnrankedTensorType(Type type, Value value) { // Returns true if we can deduce the type is tensorlist. bool IsTensorListType(Type type, std::optional value) { TF::VariantType variant_ty = - getElementTypeOrSelf(type).dyn_cast(); + llvm::dyn_cast(getElementTypeOrSelf(type)); if (!variant_ty) { return false; } @@ -1336,7 +1341,7 @@ llvm::DenseMap MapTensorListResultToArgument(func::FuncOp func) { break; } } - if (auto block_arg = parent.dyn_cast()) { + if (auto block_arg = dyn_cast(parent)) { return block_arg.getArgNumber(); } // Returns -1 if we don't find which this result maps to. @@ -1547,7 +1552,7 @@ void LowerStaticTensorListPass::runOnOperation() { // still. auto is_legal = [](Operation *op) { auto is_not_variant = [](Type ty) { - return !ty.cast().getElementType().isa(); + return !isa(cast(ty).getElementType()); }; return llvm::all_of(op->getOperandTypes(), is_not_variant) && llvm::all_of(op->getResultTypes(), is_not_variant); @@ -1555,8 +1560,7 @@ void LowerStaticTensorListPass::runOnOperation() { auto is_set_item_legal = [](Operation *op) { return op->hasAttr("resize_if_index_out_of_bounds") && - op->getAttr("resize_if_index_out_of_bounds") - .cast() + llvm::cast(op->getAttr("resize_if_index_out_of_bounds")) .getValue(); }; diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.td b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.td index 85bdf63babcbab..bc82b1f496acfb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.td @@ -26,8 +26,8 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" def NotFromDequant : Constraint>; def IsResultRankEqualTo : Constraint().getRank() == " - "$1.getType().cast().getRank()">>; + "llvm::cast($0.getType().front()).getRank() == " + "llvm::cast($1.getType()).getRank()">>; // Fuses TFL_FullyConnectedOp and TFL_TransposeOp Rhs to TFL_BatchMatMulOp when // it's used by TFL_BatchMatMulOp and "transpose_lhs" is true. diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.cc index 2451089517c549..71ebbab92c1a71 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" namespace mlir { @@ -56,7 +57,7 @@ bool NotFromDequant(mlir::Value value) { // Converts batch_matmul operation to fully_connected if rhs is a // constant tensor with rank 2 -struct ConvertBatchMatMulOp2FullyConnectedOp +struct ConvertBatchMatMulOp2FullyConnectedOp_Rank2ConstantRhs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TFL::BatchMatMulOp bmm_op, @@ -263,6 +264,127 @@ struct ConvertBatchMatMulOpToReduceSum return false; } }; + +// Pattern to fuse transpose op into RHS of batch_matmul op if the transpose and +// batch_matmul are separated by a reshape op; and the transpose op is used +// exclusively to transpose the contracting dimension and the LHS-Output +// dimension. +// Converts batch_matmul operation to fully_connected if rhs is rank-2 +// else converts it to a BatchMatMul op with adj_y = true and transpose fused +// into RHS. +// +// Example: +// % 0 = "tfl.transpose" // Input: [2048, 32, 128] -> [128, 2048, 32] +// % 1 = "tfl.reshape"(%0) // reshaped [128, 2048, 32] -> [128, 65536] +// % 2 = "tfl.batch_matmul" // LHS: [4, 128], RHS: [128, 65536] -> [4, 65536] +struct FuseRhsTransposeIntoBatchMatMulOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TFL::BatchMatMulOp bmm_op, + PatternRewriter& rewriter) const override { + // Exit the pattern if adj_y is true. + if (bmm_op.getAdjY()) { + return rewriter.notifyMatchFailure( + bmm_op, "Pattern does not apply when adj_y is true."); + } + + // Exit the pattern if the RHS of BatchMatMulOp is not originated from a + // TFL::TransposeOp->TFL::ReshapeOp. + auto reshape_op = bmm_op.getY().getDefiningOp(); + if (!reshape_op) { + return rewriter.notifyMatchFailure( + bmm_op, + "RHS is not originated from a transpose->reshape op pattern."); + } + + auto transpose_op = reshape_op.getInput().getDefiningOp(); + if (!transpose_op) { + return rewriter.notifyMatchFailure( + bmm_op, + "RHS is not originated from a transpose->reshape op pattern."); + } + + // Get the dimensions info of the RHS of BatchMatMulOp. + auto rhs_dimensions_info = GetBatchMatMulRhsDimensionsInfo( + mlir::cast(bmm_op.getY().getType())); + + // Make sure that the reshape op is flattening either the contracting + // dimension or the output dimension. + auto reshape_input_shape = GetShape(reshape_op.getInput()); + if (!HasFlattenedContractingDims(reshape_input_shape, + rhs_dimensions_info) && + !HasFlattenedOutDims(reshape_input_shape, rhs_dimensions_info)) { + return rewriter.notifyMatchFailure( + bmm_op, + "Reshape op is not flattening the contracting dimension or the " + "output dimension."); + } + + // Make sure that the transpose op is only transposing the contracting + // dimensions and the output dimensions. + auto transpose_perm_status_or_value = + GetValueAsIntArray(transpose_op.getPerm()); + auto transpose_input_shape = GetShape(transpose_op.getInput()); + if (transpose_perm_status_or_value.ok() && + !HasTransposedContractingAndOutDims( + transpose_input_shape, transpose_perm_status_or_value.value(), + rhs_dimensions_info)) { + return rewriter.notifyMatchFailure( + bmm_op, + "Transpose op is not transposing the contracting dimension and the " + "output dimension."); + } + + auto rhs_contracting_dimensions = + rhs_dimensions_info.contracting_dimensions(); + auto rhs_out_dimensions = rhs_dimensions_info.out_dimensions(); + auto rhs_batch_dimensions = rhs_dimensions_info.batch_dimensions(); + + // Create a new ReshapeOp, without the TransposeOp, to flatten the + // contracting dimension and the output dimension, as needed. + llvm::SmallVector new_reshape_input_shape; + if (!rhs_dimensions_info.batch_dimensions().AxesArray().empty()) { + for (auto dim_size : rhs_batch_dimensions.SizesArray()) { + new_reshape_input_shape.push_back(dim_size); + } + } + new_reshape_input_shape.push_back(rhs_out_dimensions.SizesArray().front()); + new_reshape_input_shape.push_back( + rhs_contracting_dimensions.SizesArray().front()); + + Value new_reshape_shape_value = rewriter.create( + bmm_op->getLoc(), + GetI32ElementsAttr(new_reshape_input_shape, &rewriter)); + auto new_reshape_value = rewriter.create( + bmm_op->getLoc(), transpose_op.getInput(), new_reshape_shape_value); + + // Replace the BatchMatMulOp with a FullyConnectedOp, if the RHS of BMM has + // no broadcasting dimensions. I.e. RHS of BMM is of Rank 2. + if (rhs_dimensions_info.batch_dimensions().AxesArray().empty()) { + auto no_input = rewriter.create( + bmm_op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); + auto fc_op = rewriter.create( + bmm_op->getLoc(), ArrayRef{bmm_op.getType()}, + /*input=*/bmm_op.getX(), /*filter=*/new_reshape_value, + /*bias=*/no_input, + /*fused_activation_function=*/rewriter.getStringAttr("NONE"), + /*weights_format=*/rewriter.getStringAttr("DEFAULT"), + /*keep_num_dims=*/rewriter.getBoolAttr(true), + /*asymmetric_quantize_inputs=*/mlir::BoolAttr()); + rewriter.replaceOp(bmm_op, {fc_op.getResult(0)}); + } else { + // Replace the BatchMatMulOp with a BatchMatMulOp with adj_y = true and + // transpose fused into RHS. + auto bmm_op_with_adj_y = rewriter.create( + bmm_op->getLoc(), bmm_op.getType(), bmm_op.getX(), new_reshape_value, + bmm_op.getAdjX(), /*adj_y=*/true, mlir::BoolAttr()); + rewriter.replaceOp(bmm_op, {bmm_op_with_adj_y.getResult()}); + } + + return success(); + } +}; + #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize_batch_matmul.inc" } // namespace @@ -271,8 +393,10 @@ void OptimizeBatchMatmulPass::runOnOperation() { auto* ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns + .add( + ctx); TFL::populateWithGenerated(patterns); (void)applyPatternsGreedily(func, std::move(patterns)); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.cc index e85cfef6dd0d87..aed2946db17ba3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.cc @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" namespace mlir { @@ -55,8 +56,10 @@ using BroadcastedShapeFunction = class ConvertResultsBroadcastableShapeOp : public RewritePattern { public: - explicit ConvertResultsBroadcastableShapeOp(MLIRContext* context) - : RewritePattern(MatchAnyOpTypeTag(), /*PatternBenefit*/ 1, context) {} + explicit ConvertResultsBroadcastableShapeOp( + MLIRContext* context, const OptimizeBroadcastLikePassOptions& options) + : RewritePattern(MatchAnyOpTypeTag(), /*PatternBenefit*/ 1, context), + options_(options) {} LogicalResult matchAndRewrite(Operation* op, PatternRewriter& rewriter) const override; @@ -65,6 +68,9 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern { LogicalResult RewriteOp( Operation* op, PatternRewriter& rewriter, BroadcastedShapeFunction& get_broadcasted_shape) const; + + private: + const OptimizeBroadcastLikePassOptions& options_; }; // Some tfl ops only support implicit broadcasting up to a certain rank. @@ -191,7 +197,8 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the result shape is fully defined. auto result_type = llvm::cast(op->getResultTypes().front()); - if (!result_type || !result_type.hasStaticShape()) + if (!result_type || (!options_.unsafe_fuse_dynamic_shaped_broadcast && + !result_type.hasStaticShape())) return rewriter.notifyMatchFailure( op, "Unsupported result shape for broadcasting on op: " + op->getName().getStringRef()); @@ -224,7 +231,10 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the operand of the broadcast has fully defined shape. auto broadcast_arg_type = llvm::cast(broadcast_like_op_input.getType()); - if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue; + if (!broadcast_arg_type || + (!options_.unsafe_fuse_dynamic_shaped_broadcast && + !broadcast_arg_type.hasStaticShape())) + continue; auto other_arg = op->getOpOperand(1 - i).get(); // If non-splat operand is not fusable affine ops, then no need to apply @@ -238,7 +248,9 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the other argument has fully defined shape. auto other_arg_type = llvm::cast(other_arg.getType()); - if (!other_arg_type || !other_arg_type.hasStaticShape()) continue; + if (!other_arg_type || (!options_.unsafe_fuse_dynamic_shaped_broadcast && + !other_arg_type.hasStaticShape())) + continue; // Get the unbroadcasted shapes in the operand order. std::array, 2> operand_shapes; @@ -268,8 +280,9 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( class ConvertResultsBroadcastableBatchMatMulShapeOp : public ConvertResultsBroadcastableShapeOp { public: - explicit ConvertResultsBroadcastableBatchMatMulShapeOp(MLIRContext* context) - : ConvertResultsBroadcastableShapeOp(context) {} + explicit ConvertResultsBroadcastableBatchMatMulShapeOp( + MLIRContext* context, const OptimizeBroadcastLikePassOptions& options) + : ConvertResultsBroadcastableShapeOp(context, options) {} LogicalResult matchAndRewrite(Operation* op, PatternRewriter& rewriter) const override; @@ -384,9 +397,10 @@ void OptimizeBroadcastLikePass::runOnOperation() { RewritePatternSet patterns(&getContext()); auto func = getOperation(); - patterns.add(func.getContext()); - patterns.add( - func.getContext()); + patterns.add(func.getContext(), + GetOptions()); + patterns.add(func.getContext(), + GetOptions()); patterns.add(func.getContext()); TFL::populateWithGenerated(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h index f13048a1982641..0b5f8f1f6bc2b1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h @@ -16,24 +16,28 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/pass.h" -#include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" namespace mlir { namespace TFL { // Pass to optimize explicit broadcasting-like patterns. class OptimizeBroadcastLikePass - : public TFL::Pass { + : public TFL::Pass { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizeBroadcastLikePass) OptimizeBroadcastLikePass() = default; OptimizeBroadcastLikePass(const OptimizeBroadcastLikePass&) {}; + explicit OptimizeBroadcastLikePass(const mlir::detail::PassOptions& options) + : Pass(options) {} void runOnOperation() override; static llvm::StringRef GetName() { return "OptimizeBroadcastLikePass"; } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h new file mode 100644 index 00000000000000..7d11f5d74cc4c5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h @@ -0,0 +1,41 @@ + +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BROADCAST_LIKE_PASS_OPTIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BROADCAST_LIKE_PASS_OPTIONS_H_ + +#include "llvm/Support/CommandLine.h" +#include "mlir/Pass/PassOptions.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +//////////////////////////////////////////////////////////////////////////////// +// Pass Options +//////////////////////////////////////////////////////////////////////////////// + +struct OptimizeBroadcastLikePassOptions : public mlir::detail::PassOptions { + mlir::detail::PassOptions::Option unsafe_fuse_dynamic_shaped_broadcast{ + *this, "unsafe-fuse-dynamic-shaped-broadcast", + llvm::cl::desc( + "Enable fusion of dynamic shaped broadcast ops. It helps fusing " + "implicit broadcasting ops when output shape has dynamic dimensions, " + "but it may cause incorrect results when broadcasting ops are " + "introduced by explicit broadcasting in the source model."), + llvm::cl::init(false)}; +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BROADCAST_LIKE_PASS_OPTIONS_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_patterns.td index e97ab85accb93d..945c67090f08fd 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_patterns.td @@ -364,3 +364,32 @@ def RemoveRedundantBroadcastToOp : Pat< (TFL_BroadcastToOp:$result AnyStaticShapeTensor:$pre_broadcast, $dim), (replaceWithValue $pre_broadcast), [(HasSameStaticShapes $pre_broadcast, $result)]>; + +//////////////////////////////////////////////////////////////////////////////// +// Reorder TFL::SumOp with the TFL::broadcast_to operator. +//////////////////////////////////////////////////////////////////////////////// + +def HasDistinctBroadcastAndReduceAxes : Constraint>; + +// Pattern to transform tfl.sum(tfl.broadcast_to(input, shape=S1), axis=B, keep_dims=true) +// into tfl.broadcast_to(tfl.sum(input, axis=B, keep_dims=true), shape=S2) +// where S1 is intermediate_target_shape_val, B is reduction_indices_val, +// and S2 is the computed final_target_shape_val (shape of original sum). +def ReorderBroadcastToAfterSumOp : Pat< + (TFL_SumOp:$original_sum + (TFL_BroadcastToOp:$intermediate_broadcast + AnyStaticShapeTensor:$original_input, + (Arith_ConstantOp $intermediate_target_shape_val)), + (Arith_ConstantOp I32ElementsAttr:$reduction_indices_val), + $keep_dims), + (TFL_BroadcastToOp + (TFL_SumOp + $original_input, + (Arith_ConstantOp $reduction_indices_val), + $keep_dims), + (Arith_ConstantOp (GetShapeAttr $original_sum))), + [(HasOneUse $intermediate_broadcast), + (HasDistinctBroadcastAndReduceAxes + $original_input, $reduction_indices_val, $intermediate_target_shape_val), + ]>; diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc index 994762b7641ebb..277546ec8e3ae2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc @@ -824,21 +824,6 @@ bool IsPermutationNCHW(Value perm) { #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc" -// Returns 1D 32-bit dense elements attribute with the given values. -static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, - Builder *builder) { - RankedTensorType ty = mlir::RankedTensorType::get( - {static_cast(values.size())}, builder->getIntegerType(32)); - return DenseIntElementsAttr::get(ty, values); -} - -DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, - Builder *builder) { - RankedTensorType ty = RankedTensorType::get( - {static_cast(values.size())}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, values); -} - // Get the number of leading 1s in the shape of the given input. // Ex. input_shape = [1 x 1 x 1 x 1 x 2 x 1] => 4 // returns 0 if the input shape is not static. @@ -992,80 +977,6 @@ struct SqueezeReshapesAroundBroadcastOp } }; -// This pattern matches TFL::BroadcastToOp WITH TENSOR RANK <= 4 and replaces -// it with a MulOp that multiplies the tensor by a splat constant with 1s. -struct ConvertTFLBroadcastToMulOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TFL::BroadcastToOp tfl_broadcast_to_op, - PatternRewriter &rewriter) const override { - auto input_type = - mlir::cast(tfl_broadcast_to_op.getInput().getType()); - auto output_type = - mlir::cast(tfl_broadcast_to_op.getOutput().getType()); - auto shape_type = - mlir::cast(tfl_broadcast_to_op.getShape().getType()); - Type element_type = input_type.getElementType(); - - auto loc = tfl_broadcast_to_op->getLoc(); - - // Check that the output type is not dynamic and is less-than-equal to 4D or - // the shape type is static, 1D and has less-than-equal to 4 elements. - bool is_output_shape_dynamic = - (!output_type.hasRank() || (output_type.getRank() > 4) || - (output_type.getNumDynamicDims() > 0)); - bool is_broadcast_shape_dynamic = - (!shape_type.hasStaticShape() || (shape_type.getRank() != 1) || - (shape_type.getDimSize(0) > 4)); - if (is_output_shape_dynamic && is_broadcast_shape_dynamic) - return rewriter.notifyMatchFailure( - loc, "output_rank or broadcast_to shape not supported"); - - // Allow lowering when the input's elements type is F32, BFloat16, I32 or - // I16. - if (!(mlir::isa(element_type) || - element_type.isInteger(32) || element_type.isInteger(16))) - return rewriter.notifyMatchFailure(loc, "element_type_not_supported"); - - // TFL_FillOp is created only if is_output_shape_dynamic is true, otherwise - // a Arith.ConstOp is created. - if (is_output_shape_dynamic && - output_type.getElementType().isUnsignedInteger()) { - return rewriter.notifyMatchFailure( - loc, - "Unsigned broadcast_to output with dynamic shape is not supported"); - } - - Value mul_rhs_value; - if (!output_type.hasRank() || (output_type.getNumDynamicDims() > 0)) { - auto status_or_const_op = - CreateConstOpWithSingleValue(&rewriter, loc, input_type, 1); - if (!status_or_const_op.ok()) { - return failure(); - } - - mul_rhs_value = rewriter.create( - loc, output_type, tfl_broadcast_to_op.getShape(), - status_or_const_op.value()); - } else { - auto status_or_const_op = - CreateConstOpWithVectorValue(&rewriter, loc, output_type, 1); - if (!status_or_const_op.ok()) { - return failure(); - } - - mul_rhs_value = status_or_const_op.value(); - } - - auto mul_op = rewriter.create( - loc, output_type, tfl_broadcast_to_op.getInput(), mul_rhs_value, - rewriter.getStringAttr("NONE")); - rewriter.replaceOp(tfl_broadcast_to_op, mul_op.getResult()); - return success(); - } -}; - struct FuseAddAndStridedSlice : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -2531,9 +2442,11 @@ struct EliminateQDQPairs : public OpRewritePattern { // (HasRankAtLeast<2> $bias), // (IsDefinedByFullyConnectedOp $lhs)]>; struct UndoBroadcastFullyConnectedBiasAddWithQDQs - : public OpRewritePattern::SplitMatchAndRewrite { - using SplitMatchAndRewrite::SplitMatchAndRewrite; - LogicalResult match(TFL::AddOp add_op) const override { + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::AddOp add_op, + PatternRewriter &rewriter) const override { if (!add_op->hasOneUse()) { return failure(); } @@ -2572,13 +2485,6 @@ struct UndoBroadcastFullyConnectedBiasAddWithQDQs return failure(); } - return success(); - } - - void rewrite(TFL::AddOp add_op, PatternRewriter &rewriter) const override { - auto dq_op = cast(add_op.getRhs().getDefiningOp()); - auto q_op = cast(dq_op.getInput().getDefiningOp()); - auto bias_op = cast(q_op.getInput().getDefiningOp()); auto new_bias = FlattenTo1D(bias_op.getValueAttr()); auto new_bias_type = new_bias.getType(); auto new_bias_op = rewriter.create( @@ -2603,6 +2509,7 @@ struct UndoBroadcastFullyConnectedBiasAddWithQDQs // Remove old bias rewriter.eraseOp(bias_op); + return success(); } }; @@ -2813,6 +2720,233 @@ struct PushTransposeThroughSqueeze : public RewritePattern { } }; +// Helper function to check if a constant tensor attribute has the expected +// integer values +bool matchConstantIntPermutation(Value permValue, + ArrayRef expectedPerm) { + DenseElementsAttr permAttr; + if (!matchPattern(permValue, m_Constant(&permAttr))) { + return false; // Not a constant + } + if (!permAttr.getElementType().isInteger(32) && + !permAttr.getElementType().isInteger(64)) { + // TFLite perms are often i32, but accept i64 too + return false; + } + + auto values = permAttr.getValues(); + if (values.size() != expectedPerm.size()) { + return false; + } + for (size_t i = 0; i < expectedPerm.size(); ++i) { + if (values[i].getSExtValue() != expectedPerm[i]) { + return false; + } + } + return true; +} + +inline DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, + Builder *builder) { + RankedTensorType ty = mlir::RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(32)); + return DenseIntElementsAttr::get(ty, values); +} + +inline DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, + Builder *builder) { + llvm::SmallVector new_values; + for (auto el : values) { + new_values.push_back(static_cast(el)); + } + RankedTensorType ty = mlir::RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(32)); + return DenseIntElementsAttr::get(ty, new_values); +} + +// Reorders a Transpose-Reshape-Transpose sequence to +// Reshape-Transpose-Transpose to allow for further optimization. +// +// The pattern matches: +// Transpose(Reshape(Transpose(input, perm: [1, 0]))) +// +// and rewrites it to: +// Transpose(Transpose(Reshape(input))) +// +// This reordering allows for further optimization by potentially fusing the +// reshapes and transposes. +struct ReorderTransposeReshapeTranspose + : public OpRewritePattern { + explicit ReorderTransposeReshapeTranspose(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/0) {} + + LogicalResult matchAndRewrite(TFL::TransposeOp outer_tpose, + PatternRewriter &rewriter) const override { + auto reshape = outer_tpose.getInput().getDefiningOp(); + if (!reshape) return failure(); + + auto inner_tpose = reshape.getInput().getDefiningOp(); + if (!inner_tpose) return failure(); + + auto inner_tpose_shape = + mlir::dyn_cast_or_null(inner_tpose.getType()); + if (!inner_tpose_shape) return failure(); + + auto input = inner_tpose.getInput(); + + auto inner_perm = inner_tpose.getPerm(); + if (!matchConstantIntPermutation(inner_perm, {1, 0})) return failure(); + + int64_t perm0 = inner_tpose_shape.getDimSize(0); + + llvm::SmallVector reshape_shape; + { + DenseIntElementsAttr reshape_shape_attr; + if (!matchPattern(reshape.getShape(), m_Constant(&reshape_shape_attr))) { + return failure(); + } + + for (auto dim : reshape_shape_attr) { + reshape_shape.push_back(static_cast(dim.getSExtValue())); + } + } + + // Consume dimensions until we've equaled the size of the first dim in the + // permuted result of the inner tpose and record the dim. + int32_t dim = -1; + for (auto i = 0, running_total = 1; i < reshape_shape.size(); i++) { + running_total *= reshape_shape[i]; + if (perm0 == running_total) { + dim = i; + } + } + + if (dim == -1) return failure(); + + llvm::SmallVector new_reshape_shape(reshape_shape.size()); + llvm::SmallVector new_inner_perm(reshape_shape.size()); + + int index = 0; + for (auto i = dim + 1; i < reshape_shape.size(); i++) { + new_inner_perm[i] = index; + new_reshape_shape[index++] = reshape_shape[i]; + } + for (auto i = 0; i <= dim; i++) { + new_inner_perm[i] = index; + new_reshape_shape[index++] = reshape_shape[i]; + } + + auto reshape_type = + mlir::dyn_cast_or_null(reshape.getType()); + if (!reshape_type) return failure(); + + auto new_reshape_shape_const = rewriter.create( + reshape.getLoc(), GetI32ElementsAttr(new_reshape_shape, &rewriter)); + + auto new_inner_reshape = rewriter.create( + reshape.getLoc(), + RankedTensorType::get(new_reshape_shape, reshape_type.getElementType()), + input, new_reshape_shape_const.getResult()); + auto new_inner_tpose = rewriter.create( + inner_tpose.getLoc(), reshape_type, new_inner_reshape, + rewriter.create( + inner_tpose.getLoc(), + GetI32ElementsAttr(new_inner_perm, &rewriter))); + + rewriter.replaceOp(reshape, new_inner_tpose); + + return success(); + } +}; + +// Some models produce FullyConnected ops where the LHS is a const and the RHS +// is the activation. This breaks some downstream optimizations (notably input +// caching in XNNPack among other things). This rewrite pattern swaps the +// operands to match the expected order and recomputes a new output shape for +// the resuling op. +// +// This pattern only applies when: +// * input and filter operands are 2D +// * bias = none +// * keep_num_dims = false (implied if input and filter are 2D) +// Support for additional cases to broaden applicability can be added later. +// TODO(b/408313959): Add support for more cases. +// +// Note that transposes are added to maintain correctness: +// +// Original: Output[B, O] = FC(Input[B, I](Const), Filter[O, I](Var), Bias=None) +// ~= matmul(C, transpose(V)) +// +// Transformed: +// Intermediate[O, B] = FC(Filter[O, I](Var), Input[B, I](Const), None) +// ~= matmul(V, transpose(C)) +// FinalOutput[B, O] = Transpose(Intermediate[O, B], perm=[1, 0]) +struct FullyConnectedSwapOperandsWhenLHSIsConst + : public OpRewritePattern { + explicit FullyConnectedSwapOperandsWhenLHSIsConst(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/0) {} + + LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc, + PatternRewriter &rewriter) const override { + if (!mlir::isa(fc.getBias().getType())) return failure(); + + auto input = fc.getInput(); + auto filter = fc.getFilter(); + + if (!matchPattern(input, m_Constant()) || + matchPattern(filter, m_Constant())) + return failure(); + + auto input_type = mlir::dyn_cast(input.getType()); + auto filter_type = mlir::dyn_cast(filter.getType()); + auto output_type = + mlir::dyn_cast(fc.getResult(0).getType()); + + if (!input_type || !filter_type || !output_type) return failure(); + + if (input_type.getRank() != 2 || filter_type.getRank() != 2) + return failure(); + + // Dimensions: B=Batch, I=InputDepth, O=OutputDepth + // Input: [B, I], Filter: [O, I] + // We extract B from the input operand and O from the filter operand + int64_t B = input_type.getDimSize(0); + int64_t O = filter_type.getDimSize(0); + + Type element_type = output_type.getElementType(); + Location loc = fc.getLoc(); + + RankedTensorType intermediate_type = + RankedTensorType::get({O, B}, element_type); + + auto new_fc = rewriter.create( + loc, + /*resultTypes=*/intermediate_type, + /*input=*/filter, // Original Filter V[O, I] + /*filter=*/input, // Original Input C[B, I] + /*bias=*/fc.getBias(), + /*fused_activation_function=*/ + rewriter.getStringAttr(fc.getFusedActivationFunction()), + /*weights_format=*/fc.getWeightsFormatAttr(), + /*keep_num_dims=*/rewriter.getBoolAttr(false), + /*asymmetric_quantize_inputs=*/ + fc.getAsymmetricQuantizeInputsAttr() // Propagate quant attr + ); + + RankedTensorType final_shape_type = + RankedTensorType::get({B, O}, element_type); + + Value transposed_result = rewriter.create( + loc, final_shape_type, new_fc.getResult(0), + rewriter.create( + loc, GetI32ElementsAttr(ArrayRef({1, 0}), &rewriter))); + + rewriter.replaceOp(fc, transposed_result); + + return success(); + } +}; + // Adds canonicalization patterns to the list of patterns. void AddCanonicalizationPatterns(MLIRContext *context, RewritePatternSet *patterns) { @@ -2873,8 +3007,9 @@ void OptimizePass::runOnOperation() { OptimizeTopK, FuseAddAndStridedSlice, FuseReshapeAndTransposeAroundBatchMatmul, FuseTransposeReshapeIntoBatchMatmul, MoveReshapeAfterFullyConnected, - EnableFullyConnectedKeepNumDimsBeforeReshape, ConvertTFLBroadcastToMulOp>( - ctx); + EnableFullyConnectedKeepNumDimsBeforeReshape, + ReorderTransposeReshapeTranspose, + FullyConnectedSwapOperandsWhenLHSIsConst>(ctx); if (!GetOptions().disable_fuse_mul_and_fc) { phase_2_patterns.add(ctx); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index cfa1e21619d203..fc09b1f6a55021 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -27,21 +27,21 @@ include "mlir/IR/CommonAttrConstraints.td" // Checks if the param passed is a F32 ElementsAttr. def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() && $_self.cast().getShapedType().getElementType().isF32()">, + CPred<"llvm::isa($_self) && llvm::cast($_self).getShapedType().getElementType().isF32()">, "32 bit float constant tensor">; // Checks if the param passed is a float ElementsAttr. def FloatElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() && $_self.cast().getShapedType().getElementType().isa()">, + CPred<"llvm::isa($_self) && llvm::isa(llvm::cast($_self).getShapedType().getElementType())">, "float constant tensor">; def ExtractSingleElementAsFloat : NativeCodeCall< - "ExtractSingleElementAsFloat($_self.cast())">; + "ExtractSingleElementAsFloat(llvm::cast($_self))">; // Checks if the value has rank 'n'. class HasRank : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() == " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() == " # n>>; class FloatValueEquals : Constraint>; @@ -57,9 +57,9 @@ def HasOneUse : Constraint>; def IsPermutationNCHW : Constraint>; def IsBiasShape : Constraint< - CPred<"$0.getType().cast().getRank() == 4 && " - "$0.getType().cast().getShape()[2] == 1 && " - "$0.getType().cast().getShape()[3] == 1">, + CPred<"llvm::cast($0.getType()).getRank() == 4 && " + "llvm::cast($0.getType()).getShape()[2] == 1 && " + "llvm::cast($0.getType()).getShape()[3] == 1">, "has shape consistent with a bias">; def ReshapeNCHWBiasToNHWC : NativeCodeCall<"ReshapeNCHWBiasToNHWC($0, $1)">; @@ -114,7 +114,7 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], } def GetBiasMultiplier: - NativeCodeCall<"GetBiasMultiplier($_builder, $0, $1.cast())">; + NativeCodeCall<"GetBiasMultiplier($_builder, $0, llvm::cast($1))">; class CanFuseConvOrDepthwiseConv : Constraint< CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>; @@ -372,22 +372,22 @@ def MatchHardSwishPattern6 : Pat< // Constraint that the attribute value is less than 'n' class ConstDoubleValueLessThan : Constraint< - CPred<"$0.isa() && " - "$0.cast().getNumElements() == 1 && " - "std::abs(*$0.cast().getValues().begin()) < " + CPred<"llvm::isa($0) && " + "llvm::cast($0).getNumElements() == 1 && " + "std::abs(*llvm::cast($0).getValues().begin()) < " # n>>; // Constraint that the attribute value is negative infinity or negative largest. // We use both -inf & flt_min due to the forward compatibility. def ConstAPFloatNegLargestOrNegInfinity : Constraint() && " - "$0.cast().getNumElements() == 1 && " - "(($0.cast().getValues()[0].isLargest() && " - "$0.cast().getValues()[0].isNegative()) || " - "$0.cast().getValues()[0].isNegInfinity())">>; + "llvm::isa($0) && " + "llvm::cast($0).getNumElements() == 1 && " + "((llvm::cast($0).getValues()[0].isLargest() && " + "llvm::cast($0).getValues()[0].isNegative()) || " + "llvm::cast($0).getValues()[0].isNegInfinity())">>; def L2NormValidReduceIndex : Constraint())">>; + "L2NormalizeReduceAxis($0, llvm::cast($1))">>; // Currently L2Normalization doesn't support activation function // in TFLite. @@ -456,9 +456,9 @@ def IsReducedTailOfShape : Constraint>; def Flatten : NativeCodeCall< - "$0.cast()" - ".reshape(RankedTensorType::get({$0.getType().cast().getNumElements()}, " - "$0.getType().cast().getElementType()))">; + "llvm::cast($0)" + ".reshape(RankedTensorType::get({llvm::cast($0.getType()).getNumElements()}, " + "llvm::cast($0.getType()).getElementType()))">; def IsLastDimEqualToNumElements : Constraint>; @@ -725,20 +725,20 @@ foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp, // Returns truncated shape of a ranked-tensor. // Prefix-Truncated, here, means eliminating any contiguous 1s' in the lower // dimentions of the tensor -def GetPrefixTruncatedShape: NativeCodeCall<"GetShape($0, true)">; +def GetPrefixTruncatedShape: NativeCodeCall<"GetShapeAttr($0, true)">; // Returns True if the operand type is RankedTensorType and valid. def HasValidRankedTensor : Constraint() && " - "$0.getType().cast().getNumDynamicDims() <= 1">>; + "llvm::isa($0.getType()) && " + "llvm::cast($0.getType()).getNumDynamicDims() <= 1">>; // Check if the truncated shape of the lhs is equal to the shape of rhs def IsPrefixTruncatedShapeEqualTo : Constraint>; + "GetShapeAttr($0, true) == GetShapeAttr($1)">>; def ConvertSqueezeToReshape : Pat< (TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), - (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $squeeze_op))), + (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShapeAttr $squeeze_op))), [(HasValidRankedTensor $squeeze_op)]>; // Pattern to perform the following optimization @@ -793,7 +793,7 @@ def UndoBroadcastConvBiasAdd : Pat< // Pattern to convert a trivial transpose op to a reshape op. def ConvertTrivialTransposeOpToReshapeOp : Pat< (TFL_TransposeOp:$transpose_op $input, (Arith_ConstantOp:$permutation $p1)), - (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $transpose_op))), + (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShapeAttr $transpose_op))), [(IsTransposeTrivial $input, $permutation), (AnyStaticShapeTensor $input), (AnyStaticShapeTensor $transpose_op)]>; @@ -810,7 +810,7 @@ def FoldDoubleTranspose : Pat< // Convert expand_dims to reshape if possible. def ConvertExpandDimsToReshape : Pat< (TFL_ExpandDimsOp:$expand_dims_op $input, $dim), - (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $expand_dims_op))), + (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShapeAttr $expand_dims_op))), [(AnyStaticShapeTensor $expand_dims_op)]>; // Here, the element type can be any integer or float type. @@ -900,8 +900,8 @@ def RemoveShapeOnlyCast : Pat<(TFL_CastOp:$output $input), // Checks if the operand0's rank is one less than operand1's rank. def PReluAlphaRankCheck : Constraint< - CPred<"$0.getType().cast().getRank() == " - "$1.getType().cast().getRank() - 1">>; + CPred<"llvm::cast($0.getType()).getRank() == " + "llvm::cast($1.getType()).getRank() - 1">>; // PReLU pattern from Keras: // f(x) = Relu(x) + (-alpha * Relu(-x)) @@ -979,7 +979,7 @@ def OptimizePow2ToRsqrt : Pat< def CanOptimizeIdentityGatherNdOrScatterNdOp : Constraint(), $2.getType())">>; + "$0, llvm::cast($1), $2.getType())">>; def OptimizeIdentityGatherNdOp : Pat< (TFL_GatherNdOp:$output $params, (Arith_ConstantOp I32ElementsAttr: $indices)), @@ -1013,9 +1013,9 @@ def IsSame : Constraint>; def HasTwoUse : Constraint>; def AxesIsLastDimension : Constraint().getNumElements() == 1 && " - "($0.cast().getValues()[0] == " - "$1.getType().cast().getRank() - 1 || $0.cast().getValues()[0] == -1)">>; + "llvm::cast($0).getNumElements() == 1 && " + "(llvm::cast($0).getValues()[0] == " + "llvm::cast($1.getType()).getRank() - 1 || llvm::cast($0).getValues()[0] == -1)">>; // Convert exp(x)/sum(exp(x)) into softmax. def OptimizeToSoftmax : Pat< @@ -1070,10 +1070,10 @@ def FoldNormalizationIntoSoftmaxJaxWithAxisMinus1 : Pat< def HaveSameType : Constraint>; class AllElementsAreF32 : Constraint() && " - "$0.cast().getType().cast().getElementType().isF32() && " - "std::all_of($0.cast().getValues().begin(), " - "$0.cast().getValues().end(), " + "(llvm::isa($0) && " + "llvm::cast(llvm::cast($0).getType()).getElementType().isF32() && " + "std::all_of(llvm::cast($0).getValues().begin(), " + "llvm::cast($0).getValues().end(), " "[](float v){ return v == " #val# ";}))">>; // Optimize X*1 to X @@ -1086,10 +1086,10 @@ def OptimizeMul1ToIdentity : Pat< (AllElementsAreF32<"1.0f"> $constant)]>; class AllElementsAreBool : Constraint() && " - "$0.cast().getType().cast().getElementType().isInteger(1) && " - "std::all_of($0.cast().getValues().begin(), " - "$0.cast().getValues().end(), " + "(llvm::isa($0) && " + "llvm::cast(llvm::cast($0).getType()).getElementType().isInteger(1) && " + "std::all_of(llvm::cast($0).getValues().begin(), " + "llvm::cast($0).getValues().end(), " "[](bool v){ return v == " #val# ";}))">>; // Remove select operators when the result is known in advance. @@ -1225,11 +1225,11 @@ def IsLastDimensionEqualOne : Constraint>; // As above but if shape is not static and rank 2 with last dim 1. def IsLastDimensionEqualOneOrDynamicBatchDimRank2 : Constraint< CPred<"IsLastDimensionEqualOne($0) || " - "(!$0.getType().cast().hasStaticShape() && " - " $0.getType().cast().hasRank() && " - " $0.getType().cast().getRank() == 2 && " - " !$0.getType().cast().getShape().empty() && " - " $0.getType().cast().getShape()[1] == 1)">>; + "(!llvm::cast($0.getType()).hasStaticShape() && " + " llvm::cast($0.getType()).hasRank() && " + " llvm::cast($0.getType()).getRank() == 2 && " + " !llvm::cast($0.getType()).getShape().empty() && " + " llvm::cast($0.getType()).getShape()[1] == 1)">>; // Replace // Equal(X, indices) @@ -1250,10 +1250,10 @@ def ReshapeEqualOpToOneHotOp : Pat< (IsOneHotIndexAttribute $series)]>; def F32ElementsVal : Constraint().getElementType().isF32()">, + "llvm::cast($0.getType()).getElementType().isF32()">, "32 bit float tensor">; def I32ElementsVal : Constraint().getElementType().isInteger(32)">, + "llvm::cast($0.getType()).getElementType().isInteger(32)">, "32 bit integer tensor">; def ConvertSingleElementAttrToFloatAttr : @@ -1324,7 +1324,7 @@ def ReplaceOneHotFullyConnectedWithLookup : Pat< (Arith_ConstantOp ConstantAttr, "{1,0}">)), (returnType (GetEmbeddingLookupShape $indices, $filter)) ), - (Arith_ConstantOp (GetShape (GetIthValue<0> $outputs)))), + (Arith_ConstantOp (GetShapeAttr (GetIthValue<0> $outputs)))), [(I32ElementsVal $indices), // lookup is not implemented for i64 (IsNoneType $bias)]>; // Maybe folded into the lookup matrix later @@ -1397,6 +1397,67 @@ def MatchGeluApproximate : Pat< (HasOneUse $pow_out), ]>; +// Alternate pattern for GeluApproximate to match mul(x, mul(x, x)). +// 0.5 * x * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * mul(x, mul(x, x)) ) ) ) +def MatchGeluApproximate_Mul1 : Pat< + (TFL_MulOp + (TFL_MulOp:$mul_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), + (TFL_AddOp:$add_out + (TFL_TanhOp:$tanh_out + (TFL_MulOp:$mul_out1 + (TFL_AddOp:$add_out1 $arg0, + (TFL_MulOp:$mul_out2 + (TFL_MulOp:$pow_out $arg0, + (TFL_MulOp:$sqr_out $arg0, $arg0, TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), + (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrTrue), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"1"> $Cst_1), + (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), + (FloatValueEquals<"0.044715"> $Coeff), + (HasOneUse $mul_out), + (HasOneUse $add_out), + (HasOneUse $tanh_out), + (HasOneUse $mul_out1), + (HasOneUse $add_out1), + (HasOneUse $mul_out2), + (HasOneUse $pow_out), + (HasOneUse $sqr_out), + ]>; + +// Alternate pattern for GeluApproximate to match mul(mul(x, x), x). +// 0.5 * x * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * mul(mul(x, x), x) ) ) ) +def MatchGeluApproximate_Mul2 : Pat< + (TFL_MulOp + (TFL_MulOp:$mul_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), + (TFL_AddOp:$add_out + (TFL_TanhOp:$tanh_out + (TFL_MulOp:$mul_out1 + (TFL_AddOp:$add_out1 $arg0, + (TFL_MulOp:$mul_out2 + (TFL_MulOp:$pow_out + (TFL_MulOp:$sqr_out $arg0, $arg0, TFL_AF_None), + $arg0, TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), + (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrTrue), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"1"> $Cst_1), + (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), + (FloatValueEquals<"0.044715"> $Coeff), + (HasOneUse $mul_out), + (HasOneUse $add_out), + (HasOneUse $tanh_out), + (HasOneUse $mul_out1), + (HasOneUse $add_out1), + (HasOneUse $mul_out2), + (HasOneUse $pow_out), + (HasOneUse $sqr_out), + ]>; + // Alternate pattern for GeluApproximate (see different order for mul), replaces // x * ( 0.5 * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * pow( x, 3 ) ) ) ) ) def MatchGeluApproximate1 : Pat< @@ -1426,6 +1487,67 @@ def MatchGeluApproximate1 : Pat< (HasOneUse $pow_out), ]>; +// Alternate pattern for GeluApproximate1 to match mul(x, mul(x, x)). +// x * ( 0.5 * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * mul(x, mul(x, x)) ) ) ) ) +def MatchGeluApproximate1_Mul1 : Pat< + (TFL_MulOp $arg0, + (TFL_MulOp:$mul_out + (TFL_AddOp:$add_out + (TFL_TanhOp:$tanh_out + (TFL_MulOp:$mul_out1 + (TFL_AddOp:$add_out1 $arg0, + (TFL_MulOp:$mul_out2 + (TFL_MulOp:$pow_out $arg0, + (TFL_MulOp:$sqr_out $arg0, $arg0, TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), + (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrTrue), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"1"> $Cst_1), + (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), + (FloatValueEquals<"0.044715"> $Coeff), + (HasOneUse $mul_out), + (HasOneUse $add_out), + (HasOneUse $tanh_out), + (HasOneUse $mul_out1), + (HasOneUse $add_out1), + (HasOneUse $mul_out2), + (HasOneUse $pow_out), + (HasOneUse $sqr_out), + ]>; + +// Alternate pattern for GeluApproximate1 to match mul(mul(x, x), x). +// x * ( 0.5 * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * mul(mul(x, x), x) ) ) ) ) +def MatchGeluApproximate1_Mul2 : Pat< + (TFL_MulOp $arg0, + (TFL_MulOp:$mul_out + (TFL_AddOp:$add_out + (TFL_TanhOp:$tanh_out + (TFL_MulOp:$mul_out1 + (TFL_AddOp:$add_out1 $arg0, + (TFL_MulOp:$mul_out2 + (TFL_MulOp:$pow_out + (TFL_MulOp:$sqr_out $arg0, $arg0, TFL_AF_None), + $arg0, TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), + (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrTrue), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"1"> $Cst_1), + (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), + (FloatValueEquals<"0.044715"> $Coeff), + (HasOneUse $mul_out), + (HasOneUse $add_out), + (HasOneUse $tanh_out), + (HasOneUse $mul_out1), + (HasOneUse $add_out1), + (HasOneUse $mul_out2), + (HasOneUse $pow_out), + (HasOneUse $sqr_out), + ]>; + // For Gelu, replaces // 0.5 * x * ( 1 + erf( x * sqrt_1_2 ) ) def MatchGelu : Pat< @@ -1542,7 +1664,7 @@ def isF32Splat : Constraint< CPred<"IsF32Splat($0)">>; def ExtractF32AtIndex0: NativeCodeCall< - "$_builder.getF32FloatAttr($_self.cast().getValues()[0])">; + "$_builder.getF32FloatAttr(llvm::cast($_self).getValues()[0])">; def FuseLeakyReluConst : Pat< (TFL_SelectOp @@ -1577,16 +1699,16 @@ class ContractingDimsProductEqual : Constraint : Constraint().getShape()" + "(llvm::dyn_cast($0.getType()).getShape()" ".drop_back("#skip_last#").drop_front("#skip_first#") ==" - "$1.getType().dyn_cast().getShape()" + "llvm::dyn_cast($1.getType()).getShape()" ".drop_back("#skip_last#").drop_front("#skip_first#"))">>; // Returns true if the broadcast dimension of a tensor is [1] // here- broadcast dimension is first prefix dimension // excluding the last two dimensions def IsBroadcastDimEqualToOne : Constraint().getShape()[0] == 1">>; + "llvm::dyn_cast($0.getType()).getShape()[0] == 1">>; // Pattern to fuse/fold the reshape ops around TFL_BatchMatMulOp // This pattern is applied when the rank of rhs is 2 @@ -1831,25 +1953,25 @@ def FuseSliceAndPack4D : Pat<( // Given a value, checks if dim `d` is static. class HasStaticDim : Constraint().isDynamicDim(" # d # ")">>; + "!llvm::cast($0.getType()).isDynamicDim(" # d # ")">>; class IsBalancedPaddingArray : Constraint())">>; + "llvm::cast($0))">>; // Given in_shape, out_shape, stride checks ceil(in_shape[d] / stride) == out_shape[d] def IsSameStridedShape2D : Constraint()," - "$1.getType().cast().getShape())">>; + "llvm::cast($1.getType()).getShape())">>; def IsSameStridedShapeDepthwise : Constraint()," - "$1.getType().cast().getShape())">>; + "llvm::cast($1.getType()).getShape())">>; def IsSameStridedShape3D : Constraint()," - "$1.getType().cast().getShape())">>; + "llvm::cast($1.getType()).getShape())">>; def IsValidPadding : Constraint>; diff --git a/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h b/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h index 534b1402dd4cd3..29906014fce292 100644 --- a/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h +++ b/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h @@ -22,6 +22,7 @@ namespace TFL { class OptimizePassOptions; class VariableFreezingPipelineOptions; class EmptyPassOptions; +class OptimizeBroadcastLikePassOptions; // Interface for setting options for TFLite Converter Pass/Pipeline Options. class PassOptionsSetter { @@ -30,6 +31,7 @@ class PassOptionsSetter { virtual void SetOptions(OptimizePassOptions& options) const = 0; virtual void SetOptions(VariableFreezingPipelineOptions& options) const = 0; virtual void SetOptions(EmptyPassOptions& options) const = 0; + virtual void SetOptions(OptimizeBroadcastLikePassOptions& options) const = 0; }; } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 4d8ecccaa5f3f7..0f7dc05fb5dee9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -23,8 +23,10 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/transforms/canonicalize_boundary_value_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/optimize_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/pass_registry_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise_pass.h" @@ -289,6 +291,11 @@ inline std::unique_ptr CreateCanonicalizeBoundaryValuePass() { std::unique_ptr> CreatePartitionedTopologicalSortPass(); +// Create a pass that cleans up optimization barriers. +inline std::unique_ptr CreateCleanupOptimizationBarrierPass() { + return Create(); +} + #define GEN_PASS_DECL_DEFAULTQUANTPARAMSPASS #define GEN_PASS_DECL_LEGALIZETFPASS #define GEN_PASS_DECL_LOWERSTATICTENSORLISTPASS @@ -340,13 +347,14 @@ inline void registerTensorFlowLitePasses() { Register(); Register(); Register(); - Register(); + Register(); Register(); Register(); // Other TFLite Passes Register(); Register(); + Register(); } } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index 10e3156855ef45..cf2cc345e34dd4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -283,7 +283,8 @@ def PrepareTFPass : Pass<"tfl-prepare-tf", "mlir::func::FuncOp"> { let dependentDialects = ["TFL::TensorFlowLiteDialect", "mlir::quant::QuantDialect", "mlir::quantfork::QuantizationForkDialect", - "mhlo::MhloDialect" + "mhlo::MhloDialect", + "stablehlo::StablehloDialect" ]; let options = [ Option<"unfold_batch_matmul_", "unfold_batchmatmul", diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 914d426f278d66..65e5368b7faf96 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -15,13 +15,21 @@ limitations under the License. // This transformation pass applies some clean up steps after quantization. +#include +#include #include #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -31,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" //===----------------------------------------------------------------------===// // The post-quantize Passes. @@ -155,6 +164,92 @@ enum RemoveVolatileOpsType { kPreserveInputsAndOutputs, }; +// Returns a constant tensor with the given scalar/vector value and shape. +template +std::optional GetConstTensor(PatternRewriter& rewriter, + Location loc, llvm::ArrayRef vec, + llvm::ArrayRef shape) { + int64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + return std::nullopt; + } + + auto const_type = tensorflow::GetTypeFromTFTensorShape( + shape, rewriter.getIntegerType(sizeof(T) * 8)); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(loc, const_type, const_attr); + return const_op.getResult(); +} + +// Converts a dequantize op to a (scale * (input - zeropoint)). The expectation +// is that the qconst value will be constant folded to retain the original +// constant value. This is essentially a constant fold of the dequantize op, +// privided that the value, zp and scale are all constants. +std::optional ConvertDequantizeOp( + PatternRewriter& rewriter, mlir::Operation* op, + mlir::ShapedType output_type, mlir::Value input_value, + llvm::ArrayRef scale, llvm::ArrayRef zeropoint, + int64_t dim) { + RankedTensorType input_type = + dyn_cast(input_value.getType()); + if (!input_type) return std::nullopt; + + std::optional zp_val; + if (zeropoint.size() == 1) { + auto const_type = + tensorflow::GetTypeFromTFTensorShape({}, rewriter.getF32Type()); + auto const_attr = + DenseElementsAttr::get(const_type, static_cast(zeropoint[0])); + + auto const_op = rewriter.create(op->getLoc(), const_type, + const_attr); + zp_val = const_op.getResult(); + } else { + SmallVector shape; + shape.resize(input_type.getRank(), 1); + shape[dim] = zeropoint.size(); + zp_val = GetConstTensor(rewriter, op->getLoc(), zeropoint, shape); + } + + std::optional scale_val; + if (scale.size() == 1) { + auto const_type = + tensorflow::GetTypeFromTFTensorShape({}, rewriter.getF32Type()); + auto const_attr = + DenseElementsAttr::get(const_type, static_cast(scale[0])); + + auto const_op = rewriter.create(op->getLoc(), const_type, + const_attr); + scale_val = const_op.getResult(); + } else { + SmallVector shape; + shape.resize(input_type.getRank(), 1); + shape[dim] = scale.size(); + scale_val = GetConstTensor(rewriter, op->getLoc(), scale, shape); + } + + if (!zp_val || !scale_val) return std::nullopt; + + auto op1_cast_in = + rewriter.create(op->getLoc(), output_type, input_value); + + auto op2_sub_op1 = rewriter.create( + op->getLoc(), output_type, op1_cast_in.getResult(), zp_val.value(), + /*fused_activation_function=*/rewriter.getStringAttr("NONE")); + + return rewriter + .create( + op->getLoc(), output_type, op2_sub_op1.getResult(), scale_val.value(), + /*fused_activation_function=*/rewriter.getStringAttr("NONE")) + .getResult(); +} + // Remove the back-to-back quantize and dequantize ops with volatile attribute. template struct RemoveVolatileOps : public OpRewritePattern { @@ -188,6 +283,48 @@ struct RemoveVolatileOps : public OpRewritePattern { op.replaceAllUsesWith(q.getInput()); return success(); + } else if (auto qconst_op = llvm::dyn_cast_or_null(input_op)) { + if (!qconst_op->getAttr(mlir::quant::kVolatileOpAttrName)) + return failure(); + + auto qtype = + quant::QuantizedType::getQuantizedElementType(qconst_op.getType()); + if (!qtype) return failure(); + SmallVector scale; + SmallVector zeropoint; + int64_t dim = 0; + + if (auto uniform_qtype = + mlir::dyn_cast(qtype)) { + scale.push_back(uniform_qtype.getScale()); + zeropoint.push_back(uniform_qtype.getZeroPoint()); + } else if (auto per_axis_qtype = + mlir::dyn_cast( + qtype)) { + scale.assign(per_axis_qtype.getScales().begin(), + per_axis_qtype.getScales().end()); + zeropoint.assign(per_axis_qtype.getZeroPoints().begin(), + per_axis_qtype.getZeroPoints().end()); + dim = per_axis_qtype.getQuantizedDimension(); + } else { + return failure(); + } + + auto output_type = mlir::cast(op.getOutput().getType()); + + auto const_type = tensorflow::GetTypeFromTFTensorShape( + output_type.getShape(), qtype.getStorageType()); + auto const_op = rewriter.create( + op->getLoc(), const_type, qconst_op.getValue()); + + auto new_value = + ConvertDequantizeOp(rewriter, op, output_type, const_op.getResult(), + scale, zeropoint, dim); + if (!new_value) return failure(); + + op.replaceAllUsesWith(new_value.value()); + op->erase(); + return success(); } return failure(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index 1afceede5252c4..5cf0b3d4ef65b9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -20,7 +20,7 @@ include "tensorflow/compiler/mlir/lite/utils/utils.td" def FalseBoolAttr : AttrConstraint>; def DenseElementsAttr : ElementsAttrBase< - CPred<"$_self.isa()">, + CPred<"llvm::isa($_self)">, "non-opaque constant tensor">; def CreateGatherNdOp : NativeCodeCall< diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 28f46d1d01b592..7716a926a5ab44 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -65,7 +65,6 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -75,6 +74,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/shape_and_size_utils.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -84,6 +84,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/rewriters.h" +#include "xla/mlir_hlo/mhlo/utils/type_conversion.h" #define DEBUG_TYPE "tf-tfl-legalization" @@ -1372,6 +1374,11 @@ LogicalResult ConvertTf2XlaOps(func::FuncOp func, MLIRContext *context) { mlir::odml::PopulateLegalizeHloToTfPatterns(&patterns, context); mhlo::GatherOp::getCanonicalizationPatterns(patterns, context); + // mhlo::PopulateLegalizeTfPatterns emits StableHLO ops, until this pipeline + // handles StableHLO ops directly, we need to convert them to MHLO ops. + stablehlo::StablehloToHloTypeConverter hlo_converter; + stablehlo::populateStablehloToHloPatterns(&patterns, &hlo_converter, context); + return applyPartialConversion(func, target, std::move(patterns)); } diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index ae1674b5862986..eb545059806905 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -236,7 +236,7 @@ class StrictQuantizationPattern : public RewritePattern { inputs.reserve(quantizing_op->getNumOperands()); for (auto operand : quantizing_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (mlir::isa(operand_type)) { inputs.push_back(operand); continue; } @@ -292,7 +292,7 @@ class StrictQuantizationPattern : public RewritePattern { Type result_type = result.getType(); // Add this to the test coverage once we create test ops with none // type results. - if (result_type.isa()) { + if (mlir::isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; @@ -582,7 +582,13 @@ class QuantizeConstPattern : public OpRewritePattern { quantized_attr = quant::Quantize(attr, qtype.getValue()); } if (quantized_attr) { - rewriter.replaceOpWithNewOp(op, qtype, quantized_attr); + auto qconst_op = + rewriter.create(op.getLoc(), qtype, quantized_attr); + if (auto volatile_attr = op->getAttr(quant::kVolatileOpAttrName)) { + qconst_op->setAttr(quant::kVolatileOpAttrName, volatile_attr); + } + op.replaceAllUsesWith(qconst_op.getOutput()); + rewriter.eraseOp(op); return success(); } } diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td index ae8af0a99cc889..e6781e7ce30b7c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td @@ -27,7 +27,7 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" // Quantize attribute $0 by using quantization parameter from %1. def QuantizeByQuantizedType : NativeCodeCall<"quant::Quantize($0, $1.getValue())">; def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; + CPred<"llvm::cast($_self).getShapedType().getElementType().isF32()">, "float constant tensor">; def HasSameType : Constraint>; diff --git a/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.cc b/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.cc index 0fe96f4b0b71cc..5e20684f6a9485 100644 --- a/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.cc @@ -71,7 +71,8 @@ ConstBytesAttr CreateListReserveOptions(MLIRContext* context, } std::optional GetSingularVariantBaseType(Value val) { - auto val_t = mlir::getElementTypeOrSelf(val).dyn_cast_or_null(); + auto val_t = llvm::dyn_cast_or_null( + mlir::getElementTypeOrSelf(val)); if (!val_t) { return std::nullopt; } @@ -107,11 +108,13 @@ std::optional CustomOptions(MLIRContext* context, bool HasVariantInputOrOutput(Operation* op) { const bool has_variant_input = llvm::any_of(op->getOperands(), [](Value val) { - return val.getType().cast().getElementType().isa(); + return llvm::isa( + llvm::cast(val.getType()).getElementType()); }); const bool has_variant_output = llvm::any_of(op->getResultTypes(), [](Type t) { - return t.cast().getElementType().isa(); + return llvm::isa( + llvm::cast(t).getElementType()); }); return has_variant_input || has_variant_output; } diff --git a/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.cc b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.cc new file mode 100644 index 00000000000000..e40fb1a85d4e88 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.cc @@ -0,0 +1,303 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +BatchMatMulDimensionsInfo::BatchMatMulDimensionsInfo(mlir::ShapedType type, + bool is_lhs) + : is_lhs_(is_lhs) { + // BatchMatMulOp has the following shape pattern: B0,...,Bn,L,C and + // B0,...,Bn,C,R. So, there is only one Contracting dimension and one + // output dimension. + const int64_t rank = type.getRank(); + + if (is_lhs) { + contracting_dimensions_.axes.push_back(rank - 1); + contracting_dimensions_.sizes.push_back(type.getDimSize(rank - 1)); + out_dimensions_.axes.push_back(rank - 2); + out_dimensions_.sizes.push_back(type.getDimSize(rank - 2)); + } else { + contracting_dimensions_.axes.push_back(rank - 2); + contracting_dimensions_.sizes.push_back(type.getDimSize(rank - 2)); + out_dimensions_.axes.push_back(rank - 1); + out_dimensions_.sizes.push_back(type.getDimSize(rank - 1)); + } + // Dims 0 and 1 are contracting and output dimensions, hence skipped. + for (int64_t dim = 0; dim < rank - 2; ++dim) { + batch_dimensions_.axes.push_back(dim); + batch_dimensions_.sizes.push_back(type.getDimSize(dim)); + } +} + +const DimensionVector& BatchMatMulDimensionsInfo::batch_dimensions() const { + return batch_dimensions_; +} +const DimensionVector& BatchMatMulDimensionsInfo::contracting_dimensions() + const { + return contracting_dimensions_; +} + +const DimensionVector& BatchMatMulDimensionsInfo::out_dimensions() const { + return out_dimensions_; +} + +bool BatchMatMulDimensionsInfo::is_lhs() const { return is_lhs_; } + +BatchMatMulDimensionsInfo GetBatchMatMulLhsDimensionsInfo( + mlir::ShapedType type) { + return BatchMatMulDimensionsInfo(type, /*is_lhs=*/true); +} + +BatchMatMulDimensionsInfo GetBatchMatMulRhsDimensionsInfo( + mlir::ShapedType type) { + return BatchMatMulDimensionsInfo(type, /*is_lhs=*/false); +} + +bool HasFlattenedContractingDims( + llvm::ArrayRef reshape_input_shape, + const BatchMatMulDimensionsInfo& bmm_dimensions_info) { + // Batch dimensions are not flattened and need to match the LHS/RHS of + // BatchMatMulOp. + auto batch_dimensions = bmm_dimensions_info.batch_dimensions().SizesArray(); + // The batch dimensions are at the front of the input shape. + auto reshape_input_shape_batch_dims = + reshape_input_shape.take_front(batch_dimensions.size()); + + if (!llvm::all_of( + llvm::zip(batch_dimensions, reshape_input_shape_batch_dims), + [](auto dims) { return std::get<0>(dims) == std::get<1>(dims); })) { + return false; + } + + // Out dimensions are assumed to be unflattened and need to match the LHS/RHS + // of BatchMatMulOp. + auto out_dimensions = bmm_dimensions_info.out_dimensions().SizesArray(); + llvm::ArrayRef reshape_input_shape_out_dims; + // The out dimensions are at the end of the input shape for LHS and + // at the front for RHS. + if (bmm_dimensions_info.is_lhs()) { + reshape_input_shape_out_dims = + reshape_input_shape.slice(batch_dimensions.size(), 1); + } else { + reshape_input_shape_out_dims = + reshape_input_shape.take_back(out_dimensions.size()); + } + if (!llvm::all_of( + llvm::zip(out_dimensions, reshape_input_shape_out_dims), + [](auto dims) { return std::get<0>(dims) == std::get<1>(dims); })) { + return false; + } + + auto contracting_dimensions = + bmm_dimensions_info.contracting_dimensions().SizesArray(); + // The contracting dimensions are at the end of the input shape for + // LHS and at the front for RHS. + llvm::ArrayRef reshape_input_shape_contracting_dims; + size_t num_contracting_dims = reshape_input_shape.size() - + batch_dimensions.size() - out_dimensions.size(); + if (bmm_dimensions_info.is_lhs()) { + reshape_input_shape_contracting_dims = + reshape_input_shape.take_back(num_contracting_dims); + } else { + reshape_input_shape_contracting_dims = reshape_input_shape.slice( + batch_dimensions.size(), num_contracting_dims); + } + + return (std::accumulate(reshape_input_shape_contracting_dims.begin(), + reshape_input_shape_contracting_dims.end(), 1, + std::multiplies()) == + contracting_dimensions[0]); +} + +bool HasFlattenedOutDims(llvm::ArrayRef reshape_input_shape, + const BatchMatMulDimensionsInfo& bmm_dimensions_info) { + // Batch dimensions are not flattened and need to match the LHS/RHS of + // BatchMatMulOp. + auto batch_dimensions = bmm_dimensions_info.batch_dimensions().SizesArray(); + // The batch dimensions are at the front of the input shape. + auto reshape_input_shape_batch_dims = + reshape_input_shape.take_front(batch_dimensions.size()); + if (!llvm::all_of( + llvm::zip(batch_dimensions, reshape_input_shape_batch_dims), + [](auto dims) { return std::get<0>(dims) == std::get<1>(dims); })) { + return false; + } + + auto contracting_dimensions = + bmm_dimensions_info.contracting_dimensions().SizesArray(); + // The contracting dimensions are at the end of the input shape for + // LHS and at the front for RHS. + llvm::ArrayRef reshape_input_shape_contracting_dims; + if (bmm_dimensions_info.is_lhs()) { + reshape_input_shape_contracting_dims = + reshape_input_shape.take_back(contracting_dimensions.size()); + } else { + reshape_input_shape_contracting_dims = + reshape_input_shape.slice(batch_dimensions.size(), 1); + } + if (!llvm::all_of( + llvm::zip(contracting_dimensions, + reshape_input_shape_contracting_dims), + [](auto dims) { return std::get<0>(dims) == std::get<1>(dims); })) { + return false; + } + + auto out_dimensions = bmm_dimensions_info.out_dimensions().SizesArray(); + // The out dimensions are at the end of the input shape for LHS and + // at the front for RHS. + llvm::ArrayRef reshape_input_shape_out_dims; + size_t num_out_dims = reshape_input_shape.size() - batch_dimensions.size() - + contracting_dimensions.size(); + if (bmm_dimensions_info.is_lhs()) { + reshape_input_shape_out_dims = + reshape_input_shape.slice(batch_dimensions.size(), num_out_dims); + } else { + reshape_input_shape_out_dims = reshape_input_shape.take_back(num_out_dims); + } + + return (std::accumulate(reshape_input_shape_out_dims.begin(), + reshape_input_shape_out_dims.end(), 1, + std::multiplies()) == out_dimensions[0]); +} + +std::tuple, std::pair> +GetTransposedGroupsIndexRange(llvm::ArrayRef transpose_permutation) { + // If the input vector is empty, return None for both pairs. + if (transpose_permutation.empty()) { + return {{-1, -1}, {-1, -1}}; // Use -1 to indicate None + } + + int group_one_end_idx = -1; + for (int i = 0; i < transpose_permutation.size(); ++i) { + if (transpose_permutation[i] == i) { + group_one_end_idx = i; + } else { + break; + } + } + + // If all dimensions are batch dimensions, i.e. the first group is a + // monotonically increasing sequence, return None for both remaining groups. + if (group_one_end_idx == transpose_permutation.size() - 1) { + return {{-1, -1}, {-1, -1}}; + } + + int group_two_start_idx = group_one_end_idx + 1; + int group_two_end_idx = group_two_start_idx; + int group_three_start_idx = -1; + int group_three_end_idx = -1; + + int group_two_end_idx_value = transpose_permutation.size() - 1; + int group_three_start_idx_value = group_one_end_idx + 1; + + for (int i = group_two_start_idx + 1; i < transpose_permutation.size(); ++i) { + if (transpose_permutation[i] > group_two_end_idx_value || + transpose_permutation[i] <= group_three_start_idx_value || + (transpose_permutation[i] != transpose_permutation[i - 1] + 1)) { + break; + } + group_two_end_idx = i; + } + + group_three_start_idx = group_two_end_idx + 1; + group_three_end_idx = transpose_permutation.size() - 1; + // Fail if the last group is not a monotonically increasing sequence. + for (int i = group_three_start_idx + 1; i < transpose_permutation.size(); + ++i) { + if (transpose_permutation[i] != transpose_permutation[i - 1] + 1) { + return {{-1, -1}, {-1, -1}}; + } + } + + // Handle edge cases where start index might be greater than end index. + if (group_two_start_idx > group_two_end_idx) { + group_two_start_idx = group_two_end_idx; + } + + if (group_three_start_idx > group_three_end_idx) { + group_three_start_idx = group_three_end_idx; + } + if (group_three_start_idx >= transpose_permutation.size()) { + group_three_start_idx = -1; + group_three_end_idx = -1; + } + + return {{group_two_start_idx, group_two_end_idx}, + {group_three_start_idx, group_three_end_idx}}; +} + +bool HasTransposedContractingAndOutDims( + llvm::ArrayRef transpose_input_shape, + llvm::ArrayRef transpose_permutation, + const BatchMatMulDimensionsInfo& bmm_dimensions_info) { + std::tuple, std::pair> + transposed_groups_index_range = + GetTransposedGroupsIndexRange(transpose_permutation); + // Return false if the transpose_permutation is not valid. + if (std::get<0>(transposed_groups_index_range).first == -1 || + std::get<0>(transposed_groups_index_range).second == -1 || + std::get<1>(transposed_groups_index_range).first == -1 || + std::get<1>(transposed_groups_index_range).second == -1) { + return false; + } + + // Check if the broadcast dimensions match the batch dimensions of + // BatchMatMulOp. + if (!bmm_dimensions_info.batch_dimensions().AxesArray().empty() && + bmm_dimensions_info.batch_dimensions().AxesArray().back() != + std::get<0>(transposed_groups_index_range).first - 1) { + return false; + } + + // Accumulating the sizes of the transposed groups should match the sizes of + // the contracting and out dimensions of BatchMatMulOp. + int64_t group_two_dims_size = 1; + int64_t group_three_dims_size = 1; + for (int i = std::get<0>(transposed_groups_index_range).first; + i <= std::get<0>(transposed_groups_index_range).second; ++i) { + group_two_dims_size *= transpose_input_shape[transpose_permutation[i]]; + } + for (int i = std::get<1>(transposed_groups_index_range).first; + i <= std::get<1>(transposed_groups_index_range).second; ++i) { + group_three_dims_size *= transpose_input_shape[transpose_permutation[i]]; + } + + const auto& out_dims = bmm_dimensions_info.out_dimensions().SizesArray()[0]; + const auto& contracting_dims = + bmm_dimensions_info.contracting_dimensions().SizesArray()[0]; + + return bmm_dimensions_info.is_lhs() + ? (group_two_dims_size == out_dims && + group_three_dims_size == contracting_dims) + : (group_two_dims_size == contracting_dims && + group_three_dims_size == out_dims); +} +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h new file mode 100644 index 00000000000000..3eb3de702e1f4a --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h @@ -0,0 +1,141 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_OPTIMIZE_BATCH_MATMUL_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_OPTIMIZE_BATCH_MATMUL_UTILS_H_ + +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +// LHS and RHS of BatchMatMulOp has shapes following the pattern: +// B0,...,Bn,L,C and B0,...,Bn,C,R. The output shape of BatchMatMulOp is: +// B0,...,Bn,L,R. +// +// LHS and RHS of FullyConnectedOp has shapes following the pattern: +// B0,...,Bn,L,C and R,C. The output shape of FullyConnectedOp is: +// B0,...,Bn,L,R. +// +// The fundamental idea behind seeing transposes and reshapes around +// BatchMatMulOp is that- +// -- BatchMatMulOp is often created as a result of lowering einsum or +// dot_general ops. +// -- einsum and dot_general ops have multiple contracting and output +// dimensions that will to be reshaped and transposed to match the +// BatchMatMulOp's LHS and RHS restrictions. +// +// This file contains utility functions to identify the reshapes and transposes +// around BatchMatMulOp and see if they can be fused. + +// A struct to hold axes and sizes for a set of dimensions. +struct DimensionVector { + llvm::ArrayRef AxesArray() const { return axes; } + llvm::ArrayRef SizesArray() const { return sizes; } + + llvm::SmallVector axes; + llvm::SmallVector sizes; +}; + +// A struct to hold information about dimensions of dot_general operands. +class BatchMatMulDimensionsInfo { + public: + BatchMatMulDimensionsInfo(mlir::ShapedType type, bool is_lhs); + const DimensionVector& batch_dimensions() const; + const DimensionVector& contracting_dimensions() const; + // Out dimensions are any dimensions that are neither batch nor contracting + // dimensions, hence will be propagated to output shape. + const DimensionVector& out_dimensions() const; + bool is_lhs() const; + + private: + DimensionVector batch_dimensions_; + DimensionVector contracting_dimensions_; + // Out dimensions are any dimensions that are neither batch nor contracting + // dimensions, hence will be propagated to output shape. + DimensionVector out_dimensions_; + bool is_lhs_; +}; + +// Returns the dimensions info of the LHS of BatchMatMulOp. +BatchMatMulDimensionsInfo GetBatchMatMulLhsDimensionsInfo( + mlir::ShapedType type); + +// Returns the dimensions info of the RHS of BatchMatMulOp. +BatchMatMulDimensionsInfo GetBatchMatMulRhsDimensionsInfo( + mlir::ShapedType type); + +// Returns true if the product of the last few dimensions in the +// `reshape_input_shape` is equal to the contracting dimension of the +// `bmm_dimensions_info`. +bool HasFlattenedContractingDims( + llvm::ArrayRef reshape_input_shape, + const BatchMatMulDimensionsInfo& bmm_dimensions_info); + +// Returns true if the product of the first few dimensions in the +// `reshape_input_shape` is equal to the output dimension of the +// `bmm_dimensions_info`. +bool HasFlattenedOutDims(llvm::ArrayRef reshape_input_shape, + const BatchMatMulDimensionsInfo& bmm_dimensions_info); + +// Returns true if the contracting and output dimensions are transposed in the +// `transpose_permutation`. +bool HasTransposedContractingAndOutDims( + llvm::ArrayRef transpose_input_shape, + llvm::ArrayRef transpose_permutation, + const BatchMatMulDimensionsInfo& bmm_dimensions_info); + +// `transpose_permutation` is the permutation of the input shape of the +// transpose op. `transpose_input_shape` is the shape of the input of the +// transpose op. `bmm_dimensions_info` is the dimensions info of the +// BatchMatMulOp. +// +// The dimensions in the transpose_permutation can be split into three groups: +// 1. Batch dimensions +// 2. Contracting dimensions +// 3. Output dimensions +// +// - The number of dimensions and the order of the dimensions in the +// batch-dimensions group is expected to match the batch dimensions of the +// BatchMatMulOp. +// - The number of dimensions in the contracting-dimensions and +// output-dimensions groups can be more than 1. +// - The dimensions in group 1 are expected to be a monotonically increasing +// sequence. +// - The dimensions in group 2 and 3 need not be a monotonically increasing +// sequence. +// - In this function, we only care if the groups 2 and 3 are transposed. +// +// For example, consider the following transpose_permutation- +// [0, 1, 2, 6, 7, 8, 3, 4, 5]. Here all the three groups are monotonically +// increasing. But other permutations like [0, 1, 2, 8, 7, 6, 4, 5, 3] and [0, +// 1, 2, 6, 7, 8, 3, 5, 4] are also valid. +// +// NOTE: The first version of this function will support the case where all the +// three groups are monotonically increasing. +std::tuple, std::pair> +GetTransposedGroupsIndexRange(llvm::ArrayRef transpose_permutation); + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_OPTIMIZE_BATCH_MATMUL_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils_test.cc b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils_test.cc new file mode 100644 index 00000000000000..cf026d8c8169e2 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils_test.cc @@ -0,0 +1,168 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h" + +#include +#include +#include + +#include +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace { + +TEST(OptimizeBatchMatmulUtilsTest, BatchMatMulDimensionsInfo) { + mlir::MLIRContext context; + mlir::ShapedType type = mlir::RankedTensorType::get( + {1, 2, 3, 4, 5}, mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo lhs_info(type, /*is_lhs=*/true); + EXPECT_EQ(lhs_info.batch_dimensions().AxesArray(), + llvm::ArrayRef({0, 1, 2})); + EXPECT_EQ(lhs_info.batch_dimensions().SizesArray(), + llvm::ArrayRef({1, 2, 3})); + EXPECT_EQ(lhs_info.contracting_dimensions().AxesArray(), + llvm::ArrayRef({4})); + EXPECT_EQ(lhs_info.contracting_dimensions().SizesArray(), + llvm::ArrayRef({5})); + EXPECT_EQ(lhs_info.out_dimensions().AxesArray(), + llvm::ArrayRef({3})); + EXPECT_EQ(lhs_info.out_dimensions().SizesArray(), + llvm::ArrayRef({4})); + EXPECT_TRUE(lhs_info.is_lhs()); + + BatchMatMulDimensionsInfo rhs_info(type, /*is_lhs=*/false); + EXPECT_EQ(rhs_info.batch_dimensions().AxesArray(), + llvm::ArrayRef({0, 1, 2})); + EXPECT_EQ(rhs_info.batch_dimensions().SizesArray(), + llvm::ArrayRef({1, 2, 3})); + EXPECT_EQ(rhs_info.contracting_dimensions().AxesArray(), + llvm::ArrayRef({3})); + EXPECT_EQ(rhs_info.contracting_dimensions().SizesArray(), + llvm::ArrayRef({4})); + EXPECT_EQ(rhs_info.out_dimensions().AxesArray(), + llvm::ArrayRef({4})); + EXPECT_EQ(rhs_info.out_dimensions().SizesArray(), + llvm::ArrayRef({5})); + EXPECT_FALSE(rhs_info.is_lhs()); +} + +TEST(OptimizeBatchMatmulUtilsTest, HasFlattenedContractingDims) { + mlir::MLIRContext context; + mlir::ShapedType type = mlir::RankedTensorType::get( + {1, 2, 3, 4, 50}, mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo lhs_info(type, /*is_lhs=*/true); + EXPECT_TRUE(HasFlattenedContractingDims({1, 2, 3, 4, 5, 10}, lhs_info)); + EXPECT_FALSE(HasFlattenedContractingDims({1, 2, 3, 4, 10}, lhs_info)); + + type = mlir::RankedTensorType::get({1, 2, 12, 5}, + mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo rhs_info(type, /*is_lhs=*/false); + EXPECT_TRUE(HasFlattenedContractingDims({1, 2, 3, 4, 5}, rhs_info)); + EXPECT_FALSE(HasFlattenedContractingDims({1, 2, 3, 4, 10}, rhs_info)); + + type = mlir::RankedTensorType::get({4, 50}, mlir::Float32Type::get(&context)); + lhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/true); + EXPECT_TRUE(HasFlattenedContractingDims({4, 5, 10}, lhs_info)); + EXPECT_FALSE(HasFlattenedContractingDims({4, 10}, lhs_info)); + + type = mlir::RankedTensorType::get({12, 5}, mlir::Float32Type::get(&context)); + rhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/false); + EXPECT_TRUE(HasFlattenedContractingDims({3, 4, 5}, rhs_info)); + EXPECT_FALSE(HasFlattenedContractingDims({3, 4, 10}, rhs_info)); +} + +TEST(OptimizeBatchMatmulUtilsTest, HasFlattenedOutDims) { + mlir::MLIRContext context; + mlir::ShapedType type = mlir::RankedTensorType::get( + {1, 2, 12, 5}, mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo lhs_info(type, /*is_lhs=*/true); + EXPECT_TRUE(HasFlattenedOutDims({1, 2, 3, 4, 5}, lhs_info)); + EXPECT_FALSE(HasFlattenedOutDims({1, 2, 3, 4, 10}, lhs_info)); + + type = mlir::RankedTensorType::get({1, 2, 12, 10}, + mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo rhs_info(type, /*is_lhs=*/false); + EXPECT_TRUE(HasFlattenedOutDims({1, 2, 12, 5, 2}, rhs_info)); + EXPECT_FALSE(HasFlattenedOutDims({1, 2, 3, 4, 10}, rhs_info)); + + type = mlir::RankedTensorType::get({12, 5}, mlir::Float32Type::get(&context)); + lhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/true); + EXPECT_TRUE(HasFlattenedOutDims({3, 4, 5}, lhs_info)); + EXPECT_FALSE(HasFlattenedOutDims({3, 4, 10}, lhs_info)); + + type = + mlir::RankedTensorType::get({12, 10}, mlir::Float32Type::get(&context)); + rhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/false); + EXPECT_TRUE(HasFlattenedOutDims({12, 5, 2}, rhs_info)); + EXPECT_FALSE(HasFlattenedOutDims({3, 4, 10}, rhs_info)); +} + +TEST(OptimizeBatchMatmulUtilsTest, GetTransposedGroupsIndexRange) { + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2, 6, 7, 8, 3, 4, 5}), + std::make_tuple(std::make_pair(3, 5), std::make_pair(6, 8))); + EXPECT_EQ(GetTransposedGroupsIndexRange({2, 0, 1}), + std::make_tuple(std::make_pair(0, 0), std::make_pair(1, 2))); + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2, 3, 7, 8, 4, 5, 6}), + std::make_tuple(std::make_pair(4, 5), std::make_pair(6, 8))); + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2, 3, 8, 7, 4, 5, 6}), + std::make_tuple(std::make_pair(-1, -1), std::make_pair(-1, -1))); + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2}), + std::make_tuple(std::make_pair(-1, -1), std::make_pair(-1, -1))); + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2}), + std::make_tuple(std::make_pair(-1, -1), std::make_pair(-1, -1))); + EXPECT_EQ(GetTransposedGroupsIndexRange({}), + std::make_tuple(std::make_pair(-1, -1), std::make_pair(-1, -1))); +} + +TEST(OptimizeBatchMatmulUtilsTest, HasTransposedContractingAndOutDims) { + mlir::MLIRContext context; + mlir::ShapedType type = mlir::RankedTensorType::get( + {1, 2, 3, 504, 120}, mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo lhs_info(type, /*is_lhs=*/true); + EXPECT_TRUE(HasTransposedContractingAndOutDims( + {1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 6, 7, 8, 3, 4, 5}, lhs_info)); + EXPECT_FALSE(HasTransposedContractingAndOutDims( + {1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 8, 7, 6, 4, 5, 3}, lhs_info)); + + BatchMatMulDimensionsInfo rhs_info(type, /*is_lhs=*/false); + EXPECT_TRUE(HasTransposedContractingAndOutDims( + {1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 6, 7, 8, 3, 4, 5}, rhs_info)); + EXPECT_FALSE(HasTransposedContractingAndOutDims( + {1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 8, 7, 6, 4, 5, 3}, rhs_info)); + + type = + mlir::RankedTensorType::get({504, 120}, mlir::Float32Type::get(&context)); + lhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/true); + EXPECT_TRUE(HasTransposedContractingAndOutDims({4, 5, 6, 7, 8, 9}, + {3, 4, 5, 0, 1, 2}, lhs_info)); + EXPECT_FALSE(HasTransposedContractingAndOutDims( + {4, 5, 6, 7, 8, 9}, {5, 4, 3, 1, 2, 0}, lhs_info)); + + rhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/false); + EXPECT_TRUE(HasTransposedContractingAndOutDims({4, 5, 6, 7, 8, 9}, + {3, 4, 5, 0, 1, 2}, rhs_info)); + EXPECT_FALSE(HasTransposedContractingAndOutDims( + {4, 5, 6, 7, 8, 9}, {5, 4, 3, 1, 2, 0}, rhs_info)); +} + +} // namespace +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/utils.h b/tensorflow/compiler/mlir/lite/utils/utils.h index 6779dac5ad5e8c..88088b5799e701 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.h +++ b/tensorflow/compiler/mlir/lite/utils/utils.h @@ -20,15 +20,20 @@ limitations under the License. #include #include #include +#include #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project @@ -59,13 +64,28 @@ inline bool IsPosInfiniteValue(APFloat value) { return value.isInfinity(); } +// Returns 1D 32-bit dense elements attribute with the given values. +inline DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = mlir::RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(32)); + return DenseIntElementsAttr::get(ty, values); +} + +inline DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, values); +} + // Returns true if all tensor value in `values` has static shape and same shape. inline bool OpHasSameStaticShapes(Operation* op) { auto values = op->getOperands(); int operand_num = 0; ArrayRef shape; for (Value value : values) { - auto shaped_type = value.getType().dyn_cast(); + auto shaped_type = mlir::dyn_cast(value.getType()); if (!shaped_type || !shaped_type.hasStaticShape()) { return false; } @@ -165,7 +185,7 @@ inline bool IsTransposeTrivial(llvm::ArrayRef input_shape, // Returns the permutation that maps the input shape to the output shape. // This is only valid for trivial reshape ops. inline DenseElementsAttr GetPermutationFromTrivialReshape( - ShapedType input_type, ShapedType output_type) { + mlir::ShapedType input_type, mlir::ShapedType output_type) { ArrayRef in_shape = input_type.getShape(); ArrayRef out_shape = output_type.getShape(); @@ -209,8 +229,8 @@ inline DenseElementsAttr GetPermutationFromTrivialReshape( // Returns true if the reshape op is equivalent to a transpose op. // This is true if the reshape op is a trivial reshape op, meaning no change in // the order of non-identity dimensions. -inline bool IsReshapeEquivalentToTranspose(ShapedType input_type, - ShapedType output_type) { +inline bool IsReshapeEquivalentToTranspose(mlir::ShapedType input_type, + mlir::ShapedType output_type) { std::vector in_shape{input_type.getShape().vec()}; std::vector out_shape{output_type.getShape().vec()}; @@ -229,14 +249,14 @@ inline bool IsReshapeEquivalentToTranspose(ShapedType input_type, // Checks if all elements in the constant attribute value are 1. inline bool IsAllOnesConstant(Attribute value) { - auto values = value.cast().getValues(); + auto values = mlir::cast(value).getValues(); return !std::any_of(values.begin(), values.end(), [](int32_t element_value) { return element_value != 1; }); } // Checks if all elements in the constant attribute value are non-negative. inline bool HasNonNegativeValues(Attribute value) { - auto values = value.cast().getValues(); + auto values = mlir::cast(value).getValues(); return !std::any_of( values.begin(), values.end(), [](const APInt& element_value) { return element_value.isNegative(); }); @@ -244,8 +264,8 @@ inline bool HasNonNegativeValues(Attribute value) { // Utility function to get the offset between two dense attribute values. inline TypedAttr GetOffSet(Attribute begin, Attribute end) { - auto begin_values = begin.cast().getValues(); - auto end_values = end.cast().getValues(); + auto begin_values = mlir::cast(begin).getValues(); + auto end_values = mlir::cast(end).getValues(); SmallVector offsets; if (begin_values.size() == end_values.size()) { @@ -283,7 +303,7 @@ inline bool AreLastTwoDimsTransposed(Value permutation) { // Gets the new type after transposing the last 2 dimensions. inline Type TransposeLastTwoDims(Type type) { - auto shaped_type = type.dyn_cast(); + auto shaped_type = mlir::dyn_cast(type); if (!shaped_type.hasStaticShape() || shaped_type.getRank() < 2) { return nullptr; } @@ -299,9 +319,9 @@ inline Type TransposeLastTwoDims(Type type) { // Returns a ShapedType for a permutation and the shape of input after // applying the permutation to the given shape through a transpose. -inline ShapedType GetTransposedType(Value input, - llvm::ArrayRef permutation_array) { - auto input_type = input.getType().cast(); +inline mlir::ShapedType GetTransposedType( + Value input, llvm::ArrayRef permutation_array) { + auto input_type = mlir::cast(input.getType()); if (permutation_array.size() != input_type.getRank()) { return nullptr; } @@ -341,41 +361,67 @@ inline DenseElementsAttr GetExpandedShapeAttr(Value input_val, int n) { // Return the resultant shape type if the shape of the supplied attribute/value // is expanded by n leading 1s'. -inline ShapedType GetExpandedShapeType(Value input_val, int n) { +inline mlir::ShapedType GetExpandedShapeType(Value input_val, int n) { auto expanded_shape = GetExpandedShape(input_val, n); return RankedTensorType::get( SmallVector{expanded_shape.begin(), expanded_shape.end()}, mlir::cast(input_val.getType()).getElementType()); } -// Returns shape of a ranked tensor. -// Precondition: output_val's is ranked tensor. -// Returns a truncated shape when `truncate` is set to true. -inline DenseElementsAttr GetShape(Value output_val, bool truncate = false) { - auto output_shape = output_val.getType().dyn_cast().getShape(); +// Returns shape of a ranked tensor as a SmallVector. +// Precondition: input_value's is ranked tensor. +// Returns a squeezed shape when `squeeze_leading_ones` is set to true. +inline SmallVector GetShape(Value input_value, + bool squeeze_leading_ones = false) { + auto output_shape = + mlir::dyn_cast(input_value.getType()).getShape(); SmallVector shape; shape.reserve(output_shape.size()); - bool needs_truncation = true; + bool can_squeeze = true; for (size_t dim_idx = 0; dim_idx < output_shape.size(); ++dim_idx) { int64_t dim = output_shape[dim_idx]; - if (truncate && needs_truncation && dim == 1) { + if (squeeze_leading_ones && can_squeeze && dim == 1) { continue; - } else if (needs_truncation && dim != 1) { - needs_truncation = false; + } else if (can_squeeze && dim != 1) { + can_squeeze = false; } shape.push_back(ShapedType::isDynamic(dim) ? -1 : static_cast(dim)); } + return shape; +} + +// Returns shape of a ranked tensor as a DenseElementsAttr. +// Precondition: input_value's is ranked tensor. +// Returns a squeezed shape when `squeeze_leading_ones` is set to true. +inline DenseElementsAttr GetShapeAttr(Value input_value, + bool squeeze_leading_ones = false) { + SmallVector shape = GetShape(input_value, squeeze_leading_ones); return mlir::DenseElementsAttr::get( RankedTensorType::get( {static_cast(shape.size())}, - mlir::IntegerType::get(output_val.getContext(), 32)), + mlir::IntegerType::get(input_value.getContext(), 32)), llvm::ArrayRef(shape)); } +// Returns the value of a constant attribute as an int array, if the value is +// not a constant, returns an error status. +inline absl::StatusOr> GetValueAsIntArray(Value value) { + DenseElementsAttr values_const_attr; + if (!matchPattern(value, m_Constant(&values_const_attr))) { + return absl::InvalidArgumentError("Value is not a constant."); + } + + SmallVector values; + for (const auto& value : values_const_attr.getValues()) { + values.push_back(value.getSExtValue()); + } + return values; +} + //////////////////////////////////////////////////////////////////////////////// ///////////////// OP BROADCASTING UTILITIES //////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// @@ -416,6 +462,136 @@ DenseElementsAttr GetScalarOfType(Type ty, T raw_value) { llvm_unreachable("unsupported type"); } +// Checks if reduction axes and broadcast axes are disjoint. +// Broadcast axes are derived by comparing the shape of `input_val` to the shape +// represented by `target_shape_attr` according to standard broadcasting rules. +// Returns true if the sets of axes are disjoint, false otherwise or on error. +inline bool AreBroadcastAndReductionAxesIndependent( + mlir::Value input_val, const mlir::Attribute& indices_attr, + const mlir::Attribute& target_shape_attr) { + // 1. Get input type and shape. + // Use llvm::dyn_cast for safer casting. + auto ranked_input_type = + llvm::dyn_cast(input_val.getType()); + if (!ranked_input_type) { + // Consider logging or error emission if builder context is + // available/needed. + return false; // Expect ranked type. + } + llvm::ArrayRef input_shape = ranked_input_type.getShape(); + const int64_t input_rank = ranked_input_type.getRank(); + + // 2. Validate and extract reduction axes. + // Use llvm::dyn_cast for safer casting. + auto indices = llvm::dyn_cast(indices_attr); + if (!indices || !indices.getElementType().isIntOrIndex()) { + return false; // Invalid indices attribute. + } + + // Use std::set for efficient storage and lookup of axes. + std::set reduction_axes_set; + if (!indices.empty()) { // Only process if there are reduction axes. + if (input_rank == 0) { + // It's invalid to specify reduction axes for a scalar (rank 0) input. + return false; + } + + // Iterate using range-based for loop and structured binding (if applicable) + // or direct value access. + for (const mlir::APInt& axis_val : indices.getValues()) { + int64_t axis = + axis_val.getSExtValue(); // Use sign extension for neg axes. + + // Normalize axis and check bounds. + if (axis < -input_rank || axis >= input_rank) { + return false; // Axis out of bounds. + } + if (axis < 0) { + axis += input_rank; // Convert negative axis to positive. + } + reduction_axes_set.insert(axis); + } + } + + // If there are no reduction axes, they are trivially independent of any + // broadcast axes. + if (reduction_axes_set.empty()) { + return true; + } + + // 3. Validate and extract target shape for broadcast. + // Use llvm::dyn_cast for safer casting. + auto target_shape_value_attr = + llvm::dyn_cast(target_shape_attr); + if (!target_shape_value_attr || + !target_shape_value_attr.getElementType().isIntOrIndex()) { + return false; // Invalid target shape attribute. + } + + // Use llvm::SmallVector for efficient shape storage. + llvm::SmallVector target_shape_vec; + target_shape_vec.reserve( + target_shape_value_attr.getNumElements()); // Pre-allocate + for (const mlir::APInt& shape_val : + target_shape_value_attr.getValues()) { + // Assuming shape dimensions should be non-negative, consider getZExtValue. + // However, getSExtValue is safe if intermediate calculations handle signs. + target_shape_vec.push_back(shape_val.getSExtValue()); + } + // Use llvm::ArrayRef for safe, non-owning view of the shape vector. + llvm::ArrayRef target_shape = target_shape_vec; + const int64_t target_rank = target_shape.size(); + + // 4. Determine broadcast axes based on standard broadcasting rules. + std::set broadcast_axes_set; + const int64_t max_rank = std::max(input_rank, target_rank); + + // Iterate through dimensions, aligning from the right (trailing dimensions). + for (int64_t i = 0; i < max_rank; ++i) { + // Calculate indices relative to the end of the shape arrays. + const int64_t input_dim_idx = input_rank - 1 - i; + const int64_t target_dim_idx = target_rank - 1 - i; + + // Treat dimensions missing due to lower rank as having size 1. + const int64_t input_dim = + (input_dim_idx >= 0) ? input_shape[input_dim_idx] : 1; + const int64_t target_dim = + (target_dim_idx >= 0) ? target_shape[target_dim_idx] : 1; + + // Check for incompatible shapes (dimensions differ and neither is 1). + // This indicates an invalid broadcast according to NumPy rules. + if (input_dim != target_dim && input_dim != 1 && target_dim != 1) { + // Consider if the specific broadcast op allows other behaviors (e.g., + // -1). For standard rules, this is an incompatibility. + return false; + } + + // An axis in the *input* tensor is involved in broadcasting if its size is + // 1 and the corresponding target dimension size is greater than 1. + if (input_dim == 1 && target_dim > 1) { + // Ensure the axis index is valid for the input tensor's rank. + if (input_dim_idx >= 0) { + broadcast_axes_set.insert(input_dim_idx); + } + // Note: If input_dim_idx < 0, broadcasting occurs due to rank difference, + // but it doesn't correspond to an axis *within* the original input + // tensor. + } + } + + // 5. Check for intersection between the set of reduction axes and the set of + // broadcast axes derived above. + for (int64_t reduction_axis : reduction_axes_set) { + if (broadcast_axes_set.count(reduction_axis)) { + // Found an axis that is present in both sets. + return false; + } + } + + // 6. No overlapping axes were found. + return true; +} + } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/utils.td b/tensorflow/compiler/mlir/lite/utils/utils.td index d7029fe5ca7939..7583d48618f4fc 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.td +++ b/tensorflow/compiler/mlir/lite/utils/utils.td @@ -25,9 +25,9 @@ include "mlir/IR/PatternBase.td" //////////////////////////////////////////////////////////////////////////////// def IsQuantized : Constraint() && " - "$0.getType().dyn_cast().getElementType()" - ".isa()">>; + "llvm::dyn_cast($0.getType()) && " + "llvm::isa(" + "llvm::dyn_cast($0.getType()).getElementType())">>; def IsNotQuantized : Constraint>; @@ -38,42 +38,42 @@ def IsNotQuantized : Constraint>; // Checks if the rank of the value is less than or equal to the rank of the // other value. def IsRankLessThanEqualTo : Constraint().getRank() <= " - "$1.getType().cast().getRank()">>; + "llvm::cast($0.getType()).getRank() <= " + "llvm::cast($1.getType()).getRank()">>; // Checks if the value has rank at most 'n'. class HasRankAtMost : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() <= " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() <= " # n>>; //////////////////////////////////////////////////////////////////////////////// ///////////////// DENSE UTILITIES ///////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// -def DenseFPElementsAttrPred : CPred<"$_self.isa()">; -def DenseIntElementsAttrPred : CPred<"$_self.isa()">; +def DenseFPElementsAttrPred : CPred<"llvm::isa($_self)">; +def DenseIntElementsAttrPred : CPred<"llvm::isa($_self)">; //////////////////////////////////////////////////////////////////////////////// ///////////////// SPLAT CONSTANT UTILITIES ///////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// def DenseElementsAttrIsSplatPred - : CPred<"$_self.cast().isSplat()">; + : CPred<"llvm::cast($_self).isSplat()">; class DenseFPElementsAttrSplatValueEqualToPred - : CPred<"$_self.cast().getSplatValue()" + : CPred<"llvm::cast($_self).getSplatValue()" ".getValueAsDouble() == " # val>; class DenseFPElementsAttrSplatValueEqualToPredWithTolerance - : CPred<"std::abs($_self.cast().getSplatValue()" + : CPred<"std::abs(llvm::cast($_self).getSplatValue()" ".getValueAsDouble() - " # val # ") <= "#tolerance>; class DenseIntElementsAttrSplatValueEqualToPred - : CPred<"$_self.isa() && " - "$_self.cast().getElementType()" - " .isa() && " - "$_self.cast().isSplat() && " - "$_self.cast().getSplatValue()" + : CPred<"llvm::isa($_self) && " + "llvm::isa(" + "llvm::cast($_self).getElementType()) && " + "llvm::cast($_self).isSplat() && " + "llvm::cast($_self).getSplatValue()" " .getValue().getSExtValue() == " # val>; // AttrConstraint to match a floating point dense elements attribute with a @@ -110,8 +110,8 @@ def SplatIntElementsAttr : ElementsAttrBase< def GetScalarElementsAttrFromSplat : NativeCodeCall< "DenseElementsAttr::get(" " RankedTensorType::get({}," - " $0.cast().getType().getElementType())," - " $0.cast().getSplatValue())">; + " llvm::cast($0).getType().getElementType())," + " llvm::cast($0).getSplatValue())">; //////////////////////////////////////////////////////////////////////////////// ///////////////// OP BROADCASTING UTILITIES //////////////////////////////////// @@ -129,10 +129,10 @@ def OperandsDontBroadcastToOutputType : Constraint().hasStaticShape() && " - "$1.getType().cast().hasStaticShape() && " - "$0.getType().cast().getShape() ==" - "$1.getType().cast().getShape()">, + CPred<"llvm::cast($0.getType()).hasStaticShape() && " + "llvm::cast($1.getType()).hasStaticShape() && " + "llvm::cast($0.getType()).getShape() ==" + "llvm::cast($1.getType()).getShape()">, "have the same static shape">; def CreateNoneValue : NativeCodeCall< @@ -140,7 +140,7 @@ def CreateNoneValue : NativeCodeCall< // Returns shape of a ranked tensor. // if called without a ranked tensor it will fail. -def GetShape: NativeCodeCall<"GetShape($0)">; +def GetShapeAttr: NativeCodeCall<"GetShapeAttr($0)">; // Return the resultant shape if the shape of the supplied attribute/value is // expanded by n leading 1s'. @@ -159,7 +159,7 @@ def IsAllOnesConstant : Constraint>; // the permutation is a cyclic permutation of the original shape with only the // identity dimensions permuted. def IsTransposeTrivial : Constraint().getShape(), $1)">>; + "TFL::IsTransposeTrivial(llvm::cast($0.getType()).getShape(), $1)">>; // Constraint that checks if the transpose op is a no-op. def IsTransposeNoop : Constraint>; @@ -169,15 +169,15 @@ def IsTransposeNoop : Constraint>; // the order of non-identity dimensions. def IsReshapeEquivalentToTranspose : Constraint()," - "$1.getType().cast())">>; + "llvm::cast($0.getType())," + "llvm::cast($1.getType()))">>; // Returns the permutation of the trivial reshape op, this will be used to // construct the transpose op. def GetPermutationFromTrivialReshape : NativeCodeCall< "TFL::GetPermutationFromTrivialReshape(" - "$0.getType().cast()," - "$1.getType().cast())">; + "llvm::cast($0.getType())," + "llvm::cast($1.getType()))">; // Constraint that checks if all values in offset between two // attributes are non-negative. @@ -191,12 +191,12 @@ def GetOffSet : NativeCodeCall<"TFL::GetOffSet($0, $1)">; // Attribute Constraint that checks if the attribute value is zero. def ZeroIntAttr - : AttrConstraint().getInt() == 0">>; + : AttrConstraint($_self).getInt() == 0">>; // Checks if the value has rank at most 'n'. class HasRankAtLeast : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() >= " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() >= " # n>>; // Accepts two inputs and check if both have the same element type. def SameElementType : Constraint< @@ -227,7 +227,7 @@ def AreLastTwoDimsTransposed : Constraint>; // Checks if the param passed is of NoneType. -def IsNoneType : Constraint()">>; +def IsNoneType : Constraint($0.getType())">>; def ConstantLikePred : CPred<"::mlir::matchPattern($0, ::mlir::m_Constant())">; def IsConstantLike : Constraint; diff --git a/tensorflow/compiler/mlir/lite/utils/utils_test.cc b/tensorflow/compiler/mlir/lite/utils/utils_test.cc new file mode 100644 index 00000000000000..f4e37480b2b035 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/utils_test.cc @@ -0,0 +1,128 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/utils/utils.h" + +#include + +#include +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace { + +// Test fixture for AreBroadcastAndReductionAxesIndependent function. +class BroadcastAndReductionAxesIndependentTest : public ::testing::Test { + protected: + BroadcastAndReductionAxesIndependentTest() : builder_(&context_) { + context_.loadDialect(); + } + + // Builds an mlir::Value representing a tensor with the given shape. + Value BuildTensor(ArrayRef shape) { + return builder_.create( + builder_.getUnknownLoc(), + RankedTensorType::get(shape, builder_.getF32Type()), + builder_.getZeroAttr( + RankedTensorType::get(shape, builder_.getF32Type()))); + } + + // Builds a DenseElementsAttr representing an integer array. + DenseElementsAttr BuildIntArrayAttr(ArrayRef values) { + return DenseElementsAttr::get( + RankedTensorType::get({static_cast(values.size())}, + builder_.getI32Type()), + values); + } + + MLIRContext context_; + OpBuilder builder_; +}; + +TEST_F(BroadcastAndReductionAxesIndependentTest, IndependentAxes) { + Value input_tensor = BuildTensor({2, 1, 4, 1}); + DenseElementsAttr reduction_axes = BuildIntArrayAttr({0, 2}); + DenseElementsAttr target_shape = BuildIntArrayAttr({2, 3, 4, 5}); + + EXPECT_TRUE(AreBroadcastAndReductionAxesIndependent( + input_tensor, reduction_axes, target_shape)); + input_tensor.getDefiningOp()->destroy(); +} + +TEST_F(BroadcastAndReductionAxesIndependentTest, OverlappingAxes) { + Value input_tensor = BuildTensor({1, 3, 4, 5}); + DenseElementsAttr reduction_axes = BuildIntArrayAttr({0, 2}); + DenseElementsAttr target_shape = BuildIntArrayAttr({2, 3, 4, 5}); + + EXPECT_FALSE(AreBroadcastAndReductionAxesIndependent( + input_tensor, reduction_axes, target_shape)); + input_tensor.getDefiningOp()->destroy(); +} + +TEST_F(BroadcastAndReductionAxesIndependentTest, EmptyReductionAxes) { + Value input_tensor = BuildTensor({1, 3, 1, 5}); + DenseElementsAttr reduction_axes = BuildIntArrayAttr({}); + DenseElementsAttr target_shape = BuildIntArrayAttr({2, 3, 4, 5}); + + EXPECT_TRUE(AreBroadcastAndReductionAxesIndependent( + input_tensor, reduction_axes, target_shape)); + input_tensor.getDefiningOp()->destroy(); +} + +TEST_F(BroadcastAndReductionAxesIndependentTest, UnrankedInput) { + Value input_tensor = builder_.create( + builder_.getUnknownLoc(), builder_.getF32Type(), + builder_.getZeroAttr(builder_.getF32Type())); + DenseElementsAttr reduction_axes = BuildIntArrayAttr({0, 2}); + DenseElementsAttr target_shape = BuildIntArrayAttr({2, 3, 4, 5}); + + EXPECT_FALSE(AreBroadcastAndReductionAxesIndependent( + input_tensor, reduction_axes, target_shape)); + input_tensor.getDefiningOp()->destroy(); +} + +TEST_F(BroadcastAndReductionAxesIndependentTest, InvalidReductionAxesType) { + Value input_tensor = BuildTensor({2, 3, 4, 5}); + DenseElementsAttr reduction_axes = DenseElementsAttr::get( + RankedTensorType::get({2}, builder_.getF32Type()), {1.0f, 2.0f}); + DenseElementsAttr target_shape = BuildIntArrayAttr({1, 3, 1, 5}); + + EXPECT_FALSE(AreBroadcastAndReductionAxesIndependent( + input_tensor, reduction_axes, target_shape)); + input_tensor.getDefiningOp()->destroy(); +} + +TEST_F(BroadcastAndReductionAxesIndependentTest, InvalidTargetShapeType) { + Value input_tensor = BuildTensor({2, 3, 4, 5}); + DenseElementsAttr reduction_axes = BuildIntArrayAttr({0, 2}); + DenseElementsAttr target_shape = DenseElementsAttr::get( + RankedTensorType::get({2}, builder_.getF32Type()), {1.0f, 2.0f}); + + EXPECT_FALSE(AreBroadcastAndReductionAxesIndependent( + input_tensor, reduction_axes, target_shape)); + input_tensor.getDefiningOp()->destroy(); +} + +} // namespace +} // namespace TFL + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/variables_utils.cc b/tensorflow/compiler/mlir/lite/utils/variables_utils.cc index 0cab3ff3db32fd..fe13b43c0163ba 100644 --- a/tensorflow/compiler/mlir/lite/utils/variables_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/variables_utils.cc @@ -29,17 +29,15 @@ namespace utils { bool IsSupportedVariableType(Operation* op) { ShapedType type; if (llvm::isa(op)) { - type = op->getResult(0).getType().cast(); + type = llvm::cast(op->getResult(0).getType()); } else if (llvm::isa(op)) { - type = op->getOperand(1).getType().cast(); + type = llvm::cast(op->getOperand(1).getType()); } else if (llvm::isa(op)) { - type = op->getResult(0) - .getType() - .cast() - .getElementType() - .cast() - .GetSubtypes() - .back(); + type = + llvm::cast( + llvm::cast(op->getResult(0).getType()).getElementType()) + .GetSubtypes() + .back(); } return IsSupportedVariableType(type); } @@ -47,13 +45,13 @@ bool IsSupportedVariableType(Operation* op) { bool IsSupportedVariableType(ShapedType type) { auto element_type = type.getElementType(); // Check complex types. - if (auto complex_type = element_type.dyn_cast()) { + if (auto complex_type = llvm::dyn_cast(element_type)) { auto complex_element_type = complex_type.getElementType(); if (complex_element_type.isF32() || complex_element_type.isF64()) return true; } // Check quantized types. - if (auto quant_type = element_type.dyn_cast()) { + if (auto quant_type = llvm::dyn_cast(element_type)) { // TFLite supports QI16, QI32, QI8, and QUI8 if ((quant_type.getStorageTypeIntegralWidth() == 16 && quant_type.isSigned()) || diff --git a/tensorflow/compiler/mlir/quantization/common/BUILD b/tensorflow/compiler/mlir/quantization/common/BUILD index e71f093f070b03..60423044c8a535 100644 --- a/tensorflow/compiler/mlir/quantization/common/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/BUILD @@ -25,6 +25,37 @@ td_library( ], ) +cc_library( + name = "tf_lift_as_function_call", + srcs = ["tf_lift_as_function_call.cc"], + hdrs = ["tf_lift_as_function_call.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":tf_attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo:stablehlo_type_utils", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:quantization_unit_loc", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", + "//tensorflow/core:framework_lite", + "//tensorflow/core/ir/types:Dialect", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@stablehlo//:version", + ], +) + cc_library( name = "lift_as_function_call", srcs = ["lift_as_function_call.cc"], @@ -138,6 +169,34 @@ cc_library( ], ) +cc_library( + name = "tf_attrs_and_constraints", + srcs = [ + "tf_attrs_and_constraints.cc", + ], + hdrs = [ + "tf_attrs_and_constraints.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":tf_uniform_quantized_types", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], +) + cc_library( name = "attrs_and_constraints", srcs = [ @@ -201,6 +260,19 @@ td_library( ], ) +cc_library( + name = "tf_uniform_quantized_types", + srcs = ["tf_uniform_quantized_types.cc"], + hdrs = ["tf_uniform_quantized_types.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "uniform_quantized_types", srcs = ["uniform_quantized_types.cc"], diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td index 1921345d601283..b6085d30f656c4 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td @@ -17,7 +17,7 @@ include "mlir/IR/PatternBase.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" def DenseElementsAttr : ElementsAttrBase< - CPred<"$_self.isa()">, + CPred<"llvm::isa($_self)">, "non-opaque constant tensor">; // Checks if the data format is "NHWC". @@ -31,13 +31,13 @@ def IsConstTensor : Constraint($0.getDefin // Checks if the element value has a float type. def IsFloatElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() && " - "getElementTypeOrSelf($_self.cast().getType()).isa()">, + CPred<"llvm::isa($_self) && " + "llvm::isa(getElementTypeOrSelf(llvm::cast($_self).getType()))">, "float constant tensor">; // Checks if the boolean value is false. def IsFalseBoolAttr : AttrConstraint< - CPred<"!$_self.cast().getValue()">>; + CPred<"!llvm::cast($_self).getValue()">>; // Checks if the value has only one user. def HasOneUse : Constraint>; @@ -63,7 +63,7 @@ def IsBF16ElementType : Constraint< // Checks if the value has the type of UniformQuantizedType. def IsUniformQuantizedType : Constraint< - CPred<"getElementTypeOrSelf($0).isa()">>; + CPred<"llvm::isa(getElementTypeOrSelf($0))">>; // Checks if the given two values have the same type. def AreTheSameElementType : Constraint< @@ -75,12 +75,12 @@ def AreTheSameValue : Constraint< // Checks if the value has rank. def HasRank : Constraint< - CPred<"$0.getType().cast().hasRank()">>; + CPred<"llvm::cast($0.getType()).hasRank()">>; // Checks if the value has rank of `n`. class HasRankOf : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() == " # n>, + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() == " # n>, "Checks if the value has rank of 'n'.">; // Checks if the value has static shape. diff --git a/tensorflow/compiler/mlir/quantization/common/ir/BUILD b/tensorflow/compiler/mlir/quantization/common/ir/BUILD index 2821bb96a66950..162c14c4ad70f9 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/ir/BUILD @@ -25,30 +25,18 @@ td_library( gentbl_cc_library( name = "QuantOpsIncGen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "QuantOps.h.inc", - ), - ( - ["-gen-op-defs"], - "QuantOps.cc.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect=quantization", - ], - "QuantOpsDialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - "-dialect=quantization", - ], - "QuantOpsDialect.cc.inc", - ), - ], + tbl_outs = { + "QuantOps.h.inc": ["-gen-op-decls"], + "QuantOps.cc.inc": ["-gen-op-defs"], + "QuantOpsDialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=quantization", + ], + "QuantOpsDialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=quantization", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "QuantOps.td", deps = [":QuantizationOpsTdFiles"], @@ -57,15 +45,10 @@ gentbl_cc_library( gentbl_cc_library( name = "QuantPassIncGen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=tfquant", - ], - "Passes.h.inc", - ), - ], + tbl_outs = {"Passes.h.inc": [ + "-gen-pass-decls", + "-name=tfquant", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "Passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc index bf0cf8aa2ba9de..c4d1fc32a70549 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc @@ -491,7 +491,7 @@ bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr) { rhs_out_idx_start >= batch_dim_size; } -absl::StatusOr GetQuantizationMethod(absl::Nonnull op) { +absl::StatusOr GetQuantizationMethod(Operation* absl_nonnull op) { const auto quantization_method_attr = op->getAttrOfType(kQuantizationMethodAttr); if (!quantization_method_attr) { @@ -509,7 +509,7 @@ absl::StatusOr GetQuantizationMethod(absl::Nonnull op) { return quantization_method; } -Method GetQuantizationMethodOrDefault(absl::Nonnull op) { +Method GetQuantizationMethodOrDefault(Operation* absl_nonnull op) { absl::StatusOr method = GetQuantizationMethod(op); if (method.status().code() == absl::StatusCode::kInternal) { // This indicates that the `Method` protobuf string is corrupt, but this diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h index 22e0307f4a9e17..b9faba72f147ec 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h @@ -70,14 +70,14 @@ bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr); // `absl::InternalError` when parsing the attribute to `Method` failed. // `op` must be non-null. absl::StatusOr<::stablehlo::quantization::Method> GetQuantizationMethod( - absl::Nonnull op); + Operation* absl_nonnull op); // Gets the quantization method from `op`. It is retrieved from the // `kQuantizationMethodAttr` string attribute. Returns a default instance of // `Method` iff the attribute doesn't exist or the attribute contains an invalid // textproto for `Method`. `op` must be non-null. ::stablehlo::quantization::Method GetQuantizationMethodOrDefault( - absl::Nonnull op); + Operation* absl_nonnull op); // Creates a function to wrap the section between arguments and results. // The generated function call op type will be decided by the given call_op_type diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD b/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD index b6b1d17d17a4a7..36b7152c15ff02 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD @@ -102,16 +102,10 @@ td_library( gentbl_cc_library( name = "quantization_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "quantization_interface.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "quantization_interface.cc.inc", - ), - ], + tbl_outs = { + "quantization_interface.h.inc": ["-gen-op-interface-decls"], + "quantization_interface.cc.inc": ["-gen-op-interface-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "quantization.td", deps = [ diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td index 0f9b6a74762f9b..706eb8552eb1ff 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td @@ -31,12 +31,12 @@ include "mlir/Dialect/Quant/IR/QuantBase.td" // explicit signedness check to differentiate the signed/unsigned constraints // predicates from one another at the TD level. class QuantizedType params, bit signed> - : Type()">, - CPred<"$_self.cast()" # + : Type($_self)">, + CPred<"llvm::cast($_self)" # ".getStorageTypeIntegralWidth() == " # !head(params)>, - Or<[CPred<"$_self.cast()" # + Or<[CPred<"llvm::cast($_self)" # ".getStorageType().isSignlessInteger()">, - CPred<"$_self.cast()" # + CPred<"llvm::cast($_self)" # ".getStorageType().isSignedInteger() == " # signed>]>]>, "Q" # !if (signed, "I", "UI") # !head(params) # " type"> { string name = n; diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h index 94169e3e9436c1..51dbc257d3b7d8 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h @@ -200,7 +200,7 @@ bool QuantizableOpSupportsFloatOutputType(Operation* op); // Specialized version of location to string for flatbuffer exported locations. inline std::string GetTensorNameFromLoc(Location loc) { - if (auto name_loc = loc.dyn_cast()) { + if (auto name_loc = llvm::dyn_cast(loc)) { return name_loc.getName().str(); } return ""; @@ -218,7 +218,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { LogicalResult matchAndRewrite(quantfork::StatisticsOp op, PatternRewriter& rewriter) const override { - Type expressed = op.getType().cast().getElementType(); + Type expressed = llvm::cast(op.getType()).getElementType(); quant::QuantizedType quant_type; SmallVector mins, maxs; @@ -226,7 +226,8 @@ struct ConvertStatsToQDQs : public OpRewritePattern { // Per axis quantization (or per channel quantization) int stats_num = op.getAxisStats()->getNumElements(); if (stats_num == 0 || stats_num % 2 != 0) return failure(); - auto stats = op.getAxisStats()->dyn_cast(); + auto stats = + llvm::dyn_cast(op.getAxisStats().value()); if (!stats) return failure(); for (auto it = stats.begin(), e = stats.end(); it != e; ++it) { @@ -255,7 +256,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { quant_type = DownCastScale(quant_type, mins, maxs, op->getLoc()); } } else if (auto stats = - op.getLayerStats().dyn_cast()) { + llvm::dyn_cast(op.getLayerStats())) { // Per tensor quantization auto statValues = stats.getValues(); double rmin = FloatAttr::getValueAsDouble(statValues[0]); @@ -481,7 +482,7 @@ class QuantizationPattern : public RewritePattern { } if (!nodes_blocklist.empty()) { - if (auto name_loc = quantizing_op->getLoc().dyn_cast()) { + if (auto name_loc = llvm::dyn_cast(quantizing_op->getLoc())) { std::string sloc = name_loc.getName().str(); if (!sloc.empty() && (nodes_blocklist.find(sloc) != nodes_blocklist.end())) { @@ -503,12 +504,13 @@ class QuantizationPattern : public RewritePattern { inputs.reserve(quantizing_op->getNumOperands()); for (auto operand : quantizing_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (llvm::isa(operand_type)) { inputs.push_back(operand); continue; } - auto ele_type = operand.getType().cast().getElementType(); + auto ele_type = + llvm::cast(operand.getType()).getElementType(); if (static_cast(this) ->AllowDynamicRangeQuantizedOperand(quantizing_op, custom_map)) { @@ -568,13 +570,13 @@ class QuantizationPattern : public RewritePattern { Type result_type = result.getType(); // Add this to the test coverage once we create test ops with none // type results. - if (result_type.isa()) { + if (llvm::isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; } Type result_ele_type = - result.getType().cast().getElementType(); + llvm::cast(result.getType()).getElementType(); // If the user is the QuantizeOp, it must be the only user. if (result.hasOneUse() && llvm::isa(*result.user_begin())) { @@ -648,11 +650,9 @@ class QuantizationPattern : public RewritePattern { } for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { - if (!quantizing_op->getResult(i) - .getType() - .cast() - .getElementType() - .isa()) { + if (!llvm::isa( + llvm::cast(quantizing_op->getResult(i).getType()) + .getElementType())) { continue; } CreateVerifier(quantizing_op, quantized_op, rewriter, i, @@ -673,9 +673,7 @@ class QuantizationPattern : public RewritePattern { void RewireFloatModelBackbone(Operation* quantized_op, Operation* float_op) const { for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { - if (!float_op->getResult(i) - .getType() - .cast() + if (!llvm::cast(float_op->getResult(i).getType()) .getElementType() .isF32()) { continue; @@ -768,14 +766,14 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { auto flags = quant::QuantizationFlags::Signed; QType new_qtype; - if (auto uqtype = qtype.template dyn_cast()) { + if (auto uqtype = llvm::dyn_cast(qtype)) { new_qtype = quant::UniformQuantizedType::getChecked( op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(), uqtype.getScale(), uqtype.getZeroPoint() - offset, uqtype.getStorageTypeMin() - offset, uqtype.getStorageTypeMax() - offset); - } else if (auto aqtype = qtype.template dyn_cast< - quant::UniformQuantizedPerAxisType>()) { + } else if (auto aqtype = + llvm::dyn_cast(qtype)) { auto zero_points = aqtype.getZeroPoints(); llvm::SmallVector new_zero_points(zero_points.begin(), zero_points.end()); diff --git a/tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.cc b/tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.cc new file mode 100644 index 00000000000000..c19b7680b36c10 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.cc @@ -0,0 +1,184 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" + +namespace mlir::tf_quant { + +using ::mlir::stablehlo::DotGeneralOp; + +bool HasStaticShape(Value value) { + auto shaped_type = mlir::dyn_cast(value.getType()); + if (!shaped_type) return false; + + return shaped_type.hasStaticShape(); +} + +bool HasStaticShapeAtDims(Value value, const ArrayRef dims) { + auto shaped_type = mlir::dyn_cast(value.getType()); + if (!shaped_type || !shaped_type.hasRank()) return false; + + for (auto dim : dims) { + if (shaped_type.isDynamicDim(dim)) return false; + } + return true; +} + +Type CloneTypeWithNewElementType(Type old_type, Type element_type) { + if (!mlir::isa(old_type)) return {}; + + return mlir::cast(old_type).clone(element_type); +} + +SmallVector CloneOpWithReplacedOperands( + OpBuilder& builder, Operation* op, const ArrayRef new_operands) { + IRMapping mapping; + for (const auto& arg : enumerate(new_operands)) { + mapping.map(op->getOperand(arg.index()), arg.value()); + } + return builder.clone(*op, mapping)->getResults(); +} + +FailureOr CastI64ToI32(const int64_t value) { + if (!llvm::isInt<32>(value)) { + DEBUG_WITH_TYPE( + "mlir-quant-attrs-and-constraints", + llvm::dbgs() + << "Tried to cast " << value + << "from int64 to int32, but lies out of range of int32.\n"); + return failure(); + } + return static_cast(value); +} + +FailureOr> CastI64ArrayToI32( + const ArrayRef int64_array) { + SmallVector int32_array{}; + int32_array.reserve(int64_array.size()); + + for (const int64_t i64 : int64_array) { + FailureOr cast_i32 = CastI64ToI32(i64); + if (failed(cast_i32)) return failure(); + + int32_array.push_back(*cast_i32); + } + return int32_array; +} + +StringRef GetEntryFunctionName(TF::XlaCallModuleOp op) { + if (!op->hasAttrOfType( + TF::kStablehloEntryFunctionAttrName)) { + return StringRef(); + } + return op + ->getAttrOfType(TF::kStablehloEntryFunctionAttrName) + .getValue(); +} + +bool IsHybridQuantizedOp(Operation* op) { + if ((op->getNumOperands() != 2 && op->getNumOperands() != 3) || + op->getResultTypes().size() != 1) { + return false; + } + Type lhs_type = op->getOperand(0).getType(); + Type rhs_type = op->getOperand(1).getType(); + Type result_type = op->getResult(0).getType(); + return !IsQuantizedTensorType(lhs_type) && IsQuantizedTensorType(rhs_type) && + !IsQuantizedTensorType(result_type); +} + +absl::StatusOr IsDotGeneralFullyConnected(DotGeneralOp dot_general_op) { + if (dot_general_op == nullptr) + return absl::InvalidArgumentError( + "Given dot_general op cannot be null when checking " + "`IsDotGeneralBatchMatmul`."); + const ::mlir::stablehlo::DotDimensionNumbersAttr dot_dimension_numbers = + dot_general_op.getDotDimensionNumbers(); + const ArrayRef lhs_contracting_dims = + dot_dimension_numbers.getLhsContractingDimensions(); + const ArrayRef rhs_contracting_dims = + dot_dimension_numbers.getRhsContractingDimensions(); + const int64_t input_rank = + mlir::dyn_cast(dot_general_op.getOperand(0).getType()) + .getRank(); + const int64_t filter_rank = + mlir::dyn_cast(dot_general_op.getOperand(1).getType()) + .getRank(); + // The following conditions are such requirements: + // - rank(lhs) is 1 or 2 + // - rank(rhs) = 2 + // - size(lhs_contracting_dimensions) = 1 + // - size(rhs_contracting_dimensions) = 1 + // - lhs_contracting_dimension = last dimension of lhs. + // - `stablehlo.dot_general` should not have `lhs_batching_dim`. + // - quantization_dimension(rhs) should not be in + // `rhs_contracting_dimensions`. + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general + const bool has_proper_rank = + (input_rank == 1 || input_rank == 2) && filter_rank == 2; + const bool has_proper_contracting_dim = + lhs_contracting_dims.size() == 1 && rhs_contracting_dims.size() == 1 && + lhs_contracting_dims[0] == input_rank - 1; + const bool is_not_batch_op = + dot_dimension_numbers.getLhsBatchingDimensions().empty(); + const bool has_proper_quantization_dimension = + absl::c_find(rhs_contracting_dims, filter_rank) == + rhs_contracting_dims.end(); + return has_proper_rank && has_proper_contracting_dim && is_not_batch_op && + has_proper_quantization_dimension; +} + +std::optional GetDotGeneralQuantizationDim( + DotGeneralOp dot_general_op) { + if (dot_general_op == nullptr) return std::nullopt; + const int64_t filter_rank = + mlir::dyn_cast(dot_general_op.getOperand(1).getType()) + .getRank(); + + // To quantize rhs per-channel, we currently only consider the case where + // `stablehlo.dot_general` is legalizable to `tfl.fully_connected`. + const bool is_per_axis_quantizable = + IsDotGeneralFullyConnected(dot_general_op).value(); + if (!is_per_axis_quantizable) return std::nullopt; + return filter_rank - 1; +} + +bool ContainsConvOrDot(StringRef str) { + return str.contains("_conv") || str.contains("_dot_general"); +} + +} // namespace mlir::tf_quant diff --git a/tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h b/tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h new file mode 100644 index 00000000000000..d542996e522f8c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h @@ -0,0 +1,260 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_ATTRS_AND_CONSTRAINTS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_ATTRS_AND_CONSTRAINTS_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" + +namespace mlir::tf_quant { + +constexpr char kAttrMapAttribute[] = "attr_map"; + +// Name of the string attribute attached to `XlaCallModuleOp`, which is the +// textproto representation of `Method`. +inline constexpr StringRef kQuantizationMethodAttr = "_quantization_method"; + +// Permutation from the NHWC tensor format to NCHW. This is an inverse +// permutation of `kNchwToNhwcPermutation`. +inline constexpr std::array kNhwcToNchwPermutation = {0, 3, 1, 2}; + +// Permutation from the NCHW tensor format to NHWC. This is an inverse +// permutation of `kNchwToNhwcPermutation`. +inline constexpr std::array kNchwToNhwcPermutation = {0, 2, 3, 1}; + +// Permutation from the OIHW (== (output features, input features, height, +// width)) tensor format to HWIO. This is commonly used to transpose convolution +// weights represented as OIHW format to HWIO, which is more desirable for +// certain downstream optimization passes (e.g. XLA). +inline constexpr std::array kOihwToHwioPermutation = {2, 3, 1, 0}; + +// Returns true if the value has static shape. +bool HasStaticShape(Value value); + +// Returns true if the value has static shape at given dims. +bool HasStaticShapeAtDims(Value value, ArrayRef dims); + +// Whether `value` has known rank of `rank`. Returns false when it is not a +// `ShapedType` or its rank is unknown. +inline bool HasRankOf(Value value, const int64_t rank) { + auto shaped_type = mlir::dyn_cast_or_null(value.getType()); + return shaped_type && shaped_type.hasRank() && shaped_type.getRank() == rank; +} + +// Creates a new type that has the shape from the `old_type` and the element +// type from the `element_type`. +Type CloneTypeWithNewElementType(Type old_type, Type element_type); + +// Creates an array with integer/float type. +template || std::is_same_v), void>> +Value CreateConstValue(OpBuilder& builder, const Location loc, + const SmallVector& shape, + const SmallVector& values) { + if constexpr (std::is_integral_v) { + auto shape_type = + RankedTensorType::get(shape, builder.getIntegerType(sizeof(T) * 8)); + + const auto attr = DenseIntElementsAttr::get(shape_type, values); + return builder.create(loc, attr); + } + + const auto type = RankedTensorType::get(shape, builder.getF32Type()); + const auto value_attr = DenseFPElementsAttr::get(type, values); + return builder.create(loc, value_attr); +} + +// Creates a 1D array with integer/float type. +template +Value Create1DConstValue(OpBuilder& builder, const Location loc, + const SmallVector& values) { + return CreateConstValue(builder, loc, + {static_cast(values.size())}, values); +} + +// Creates a scalar with integer / float type. +template +Value CreateScalarConstValue(OpBuilder& builder, const Location loc, + const T value) { + return CreateConstValue(builder, loc, /*shape=*/{}, {value}); +} + +// Checks if the value is a constant and return its splat value. +template || std::is_same_v), void>> +bool GetSplatValue(Value value, T& splat_value) { + if constexpr (std::is_integral_v) { + DenseIntElementsAttr value_attr; + if (!matchPattern(value, m_Constant(&value_attr)) || + !value_attr.isSplat()) { + return false; + } + splat_value = value_attr.getSplatValue(); + return true; + } + + DenseFPElementsAttr value_attr; + if (!matchPattern(value, m_Constant(&value_attr)) || !value_attr.isSplat()) { + return false; + } + splat_value = value_attr.getSplatValue(); + return true; +} + +// Checks if the value is a constant and its splat value is equal to x. +template +bool IsSplatValueEqual(Value value, const T x) { + T splat_value; + if (!GetSplatValue(value, splat_value)) return false; + + return splat_value == x; +} + +// Checks if two values are constants and their splat values are equal. +template +bool AreSplatValuesEqual(Value x, Value y) { + T splat_x, splat_y; + if (!GetSplatValue(x, splat_x) || !GetSplatValue(y, splat_y)) { + return false; + } + + return splat_x == splat_y; +} + +// Clones an operation with new operands while keeping attributes. +SmallVector CloneOpWithReplacedOperands(OpBuilder& builder, + Operation* op, + ArrayRef new_operands); + +// Tries casting `op` with a concrete op type `T`. If the cast fails or `op` is +// a `nullptr`, returns `failure` and prints a debugging message identifying +// the cast attempt as `name`. +template +FailureOr TryCast(Operation* op, const StringRef name) { + auto cast_op = dyn_cast_or_null(op); + if (cast_op) { + return cast_op; + } else { + DEBUG_WITH_TYPE("mlir-quant-attrs-and-constraints", + llvm::dbgs() << "Failed to match " << name << " (" + << T::getOperationName() << ").\n"); + return failure(); + } +} + +FailureOr CastI64ToI32(int64_t value); + +// Tries to cast an array of int64 to int32. If any of the element in the +// array is not in the range of int32, returns failure(). +FailureOr> CastI64ArrayToI32( + ArrayRef int64_array); + +// Returns the first operation with the given type in the function. +template +OpType FindOperationOfType(func::FuncOp function) { + for (auto op : function.getBody().getOps()) { + return op; + } + return nullptr; +} + +// Returns the first user of the given operation, optionally of the given +// type if provided. If there is no user or user of type, return nullptr. +template +Operation* FindUserOfType(Operation* op) { + for (Operation* user : op->getUsers()) { + if (isa(user)) { + return user; + } + } + return nullptr; +} + +// Returns the first user of the given operation, optionally of the given +// type if provided. If there is no user or user of type, return nullptr. +template +Operation* FindOperandOfType(Operation* op) { + for (Value operand_value : op->getOperands()) { + if (isa(operand_value.getDefiningOp())) { + return operand_value.getDefiningOp(); + } + } + return nullptr; +} + +// Returns the function attribute for the given call op which is lifted for +// quantization. +inline FlatSymbolRefAttr GetFuncAttr(TF::PartitionedCallOp call_op) { + return mlir::dyn_cast(call_op.getFAttr()); +} + +inline FlatSymbolRefAttr GetFuncAttr(TF::XlaCallModuleOp call_op) { + return call_op->getAttrOfType( + TF::kStablehloEntryFunctionAttrName); +} + +// Returns the entry function name for the given tf.XlaCallModule op. Returns +// empty string if such attribute does not exist. +StringRef GetEntryFunctionName(TF::XlaCallModuleOp op); + +// Checks whether the given op contains QuantizationTrait::FullyQuantizable. +inline bool HasQuantizableTrait(Operation* op) { + return op->hasAttrOfType(kQuantTraitAttrName) && + op->getAttrOfType(kQuantTraitAttrName).getValue().str() == + QuantTraitValues[QuantizationTrait::FullyQuantizable]; +} + +// Returns true if `op` has two operands and one result and only second operand +// is quantized. +bool IsHybridQuantizedOp(Operation* op); + +// Returns whether a given `stablehlo.dot_general` can be legalizable to +// `tfl.fully_connected`. +absl::StatusOr IsDotGeneralFullyConnected( + ::mlir::stablehlo::DotGeneralOp dot_general_op); + +// Returns the quantization dimension for a given `stablehlo.dot_general` op, +// or `std::nullopt` if the given op is not per-channel quantizable. +std::optional GetDotGeneralQuantizationDim( + ::mlir::stablehlo::DotGeneralOp dot_general_op); + +// Checks if a `StringRef` contains 'conv' or 'dot_general'. +bool ContainsConvOrDot(StringRef str); + +} // namespace mlir::tf_quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_ATTRS_AND_CONSTRAINTS_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.cc b/tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.cc new file mode 100644 index 00000000000000..602e077d095faf --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.cc @@ -0,0 +1,550 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/Version.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" +#include "tensorflow/core/ir/types/dialect.h" +#include "tensorflow/core/platform/mutex.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +namespace mlir::tf_quant { + +using ::stablehlo::quantization::Method; +using ::tsl::protobuf::TextFormat; + +// Default version number for native serialization. +constexpr int64_t kDefaultVersion = 9; +// Default platform for XlaCallModuleOp. +constexpr StringRef kPlatformCpu = "CPU"; +// Name of `tf.XlaCallModule`'s dictionary attribute for keeping the +// deserialized stablehlo module's attributes. +constexpr StringRef kStablehloModuleAttrsAttrName = "_stablehlo_module_attrs"; +// Attribute required for running shape refinement pass enabled in XlaCallModule +// version 8 and above. +constexpr StringRef kUsesShapePolymorphismAttr = "jax.uses_shape_polymorphism"; + +bool IsInLiftedFunc(Operation* op) { + if (op == nullptr) return false; + return op->getParentOfType()->hasAttr(kFusedFunctionAttr); +} + +bool IsInStableHloOpRegion(Operation* op) { + if (op == nullptr) return false; + auto parent_op = op->getParentOp(); + return parent_op != nullptr && quant::stablehlo::IsStablehloOp(parent_op); +} + +// Inserts the function to the symbol table of the module thread-safely. +StringAttr InsertToSymbolTable(Operation& module, Operation& function, + const StringRef func_name) { + static tensorflow::mutex* mtx = new tensorflow::mutex(); + tensorflow::mutex_lock lock(*mtx); + + SymbolTable symbol_table(&module); + std::string unique_name = func_name.str(); + int32_t uniquing_counter = 0; + while (symbol_table.lookup(unique_name) != nullptr) { + ++uniquing_counter; + unique_name = absl::StrCat(func_name.str(), "_", uniquing_counter); + } + function.setAttr("sym_name", + StringAttr::get(module.getContext(), unique_name)); + return symbol_table.insert(&function); +} + +// Creates the TF::PartitionedCallOp with the given arguments and output types. +// This function call op is for invoking the TF subgraphs. +ValueRange CreateTFPartitionedCallOp(OpBuilder& builder, + const Location location, + const StringRef func_name, + const TypeRange output_types, + const ValueRange args) { + TF::PartitionedCallOp call_op = builder.create( + location, output_types, args, + /*args_attrs=*/nullptr, /*res_attrs=*/nullptr, + FlatSymbolRefAttr::get(builder.getStringAttr(func_name)), + /*config=*/"", /*config_proto=*/"", /*executor_type=*/""); + + // Set the attribute to annotate this function call op as a quantizable spot. + call_op->setAttr( + kQuantTraitAttrName, + builder.getStringAttr(StringRef( + std::string(QuantTraitValues[QuantizationTrait::FullyQuantizable])))); + + return call_op.getOutput(); +} + +// Creates the TF::XlaCallModuleOp with the given arguments and output types. +// This function call op is for invoking the StableHLO subgraphs. +ValueRange CreateTFXlaCallModuleOp(OpBuilder& builder, const Location location, + const StringRef func_name, + const TypeRange output_types, + const ValueRange args) { + MLIRContext* ctx = builder.getContext(); + // Collect the shapes of the output to fill up the Sout attribute. + SmallVector shape_attrs; + for (const Type result_type : output_types) { + shape_attrs.push_back( + tf_type::ShapeAttr::get(ctx, mlir::cast(result_type))); + } + auto empty_array_attr = ArrayAttr::get(ctx, {}); + auto platforms = ArrayAttr::get(ctx, {StringAttr::get(ctx, kPlatformCpu)}); + + auto call_op = builder.create( + location, + /*output=*/output_types, + /*args=*/args, + /*version=*/kDefaultVersion, /*module=*/"", + /*Sout=*/ArrayAttr::get(ctx, shape_attrs), + /*dim_args_spec=*/empty_array_attr, + /*platforms=*/platforms, + /*function_list=*/empty_array_attr, + /*has_token_input_output=*/false, + /*disabled_checks=*/empty_array_attr); + + // Set the function name. This will be controlled by the + // XlaCallModuleSerialization related passes directly, which means that the + // function name can be changed by those passes. + call_op->setAttr(TF::kStablehloEntryFunctionAttrName, + FlatSymbolRefAttr::get(builder.getStringAttr(func_name))); + + // Set target version to WEEK_4 since this is an offline quantizer. + std::string target_version = + mlir::vhlo::Version::fromCompatibilityRequirement( + vhlo::Version::CompatibilityRequirement::WEEK_4) + .toString(); + call_op->setAttr(TF::kStablehloVersionAttrName, + builder.getStringAttr(target_version)); + + // Store the custom attribute to restore the function name when loading it + // back in the post calibration stage. As mentioned above, the above entry + // function attribute is not reliable. + call_op->setAttr(kOriginalStablehloEntryFunctionAttrName, + builder.getStringAttr(func_name)); + + // Set the attribute to annotate this function call op as a quantizable spot. + call_op->setAttr( + kQuantTraitAttrName, + builder.getStringAttr(StringRef( + std::string(QuantTraitValues[QuantizationTrait::FullyQuantizable])))); + + // Set jax.uses_shape_polymorphism=true to enable shape refinement at runtime. + // This is needed for native serialization version >= 8. + call_op->setAttr(kStablehloModuleAttrsAttrName, + builder.getDictionaryAttr(builder.getNamedAttr( + kUsesShapePolymorphismAttr, builder.getBoolAttr(true)))); + + return call_op.getOutput(); +} + +// Creates the function call op based on the given call_op_type argument. +ValueRange CreateFunctionCallOp(OpBuilder& builder, const Location location, + const FunctionCallOpType call_op_type, + const StringRef func_name, + const TypeRange output_types, + const ValueRange args) { + switch (call_op_type) { + case FunctionCallOpType::TFXlaCallModuleOp: + return CreateTFXlaCallModuleOp(builder, location, func_name, output_types, + args); + case FunctionCallOpType::TFPartitionedCallOp: + return CreateTFPartitionedCallOp(builder, location, func_name, + output_types, args); + } +} + +// Finds ops in the paths from arguments to results. The ops is listed in an +// order that the former ops shouldn't have any dependencies on the later ones. +SmallVector FindOpsFromArgumentsToResults( + const ArrayRef arguments, const ArrayRef results) { + std::queue value_queue; + for (Value result : results) { + value_queue.push(result); + } + absl::flat_hash_set argument_set; + for (Value argument : arguments) { + argument_set.insert(argument.getImpl()); + } + + // Searching for ops from results to arguments. Duplicate ops in the op stack + // are intentional in order to make sure the op on the top of the stack + // doesn't depends on any ops below it. + std::stack op_stack; + while (!value_queue.empty()) { + Value current_value = value_queue.front(); + value_queue.pop(); + + Operation* defining_node = current_value.getDefiningOp(); + if (defining_node == nullptr) continue; + op_stack.push(defining_node); + for (Value arg : defining_node->getOperands()) { + if (!argument_set.contains(arg.getImpl())) { + value_queue.push(arg); + } + } + } + + // Remove duplicate ops from the op stack. + SmallVector sorted_ops; + absl::flat_hash_set unique_ops; + while (!op_stack.empty()) { + Operation* current_op = op_stack.top(); + op_stack.pop(); + if (unique_ops.contains(current_op)) continue; + sorted_ops.push_back(current_op); + unique_ops.insert(current_op); + } + return sorted_ops; +} + +// Finds the name of each attribute in `attributes` and set the attr_map +// attribute which maps an attribute identifier to its attribute name. The +// identifier is the order of that attribute in `attributes`. This map +// is then used to set attributes in the quantized functions in the +// QuantizeCompositeFunctionsPass. +// For example, for tf.MatMul with `attributes` = {{"transpose_a", false}, +// {"transpose_b", false}}, the generated attr_map is +// "0:transpose_a,1:transpose_b", where 0 and 1 are the respective attribute +// identifiers. +// This function returns success if all attributes could be found. +LogicalResult SetAttributeMap(MLIRContext& context, + const ArrayRef attributes, + const ArrayRef ops) { + // A map to find which operation an attribute belongs to. + // The key for this map uses the entire NamedAttribute object, i.e. the + // {attribute_name, attribute_value} pair. + llvm::SmallDenseMap attr_to_op_map; + for (Operation* op : ops) { + for (const NamedAttribute named_attr : op->getAttrs()) { + attr_to_op_map.insert({named_attr, op}); + } + } + + for (int idx : llvm::seq(0, attributes.size())) { + const NamedAttribute& attribute = attributes[idx]; + // Skip the following steps if the attribute value is `NullAttribute`. + if (const auto string_attr = + mlir::dyn_cast_or_null(attribute.getValue()); + string_attr != nullptr && + string_attr.getValue() == kNullAttributeValue) { + continue; + } + + if (std::find_if( + attr_to_op_map.begin(), attr_to_op_map.end(), [&](auto attr_op) { + return std::get<0>(attr_op).getName() == attribute.getName(); + }) == attr_to_op_map.end()) { + emitError(UnknownLoc::get(&context), + "Could not find attribute: " + attribute.getName().str()); + return failure(); + } + + Operation* owner_op; + for (const auto& [attr, val] : attr_to_op_map) { + if (attr.getName() == attribute.getName()) owner_op = val; + } + if (quant::stablehlo::IsStablehloOp(owner_op)) { + owner_op->setAttr(StringRef(attribute.getName()), attribute.getValue()); + } else { + owner_op = attr_to_op_map[attribute]; + + std::string new_attr_map_str{}; + if (owner_op->hasAttr(kAttrMapAttribute)) { + new_attr_map_str = + owner_op->getAttrOfType(kAttrMapAttribute).str(); + absl::StrAppend(&new_attr_map_str, ","); + } + + // Append ":". Ex) "0:transpose_a". + const std::string identifier = std::to_string(idx); + const StringAttr attribute_name = attribute.getName(); + absl::StrAppend(&new_attr_map_str, identifier, ":", attribute_name.str()); + owner_op->setAttr(kAttrMapAttribute, + StringAttr::get(&context, new_attr_map_str)); + } + } + return success(); +} + +// Creates a function to wrap the section between arguments and results. +SmallVector LiftAsFunctionCall( + OpBuilder& builder, const Location location, + const FunctionCallOpType call_op_type, const StringRef func_name, + const ArrayRef arguments, const ArrayRef results, + const ArrayRef attributes) { + MLIRContext* context = builder.getContext(); + if (results.empty()) { + emitError(UnknownLoc::get(context), "No result values specified"); + return {}; + } + Operation* result_op = results[0].getDefiningOp(); + auto module = result_op->getParentOfType(); + + // Create a private function and copy all ops between arguments and results. + auto current_func = result_op->getParentOfType(); + auto guard = OpBuilder::InsertionGuard(builder); + builder.setInsertionPointAfter(current_func); + TypeRange arg_types{ValueRange{arguments}}; + TypeRange result_types{ValueRange{results}}; + auto func_type = FunctionType::get(context, arg_types, result_types); + + SmallVector arg_locs; + for (Value arg : arguments) { + arg_locs.push_back(arg.getLoc()); + } + + auto wrap_func = builder.create(location, func_name, func_type); + wrap_func.setVisibility(SymbolTable::Visibility::Private); + // The callee function for TF::XlaCallModuleOp must have this attribute. + if (call_op_type == FunctionCallOpType::TFXlaCallModuleOp) { + wrap_func->setAttr(TF::kFromXlaCallModuleAttrName, builder.getUnitAttr()); + } + wrap_func->setAttr(kFusedFunctionAttr, builder.getUnitAttr()); + builder.createBlock(&wrap_func.getBody(), wrap_func.begin(), arg_types, + arg_locs); + + IRMapping mapping; + for (int32_t i : llvm::seq(0, arguments.size())) { + mapping.map(arguments[i], wrap_func.getArgument(i)); + } + + auto cloning_ops = FindOpsFromArgumentsToResults(arguments, results); + // Set the location of call op to QuantizationUnitLoc if found. + Location call_op_loc = location; + for (Operation* op : cloning_ops) { + std::optional unit = + quant::FindQuantizationUnitFromLoc(op->getLoc()); + if (unit.has_value()) { + call_op_loc = + quant::QuantizationUnitLoc(builder.getContext(), unit.value()); + } + } + + if (failed(SetAttributeMap(*context, attributes, cloning_ops))) { + current_func.emitError() << "Some attributes couldn't be found."; + } + for (Operation* op : cloning_ops) { + builder.clone(*op, mapping); + } + + SmallVector return_values; + for (Value result : results) { + return_values.push_back(mapping.lookupOrNull(result)); + } + builder.create(location, return_values); + + // Create a function call to the newly created function. + StringAttr new_func_name = + InsertToSymbolTable(*module, *wrap_func, func_name); + builder.setInsertionPointAfter(result_op); + ValueRange new_results = + CreateFunctionCallOp(builder, call_op_loc, call_op_type, + new_func_name.getValue(), result_types, arguments); + return SmallVector(new_results.begin(), new_results.end()); +} + +SmallVector LiftAsFunctionCall(OpBuilder& builder, + const Location location, + const FunctionCallOpType call_op_type, + const StringRef func_name, + const ArrayRef arguments, + const ArrayRef results) { + SmallVector attributes; + return LiftAsFunctionCall(builder, location, call_op_type, func_name, + arguments, results, attributes); +} + +SmallVector AppendToVector(const ArrayRef arguments, + Value append) { + SmallVector ret(arguments); + ret.push_back(append); + return ret; +} + +// Check if the given einsum equation is supported by XlaDotV2. +// Conditions: +// 1. Two inputs & one output. +// 2. No ... in the equation. +// 3. Batch dimensions should be the same, or only the left equation should have +// the batch dimension. This condition is from the XlaDotV2 specification. It +// could process the following equation by setting the attributes properly: +// abc,cd->abd. +// 4. The output should be in the form: [batch dims][lhs dims][rhs dims] +bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr) { + StringRef equation = equation_attr.getValue(); + + if (!absl::StrContains(equation, "->") || !absl::StrContains(equation, ",") || + absl::StrContains(equation, ".")) { + return false; + } + + // Parse equation. + int idx_arrow = equation.find("->"); + StringRef calc_eq = equation.substr(0, idx_arrow); + StringRef out_eq = equation.substr(idx_arrow + 2); + + int idx_comma = calc_eq.find(','); + StringRef lhs_eq = calc_eq.substr(0, idx_comma); + StringRef rhs_eq = calc_eq.substr(idx_comma + 1); + + if (absl::StrContains(rhs_eq, ",")) return false; + + int lhs_out_idx_start = out_eq.size(); + int lhs_out_idx_end = -1; + int rhs_out_idx_start = out_eq.size(); + int rhs_out_idx_end = -1; + int lhs_batch_dim_size = 0; + int rhs_batch_dim_size = 0; + for (const char c : lhs_eq) { + if (absl::StrContains(out_eq, c) && absl::StrContains(rhs_eq, c)) { + lhs_batch_dim_size++; + } else if (absl::StrContains(out_eq, c)) { + const int out_idx = out_eq.find(c); + if (out_idx < lhs_out_idx_end) { + // Left-hand equation is reversed in the output. + return false; + } + lhs_out_idx_start = std::min(lhs_out_idx_start, out_idx); + lhs_out_idx_end = std::max(lhs_out_idx_end, out_idx); + } + } + + for (const char c : rhs_eq) { + if (absl::StrContains(out_eq, c) && absl::StrContains(lhs_eq, c)) { + rhs_batch_dim_size++; + } else if (absl::StrContains(out_eq, c)) { + int out_idx = out_eq.find(c); + if (out_idx < rhs_out_idx_end) { + return false; + } + if (out_idx < rhs_out_idx_start) rhs_out_idx_start = out_idx; + if (out_idx > rhs_out_idx_end) rhs_out_idx_end = out_idx; + } + } + + if (lhs_batch_dim_size != rhs_batch_dim_size && lhs_batch_dim_size != 0 && + rhs_batch_dim_size != 0) { + // Batch dimension does not match. + return false; + } + + // All the lhs equations should come first. + if (lhs_out_idx_end > rhs_out_idx_start) return false; + + // All the lhs out dim and rhs out dim should be larger than the batch dims, + // and they should not be mixed. + int batch_dim_size = std::max(rhs_batch_dim_size, lhs_batch_dim_size); + return lhs_out_idx_start >= batch_dim_size && + rhs_out_idx_start >= batch_dim_size; +} + +absl::StatusOr GetQuantizationMethod(Operation* absl_nonnull op) { + const auto quantization_method_attr = + op->getAttrOfType(kQuantizationMethodAttr); + if (!quantization_method_attr) { + return absl::InvalidArgumentError(absl::StrCat( + "Attribute ", kQuantizationMethodAttr.str(), " is not found.")); + } + + Method quantization_method; + const std::string method_txtpb = quantization_method_attr.getValue().str(); + if (!TextFormat::ParseFromString(method_txtpb, &quantization_method)) { + return absl::InternalError( + absl::StrCat("Failed to parse Method from textproto: ", method_txtpb)); + } + + return quantization_method; +} + +Method GetQuantizationMethodOrDefault(Operation* absl_nonnull op) { + absl::StatusOr method = GetQuantizationMethod(op); + if (method.status().code() == absl::StatusCode::kInternal) { + // This indicates that the `Method` protobuf string is corrupt, but this + // function ignores it and returns the default instance. + op->emitError(absl::StrCat("Failed to get quantization method: ", + method.status().ToString())); + } + return method.ok() ? *method : Method::default_instance(); +} + +bool HasWeightOnlyPtqMethod(TF::XlaCallModuleOp xla_call_module_op) { + Method method = GetQuantizationMethodOrDefault(xla_call_module_op); + return method.has_weight_only_ptq(); +} + +bool IsWeightOnlyQuantizableOp(const Operation& op) { + if (auto call_op = dyn_cast(op)) { + StringRef entry_function_name = GetEntryFunctionName(call_op); + absl::StatusOr quantization_method = GetQuantizationMethod(call_op); + return ContainsConvOrDot(entry_function_name) && quantization_method.ok() && + quantization_method->has_weight_only_ptq(); + } + return false; +} + +SmallVector GetSortedFunctions(ModuleOp module_op) { + auto iterator_range = module_op.getOps(); + SmallVector func_ops(iterator_range.begin(), + iterator_range.end()); + absl::c_sort(func_ops, [](func::FuncOp op1, func::FuncOp op2) { + return op1.getName() < op2.getName(); + }); + return func_ops; +} + +} // namespace mlir::tf_quant diff --git a/tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h b/tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h new file mode 100644 index 00000000000000..b421ec3c672d65 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h @@ -0,0 +1,114 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_LIFT_AS_FUNCTION_CALL_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_LIFT_AS_FUNCTION_CALL_H_ + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::tf_quant { + +// This attribute will be set for functions created by this pass. +// Presence of this attribute will mark the function as quantization target. +inline constexpr StringRef kFusedFunctionAttr = "tf_quant.composite_function"; +// The keyword to detect if this is a `NullAttribute`. +inline constexpr StringRef kNullAttributeValue = "N/A"; + +// Prefixes attached to lifted functions. +constexpr StringRef kQuantizedFuncPrefix = "quantized_"; +constexpr StringRef kCompositeFuncPrefix = "composite_"; + +// The attribute will be used for TF::XlaCallModuleOp to restore the original +// function name when loading it back. +inline constexpr StringRef kOriginalStablehloEntryFunctionAttrName = + "_original_entry_function"; + +// FunctionCallOpType to be generated as the function call operator when +// function lifting will happen. +enum FunctionCallOpType { TFPartitionedCallOp = 0, TFXlaCallModuleOp = 1 }; + +// Checks if an op is inside a lifted function. +// If the given op pointer is a nullptr, returns false. +bool IsInLiftedFunc(Operation* op); + +// Checks if the op is inside a StableHLO op with region. +// If the given op pointer is a nullptr, returns false. +bool IsInStableHloOpRegion(Operation* op); + +// Checks if a given einsum op is supported for XlaDotV2 quantization. +bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr); + +// Gets the quantization method from `op`. It is retrieved from the +// `kQuantizationMethodAttr` string attribute. Returns +// `absl::InvalidArgumentError` when the attribute doesn't exist. Returns +// `absl::InternalError` when parsing the attribute to `Method` failed. +// `op` must be non-null. +absl::StatusOr<::stablehlo::quantization::Method> GetQuantizationMethod( + Operation* absl_nonnull op); + +// Gets the quantization method from `op`. It is retrieved from the +// `kQuantizationMethodAttr` string attribute. Returns a default instance of +// `Method` iff the attribute doesn't exist or the attribute contains an invalid +// textproto for `Method`. `op` must be non-null. +::stablehlo::quantization::Method GetQuantizationMethodOrDefault( + Operation* absl_nonnull op); + +// Creates a function to wrap the section between arguments and results. +// The generated function call op type will be decided by the given call_op_type +// argument. Currently, it supports TF::XlaCallModuleOp and +// TF::PartitionedCallOp function call op generations. +SmallVector LiftAsFunctionCall(OpBuilder& builder, Location location, + FunctionCallOpType call_op_type, + StringRef func_name, + ArrayRef arguments, + ArrayRef results, + ArrayRef attributes); + +// Same as above but with empty attributes. +SmallVector LiftAsFunctionCall(OpBuilder& builder, Location location, + FunctionCallOpType call_op_type, + StringRef func_name, + ArrayRef arguments, + ArrayRef results); + +// Add the second argument to the first argument, which is expected to be an +// argument list. +// Used to attach bias to einsum argument list. +SmallVector AppendToVector(ArrayRef arguments, Value append); + +// Checks if the `Method` attatched to the given `tf.XlaCallModule` op has +// `WeightOnlyPtq`. +bool HasWeightOnlyPtqMethod(TF::XlaCallModuleOp xla_call_module_op); + +// Checks if an op is a `tf.XlaCallModule` op, contains 'conv' or 'dot_general' +// in its name and has `Method` with `WeightOnlyPtq`. +bool IsWeightOnlyQuantizableOp(const Operation& op); + +// Lists the functions in a ModuleOp sorted by their names. +SmallVector GetSortedFunctions(ModuleOp module_op); + +} // namespace mlir::tf_quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_LIFT_AS_FUNCTION_CALL_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD index 8079760d548d5f..2ce3b743dcd766 100644 --- a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD @@ -30,7 +30,6 @@ cc_library( deps = [ ":tf_quantization_config", ":tf_quantization_interfaces_inc_gen", - "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:portable_tensor_utils", "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/tools/optimize:quantization_utils", diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization.td b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization.td index 050b87e8c08834..3909495ef239fb 100644 --- a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization.td +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization.td @@ -31,12 +31,12 @@ include "mlir/Dialect/Quant/IR/QuantBase.td" // explicit signedness check to differentiate the signed/unsigned constraints // predicates from one another at the TD level. class TFQuantizedType params, bit signed> - : Type()">, - CPred<"$_self.cast()" # + : Type($_self)">, + CPred<"llvm::cast($_self)" # ".getStorageTypeIntegralWidth() == " # !head(params)>, - Or<[CPred<"$_self.cast()" # + Or<[CPred<"llvm::cast($_self)" # ".getStorageType().isSignlessInteger()">, - CPred<"$_self.cast()" # + CPred<"llvm::cast($_self)" # ".getStorageType().isSignedInteger() == " # signed>]>]>, "Q" # !if (signed, "I", "UI") # !head(params) # " type"> { string name = n; diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.cc b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.cc index 80a6a5c9c9b442..2beccf116125d9 100644 --- a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.cc +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.cc @@ -45,10 +45,10 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/portable_tensor_utils.h" #include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantizeUtils.h" #include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" #include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" #include "tensorflow/compiler/mlir/tools/optimize/quantization_utils.h" @@ -712,7 +712,7 @@ ElementsAttr Quantize(const Attribute real_value, const Type tensor_type) { quant::QuantizedType::getQuantizedElementType(tensor_type)) { Type converted_type; return dyn_cast_or_null( - quantfork::quantizeAttr(real_value, q_type, converted_type)); + mlir::quant::ir::quantizeAttr(real_value, q_type, converted_type)); } return {}; } diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h index 926adebdab3764..ecafcf473e33ce 100644 --- a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h @@ -200,7 +200,7 @@ bool QuantizableOpSupportsFloatOutputType(Operation* op); // Specialized version of location to string for flatbuffer exported locations. inline std::string GetTensorNameFromLoc(Location loc) { - if (auto name_loc = loc.dyn_cast()) { + if (auto name_loc = llvm::dyn_cast(loc)) { return name_loc.getName().str(); } return ""; @@ -219,7 +219,7 @@ struct ConvertStatsToQDQs LogicalResult matchAndRewrite(mlir::quant::ir::StatisticsOp op, PatternRewriter& rewriter) const override { - Type expressed = op.getType().cast().getElementType(); + Type expressed = llvm::cast(op.getType()).getElementType(); quant::QuantizedType quant_type; SmallVector mins, maxs; @@ -227,7 +227,7 @@ struct ConvertStatsToQDQs // Per axis quantization (or per channel quantization) int stats_num = op.getAxisStats()->getNumElements(); if (stats_num == 0 || stats_num % 2 != 0) return failure(); - auto stats = op.getAxisStats()->dyn_cast(); + auto stats = llvm::dyn_cast(*op.getAxisStats()); if (!stats) return failure(); for (auto it = stats.begin(), e = stats.end(); it != e; ++it) { @@ -256,7 +256,7 @@ struct ConvertStatsToQDQs quant_type = DownCastScale(quant_type, mins, maxs, op->getLoc()); } } else if (auto stats = - op.getLayerStats().dyn_cast()) { + llvm::dyn_cast(op.getLayerStats())) { // Per tensor quantization auto statValues = stats.getValues(); double rmin = FloatAttr::getValueAsDouble(statValues[0]); @@ -482,7 +482,7 @@ class QuantizationPattern : public RewritePattern { } if (!nodes_blocklist.empty()) { - if (auto name_loc = quantizing_op->getLoc().dyn_cast()) { + if (auto name_loc = llvm::dyn_cast(quantizing_op->getLoc())) { std::string sloc = name_loc.getName().str(); if (!sloc.empty() && (nodes_blocklist.find(sloc) != nodes_blocklist.end())) { @@ -504,12 +504,13 @@ class QuantizationPattern : public RewritePattern { inputs.reserve(quantizing_op->getNumOperands()); for (auto operand : quantizing_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (isa(operand_type)) { inputs.push_back(operand); continue; } - auto ele_type = operand.getType().cast().getElementType(); + auto ele_type = + llvm::cast(operand.getType()).getElementType(); if (static_cast(this) ->AllowDynamicRangeQuantizedOperand(quantizing_op, custom_map)) { @@ -569,13 +570,13 @@ class QuantizationPattern : public RewritePattern { Type result_type = result.getType(); // Add this to the test coverage once we create test ops with none // type results. - if (result_type.isa()) { + if (isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; } Type result_ele_type = - result.getType().cast().getElementType(); + llvm::cast(result.getType()).getElementType(); // If the user is the QuantizeOp, it must be the only user. if (result.hasOneUse() && llvm::isa(*result.user_begin())) { @@ -649,11 +650,9 @@ class QuantizationPattern : public RewritePattern { } for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { - if (!quantizing_op->getResult(i) - .getType() - .cast() - .getElementType() - .isa()) { + if (!isa( + cast(quantizing_op->getResult(i).getType()) + .getElementType())) { continue; } CreateVerifier(quantizing_op, quantized_op, rewriter, i, @@ -674,9 +673,7 @@ class QuantizationPattern : public RewritePattern { void RewireFloatModelBackbone(Operation* quantized_op, Operation* float_op) const { for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { - if (!float_op->getResult(i) - .getType() - .cast() + if (!llvm::cast(float_op->getResult(i).getType()) .getElementType() .isF32()) { continue; @@ -769,14 +766,14 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { auto flags = quant::QuantizationFlags::Signed; QType new_qtype; - if (auto uqtype = qtype.template dyn_cast()) { + if (auto uqtype = llvm::dyn_cast(qtype)) { new_qtype = quant::UniformQuantizedType::getChecked( op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(), uqtype.getScale(), uqtype.getZeroPoint() - offset, uqtype.getStorageTypeMin() - offset, uqtype.getStorageTypeMax() - offset); - } else if (auto aqtype = qtype.template dyn_cast< - quant::UniformQuantizedPerAxisType>()) { + } else if (auto aqtype = + llvm::dyn_cast(qtype)) { auto zero_points = aqtype.getZeroPoints(); llvm::SmallVector new_zero_points(zero_points.begin(), zero_points.end()); diff --git a/tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.cc b/tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.cc new file mode 100644 index 00000000000000..da812387fc1be3 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.cc @@ -0,0 +1,232 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.h" + +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +#define DEBUG_TYPE "uniform-quantized-types" + +namespace mlir { +namespace tf_quant { + +using quant::QuantizedType; +using quant::UniformQuantizedPerAxisType; +using quant::UniformQuantizedType; + +UniformQuantizedType CreateI8F32UniformQuantizedType(const Location loc, + MLIRContext& context, + const double scale, + const int64_t zero_point, + const bool narrow_range) { + return UniformQuantizedType::getChecked( + loc, /*flags=*/quant::QuantizationFlags::Signed, + /*storageType=*/IntegerType::get(&context, /*width=*/8), + /*expressedType=*/Float32Type::get(&context), scale, zero_point, + /*storageTypeMin=*/llvm::minIntN(8) + (narrow_range ? 1 : 0), + /*storageTypeMax=*/llvm::maxIntN(8)); +} + +UniformQuantizedType CreateI32F32UniformQuantizedType( + const Location loc, MLIRContext& context, const double scale, + const int64_t zero_point) { + return UniformQuantizedType::getChecked( + loc, /*flags=*/quant::QuantizationFlags::Signed, + /*storageType=*/IntegerType::get(&context, /*width=*/32), + /*expressedType=*/Float32Type::get(&context), scale, zero_point, + /*storageTypeMin=*/llvm::minIntN(32), + /*storageTypeMax=*/llvm::maxIntN(32)); +} + +UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( + const Location loc, MLIRContext& context, const ArrayRef scales, + const ArrayRef zero_points, const int quantization_dimension, + const bool narrow_range) { + return UniformQuantizedPerAxisType::getChecked( + loc, /*flags=*/quant::QuantizationFlags::Signed, + /*storageType=*/IntegerType::get(&context, /*width=*/8), + /*expressedType=*/Float32Type::get(&context), SmallVector(scales), + SmallVector(zero_points), quantization_dimension, + /*storageTypeMin=*/llvm::minIntN(8) + (narrow_range ? 1 : 0), + /*storageTypeMax=*/llvm::maxIntN(8)); +} + +UniformQuantizedPerAxisType CreateI32F32UniformQuantizedPerAxisType( + const Location loc, MLIRContext& context, const ArrayRef scales, + const ArrayRef zero_points, const int quantization_dimension) { + return UniformQuantizedPerAxisType::getChecked( + loc, /*flags=*/quant::QuantizationFlags::Signed, + /*storageType=*/IntegerType::get(&context, /*width=*/32), + /*expressedType=*/Float32Type::get(&context), SmallVector(scales), + SmallVector(zero_points), quantization_dimension, + /*storageTypeMin=*/llvm::minIntN(32), + /*storageTypeMax=*/llvm::maxIntN(32)); +} + +bool IsStorageTypeI8(const QuantizedType quantized_type) { + const Type storage_type = quantized_type.getStorageType(); + return storage_type.isInteger(/*width=*/8); +} + +bool IsStorageTypeI32(const QuantizedType quantized_type) { + const Type storage_type = quantized_type.getStorageType(); + return storage_type.isInteger(/*width=*/32); +} + +bool IsExpressedTypeF32(const QuantizedType quantized_type) { + const Type expressed_type = quantized_type.getExpressedType(); + return mlir::isa(expressed_type); +} + +bool IsI8F32UniformQuantizedType(const Type type) { + const UniformQuantizedType quantized_type = + mlir::dyn_cast_or_null(type); + if (!quantized_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI8(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " + << quantized_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_type << ".\n"); + return false; + } + + return true; +} + +bool IsI8F32UniformQuantizedPerAxisType(const Type type) { + const UniformQuantizedPerAxisType quantized_per_axis_type = + mlir::dyn_cast_or_null(type); + if (!quantized_per_axis_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI8(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + return true; +} + +bool IsI32F32UniformQuantizedType(const Type type) { + const UniformQuantizedType quantized_type = + mlir::dyn_cast_or_null(type); + if (!quantized_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI32(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i32 storage type. Got: " + << quantized_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_type << ".\n"); + return false; + } + + return true; +} + +bool IsI32F32UniformQuantizedPerAxisType(const Type type) { + const UniformQuantizedPerAxisType quantized_per_axis_type = + mlir::dyn_cast_or_null(type); + if (!quantized_per_axis_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI32(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i32 storage type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + return true; +} + +// Determines whether the storage type of a quantized type is supported by +// `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. +bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type) { + if (storage_type.getWidth() == 8 || + (storage_type.isSigned() && storage_type.getWidth() == 16)) { + return true; + } + LLVM_DEBUG(llvm::dbgs() + << "Uniform quantize / dequantize op only supports ui8, i8 or " + "i16 for the storage type of uniform quantized type. Got: " + << storage_type << ".\n"); + return false; +} + +bool IsQuantizedTensorType(Type type) { + if (!mlir::isa(type)) { + return false; + } + Type element_type = mlir::cast(type).getElementType(); + return mlir::isa(element_type); +} + +bool IsOpFullyQuantized(Operation* op) { + return llvm::all_of(op->getOperandTypes(), IsQuantizedTensorType) && + llvm::all_of(op->getResultTypes(), IsQuantizedTensorType); +} + +bool IsOpNotQuantized(Operation* op) { + return !llvm::any_of(op->getOperandTypes(), IsQuantizedTensorType) && + !llvm::any_of(op->getResultTypes(), IsQuantizedTensorType); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.h new file mode 100644 index 00000000000000..e0bec5c2630a44 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.h @@ -0,0 +1,116 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_UNIFORM_QUANTIZED_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_UNIFORM_QUANTIZED_TYPES_H_ + +#include + +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace tf_quant { + +// Creates a `UniformQuantizedType` with the given `scale` and `zero_point` +// values. The produced type has f32 as its expressed type and i8 as its +// storage type. The available values use the full range of the storage value, +// i.e. [-128, 127]. Assumes asymmetric quantization, meaning the zero point +// value can be a non-zero value. +// If `narrow_range` is set true (ex: for weights), a restricted range of +// integers will be used for symmetric mapping, i.e. [-127, 127]. +quant::UniformQuantizedType CreateI8F32UniformQuantizedType( + Location loc, MLIRContext& context, double scale, int64_t zero_point, + bool narrow_range = false); + +// Creates a `UniformQuantizedType` with the given `scale` and `zero_point` +// values. The produced type has f32 as its expressed type and i32 as its +// storage type. The available values use the full range of the storage value. +// Assumes asymmetric quantization, meaning the zero point value can be +// a non-zero value. +quant::UniformQuantizedType CreateI32F32UniformQuantizedType( + Location loc, MLIRContext& context, double scale, int64_t zero_point); + +// Creates a `UniformQuantizedPerAxisType` with the given `scales` and +// `zero_points` values. The produced type has f32 as its expressed type and +// i8 as its storage type. The available values use the full range of the +// storage value, i.e. [-128, 127]. Assumes asymmetric quantization, meaning the +// zero point values can be non-zero values. +// If `narrow_range` is set true (ex: for weights), a restricted range of +// integers will be used for symmetric mapping, i.e. [-127, 127]. +quant::UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( + Location loc, MLIRContext& context, ArrayRef scales, + ArrayRef zero_points, int quantization_dimension, + bool narrow_range = false); + +// Creates a `UniformQuantizedPerAxisType` with the given `scales` and +// `zero_points` values. The produced type has f32 as its expressed type and +// i32 as its storage type. The available values use the full range of the +// storage value. Assumes asymmetric quantization, meaning the +// zero point values can be non-zero values. +quant::UniformQuantizedPerAxisType CreateI32F32UniformQuantizedPerAxisType( + Location loc, MLIRContext& context, ArrayRef scales, + ArrayRef zero_points, int quantization_dimension); + +bool IsStorageTypeI8(quant::QuantizedType quantized_type); + +bool IsStorageTypeI32(quant::QuantizedType quantized_type); + +bool IsExpressedTypeF32(quant::QuantizedType quantized_type); + +// Given a value, extract the `ElementType`. +// `value` should be a non-null `TensorType`. +inline Type GetElementType(const Value value) { + return mlir::cast(value.getType()).getElementType(); +} + +// Returns true iff `type` is a uniform quantized type whose storage type is +// 8-bit integer and expressed type is f32. +bool IsI8F32UniformQuantizedType(Type type); + +// Returns true iff `type` is a uniform quantized per-axis (per-channel) type +// whose storage type is 8-bit integer and expressed type is f32. +bool IsI8F32UniformQuantizedPerAxisType(Type type); + +// Returns true iff `type` is a uniform quantized type whose storage type is +// 32-bit integer and expressed type is f32. +bool IsI32F32UniformQuantizedType(Type type); + +// Returns true iff `type` is a uniform quantized per-axis (per-channel) type +// whose storage type is 32-bit integer and expressed type is f32. +bool IsI32F32UniformQuantizedPerAxisType(Type type); + +// Determines whether the storage type of a quantized type is supported by +// `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. +bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type); + +// Returns true if a type is quantized tensor type. +bool IsQuantizedTensorType(Type type); + +// Returns true if all operands and results are quantized. +bool IsOpFullyQuantized(Operation* op); + +// Returns true iff none among operand and result tensors are quantized. +bool IsOpNotQuantized(Operation* op); + +} // namespace tf_quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_UNIFORM_QUANTIZED_TYPES_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index ec79c4f83f5d26..f63287d859b46a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -27,16 +27,24 @@ package( ) gentbl_cc_library( - name = "stablehlo_passes_inc_gen", + name = "tf_stablehlo_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - ], - "passes/passes.h.inc", - ), + tbl_outs = {"passes/tf_passes.h.inc": [ + "-gen-pass-decls", + ]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_passes.td", + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", ], +) + +gentbl_cc_library( + name = "stablehlo_passes_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/passes.h.inc": [ + "-gen-pass-decls", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/passes.td", deps = [ @@ -44,10 +52,102 @@ gentbl_cc_library( ], ) +cc_library( + name = "tf_passes", + srcs = [ + "passes/lift_quantizable_spots_as_functions_fusion.inc", + "passes/lift_quantizable_spots_as_functions_simple.inc", + "passes/remove_sharding_custom_call.inc", + "passes/tf_convert_func_to_bfloat16.cc", + "passes/tf_convert_shape_constraint_to_assert.cc", + "passes/tf_convert_xla_call_module_op_to_bfloat16.cc", + "passes/tf_defer_activation_transpose.cc", + "passes/tf_fold_constant_transpose.cc", + "passes/tf_insert_calibration_statistics_saver.cc", + "passes/tf_insert_weight_param.cc", + "passes/tf_lift_quantizable_spots_as_functions.cc", + "passes/tf_merge_fusion_with_dequantize.cc", + "passes/tf_nchw_convolution_to_nhwc.cc", + "passes/tf_optimize_graph.cc", + "passes/tf_post_quantize.cc", + "passes/tf_prepare_quantize.cc", + "passes/tf_quantize.cc", + "passes/tf_quantize_composite_functions.cc", + "passes/tf_quantize_weight.cc", + "passes/tf_remove_sharding_custom_call.cc", + "passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc", + "passes/tf_restore_function_name.cc", + "passes/tf_unfuse_mhlo_batch_norm.cc", + "passes/tf_unwrap_xla_call_module_op.cc", + "passes/tf_xla_call_module_to_call.cc", + ], + hdrs = [ + "passes/tf_passes.h", + ], + deps = [ + ":bfloat16_type", + ":fill_quantization_options", + ":lift_quantizable_spots_as_functions_fusion_inc_gen", + ":lift_quantizable_spots_as_functions_simple_inc_gen", + ":optimize_graph_inc_gen", + ":quantization_config_proto_cc", + ":quantization_options_proto_cc", + ":remove_sharding_custom_call_inc_gen", + ":stablehlo_type_utils", + ":tf_quantization_patterns", + ":tf_stablehlo_passes_inc_gen", + "//tensorflow/compiler/mlir/quantization/common:func", + "//tensorflow/compiler/mlir/quantization/common:tf_attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:tf_lift_as_function_call", + "//tensorflow/compiler/mlir/quantization/common:uniform_quantized_types", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib:tf_quantization_config", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:permutation", + "//tensorflow/compiler/mlir/quantization/stablehlo/ops:tf_stablehlo_op_quant_spec", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core/ir/types:Dialect", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@eigen_archive//:eigen3", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:regexp", + "@local_xla//xla/mlir_hlo", + "@local_xla//xla/mlir_hlo:mhlo_passes", + "@stablehlo//:chlo_ops", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", + "@stablehlo//:stablehlo_portable_api", + "@stablehlo//:stablehlo_serialization", + "@stablehlo//:version", + ], +) + cc_library( name = "passes", srcs = [ "passes/convert_func_to_bfloat16.cc", + "passes/convert_shape_constraint_to_assert.cc", "passes/convert_xla_call_module_op_to_bfloat16.cc", "passes/defer_activation_transpose.cc", "passes/fold_constant_transpose.cc", @@ -138,6 +238,7 @@ cc_library( "@llvm-project//mlir:Rewrite", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_tsl//tsl/platform:path", @@ -150,11 +251,42 @@ cc_library( "@local_xla//xla/tsl/protobuf:protos_all_cc", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", "@stablehlo//:stablehlo_portable_api", "@stablehlo//:stablehlo_serialization", ], ) +cc_library( + name = "tf_quantization_patterns", + srcs = ["passes/tf_quantization_patterns.cc"], + hdrs = [ + "passes/tf_quantization_patterns.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/common:tf_attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:tf_lift_as_function_call", + "//tensorflow/compiler/mlir/quantization/common:uniform_quantized_types", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/stablehlo/ops:tf_stablehlo_op_quant_spec", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:path", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], +) + cc_library( name = "quantization_patterns", srcs = ["passes/quantization_patterns.cc"], @@ -209,12 +341,7 @@ td_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_simple_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/lift_quantizable_spots_as_functions_simple.inc", - ), - ], + tbl_outs = {"passes/lift_quantizable_spots_as_functions_simple.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions_simple.td", deps = [ @@ -226,12 +353,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_fusion_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/lift_quantizable_spots_as_functions_fusion.inc", - ), - ], + tbl_outs = {"passes/lift_quantizable_spots_as_functions_fusion.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions_fusion.td", deps = [ @@ -243,12 +365,7 @@ gentbl_cc_library( gentbl_cc_library( name = "optimize_graph_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/optimize_graph.inc", - ), - ], + tbl_outs = {"passes/optimize_graph.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/optimize_graph.td", deps = [ @@ -260,12 +377,7 @@ gentbl_cc_library( gentbl_cc_library( name = "remove_sharding_custom_call_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/remove_sharding_custom_call.inc", - ), - ], + tbl_outs = {"passes/remove_sharding_custom_call.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/remove_sharding_custom_call.td", deps = [ @@ -276,15 +388,10 @@ gentbl_cc_library( gentbl_cc_library( name = "bridge_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=Bridge", - ], - "passes/bridge/passes.h.inc", - ), - ], + tbl_outs = {"passes/bridge/passes.h.inc": [ + "-gen-pass-decls", + "-name=Bridge", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/bridge/passes.td", deps = [ @@ -365,12 +472,7 @@ td_library( gentbl_cc_library( name = "optimize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/bridge/optimize.inc", - ), - ], + tbl_outs = {"passes/bridge/optimize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/bridge/optimize.td", deps = [":optimize_td_files"], @@ -495,15 +597,10 @@ cc_library( gentbl_cc_library( name = "stablehlo_test_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=Test", - ], - "passes/testing/passes.h.inc", - ), - ], + tbl_outs = {"passes/testing/passes.h.inc": [ + "-gen-pass-decls", + "-name=Test", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/testing/passes.td", deps = [ @@ -768,8 +865,10 @@ tf_cc_binary( ":bridge_passes", ":passes", ":test_passes", + ":tf_passes", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:pass_pipeline", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc index f18cf0f7df7fe8..a6e8fa86e9d183 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc @@ -104,8 +104,8 @@ absl::Status RunCalibrationPasses( } CalibrationComponent::CalibrationComponent( - absl::Nonnull ctx, - absl::Nonnull py_function_lib, + MLIRContext* absl_nonnull ctx, + const PyFunctionLibrary* absl_nonnull py_function_lib, const absl::string_view src_saved_model_path, absl::flat_hash_map function_aliases, std::unordered_set tags, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h index 03d2dd933732d4..d55f5afda362fa 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h @@ -57,9 +57,9 @@ class CalibrationComponent : public Component { // `representative_dataset_file_map` contains information about the // calibration dataset. CalibrationComponent( - absl::Nonnull ctx, - absl::Nonnull - py_function_lib, + MLIRContext* absl_nonnull ctx, + const tensorflow::quantization::PyFunctionLibrary* absl_nonnull + py_function_lib, absl::string_view src_saved_model_path, absl::flat_hash_map function_aliases, std::unordered_set tags, @@ -88,12 +88,12 @@ class CalibrationComponent : public Component { absl::StatusOr ImportCalibratedSavedModel( absl::string_view calibrated_saved_model_path); - absl::Nonnull ctx_; + MLIRContext* absl_nonnull ctx_; // Contains function implementations from the python layer. Should be injected // from the python level using pybind11. - absl::Nonnull - py_function_lib_; + const tensorflow::quantization::PyFunctionLibrary* absl_nonnull + py_function_lib_; // Path to the pre-calibrated SavedModel. std::string src_saved_model_path_; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc index 1bbf67389366f5..c5fc8b5b3d8d8e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc @@ -116,15 +116,13 @@ void AddXlaCallModuleOpDeserializationPasses(OpPassManager& pm) { } void AddShapeLegalizationPasses(OpPassManager& pm) { - pm.addPass(mhlo::createStablehloLegalizeToHloPass()); + // TODO: We may need to make a parent pass here that does + // shape->StableHLO+cstr because the stablehlo pass requires that the ops made + // by cstr are legal. pm.addNestedPass( - mhlo::createShapeLegalizeToHloPass(/*legalizeConstraints=*/true)); - // The following 2 passes are used to clean up the spurious UnrealizedCast ops - // and shape.assuming regions leftover from the ShapeLegalizeToHlo pass. See - // pass definition for details. + createConvertShapeToStablehloWithConstraintsPass()); pm.addPass(createReconcileUnrealizedCastsPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addPass(mhlo::createHloLegalizeToStablehloPass()); } void AddStablehloQuantToIntPasses(OpPassManager& pm) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc index 45213c10b3b7a6..ec4a10af74bca2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc @@ -38,7 +38,7 @@ using ::stablehlo::quantization::QuantizationSpecs; using ::tensorflow::quantization::RunPasses; PostCalibrationComponent::PostCalibrationComponent( - absl::Nonnull ctx) + MLIRContext* absl_nonnull ctx) : ctx_(ABSL_DIE_IF_NULL(ctx)) {} // Crash OK absl::StatusOr PostCalibrationComponent::Run( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h index 6e3762817e16a1..6692047628f08e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h @@ -39,7 +39,7 @@ class PostCalibrationComponent : public Component { // debugging purposes. static constexpr absl::string_view kName = "quant_ptq_post_calibration"; - explicit PostCalibrationComponent(absl::Nonnull ctx); + explicit PostCalibrationComponent(MLIRContext* absl_nonnull ctx); absl::StatusOr Run( ModuleOp module_op, @@ -51,7 +51,7 @@ class PostCalibrationComponent : public Component { const ::stablehlo::quantization::PipelineConfig& pipeline_config) const; private: - absl::Nonnull ctx_; + MLIRContext* absl_nonnull ctx_; }; } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc index bd7cab73d90c22..3de90290df20e5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc @@ -30,8 +30,7 @@ namespace mlir::quant::stablehlo { using ::stablehlo::quantization::QuantizationConfig; using ::tensorflow::quantization::RunPasses; -PreCalibrationComponent::PreCalibrationComponent( - absl::Nonnull ctx) +PreCalibrationComponent::PreCalibrationComponent(MLIRContext* absl_nonnull ctx) : ctx_(ABSL_DIE_IF_NULL(ctx)) {} // Crash OK absl::StatusOr PreCalibrationComponent::Run( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h index bdc61bafa569df..705f8b95bda1a9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h @@ -38,14 +38,14 @@ class PreCalibrationComponent : public Component { // debugging purposes. static constexpr absl::string_view kName = "quant_ptq_pre_calibration"; - explicit PreCalibrationComponent(absl::Nonnull ctx); + explicit PreCalibrationComponent(MLIRContext* absl_nonnull ctx); absl::StatusOr Run( ModuleOp, const ::stablehlo::quantization::QuantizationConfig& config) override; private: - absl::Nonnull ctx_; + MLIRContext* absl_nonnull ctx_; }; } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc index 47aaf31216568b..ca103374638399 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc @@ -56,8 +56,8 @@ using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; StaticRangePtqComponent::StaticRangePtqComponent( - absl::Nonnull ctx, - absl::Nonnull py_function_library, + MLIRContext* absl_nonnull ctx, + const PyFunctionLibrary* absl_nonnull py_function_library, const absl::string_view src_saved_model_path, std::vector signature_keys, std::unordered_set tags, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h index 69bd9da6733c0c..104df9aa50da60 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h @@ -51,9 +51,9 @@ class StaticRangePtqComponent : public Component { // `CalibrationComponent`. For detailed explanation of each argument, see the // comment of `CalibrationComponent`'s constructor. StaticRangePtqComponent( - absl::Nonnull ctx, - absl::Nonnull - py_function_library, + MLIRContext* absl_nonnull ctx, + const tensorflow::quantization::PyFunctionLibrary* absl_nonnull + py_function_library, absl::string_view src_saved_model_path, std::vector signature_keys, std::unordered_set tags, @@ -69,7 +69,7 @@ class StaticRangePtqComponent : public Component { private: // A non-owning `MLIRContext`. This `MLIRContext` should exceed the lifetime // of `StaticRangePtqComponent`. - absl::Nonnull ctx_; + MLIRContext* absl_nonnull ctx_; // This component consists of three sub-components, `PreCalibrationComponent`, // `CalibrationComponent`, and `PostCalibrationComponent`. std::array, 3> sub_components_; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc index 3f8215edc605cd..ec780bf8cf9a22 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc @@ -53,7 +53,7 @@ using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; using ::tensorflow::quantization::RunPasses; -WeightOnlyPtqComponent::WeightOnlyPtqComponent(absl::Nonnull ctx) +WeightOnlyPtqComponent::WeightOnlyPtqComponent(MLIRContext* absl_nonnull ctx) : ctx_(ABSL_DIE_IF_NULL(ctx)) {} // Crash OK absl::StatusOr WeightOnlyPtqComponent::Run( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h index bf23e93246c700..ba18d729042d9f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h @@ -42,14 +42,14 @@ class WeightOnlyPtqComponent : public Component { // Used for debugging purposes. static constexpr absl::string_view kName = "quant_ptq_weight_only"; - explicit WeightOnlyPtqComponent(absl::Nonnull ctx); + explicit WeightOnlyPtqComponent(MLIRContext* absl_nonnull ctx); absl::StatusOr Run( ModuleOp module_op, const ::stablehlo::quantization::QuantizationConfig& config) override; private: - absl::Nonnull ctx_; + MLIRContext* absl_nonnull ctx_; }; // Runs weight-only quantization on a SavedModel at diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.cc b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.cc index e1a705cdbb24f6..edba8f60408636 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.cc @@ -38,8 +38,8 @@ std::optional OptionalStringViewToOptionalString( } // Whether the pass is `QuantizeCompositeFunctionPass`. -bool IsQuantizeCompositeFunctionPass(absl::Nullable pass, - absl::Nullable op) { +bool IsQuantizeCompositeFunctionPass(Pass* absl_nullable pass, + Operation* absl_nullable op) { // It is known that `op` is `ModuleOp` when `pass` is // `QuantizeCompositeFunctionPass`, but the check is still performed to be // defensive. @@ -52,7 +52,7 @@ bool IsQuantizeCompositeFunctionPass(absl::Nullable pass, // * After running `QuantizeCompositeFunctionPass`. // * The pass is run on `ModuleOp`. // * `file_path` is not `nullopt`. -bool ShouldSaveReport(absl::Nullable pass, absl::Nullable op, +bool ShouldSaveReport(Pass* absl_nullable pass, Operation* absl_nullable op, const std::optional& file_path) { return file_path != std::nullopt && IsQuantizeCompositeFunctionPass(pass, op); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD index 61da2af4d3fb58..798d0ecc1396cf 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD @@ -9,6 +9,31 @@ package( licenses = ["notice"], ) +cc_library( + name = "tf_stablehlo_op_quant_spec", + srcs = [ + "tf_stablehlo_op_quant_spec.cc", + ], + hdrs = ["tf_stablehlo_op_quant_spec.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/common:tf_attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:tf_lift_as_function_call", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:protobuf", + "@stablehlo//:stablehlo_ops", + ], +) + cc_library( name = "stablehlo_op_quant_spec", srcs = [ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.cc new file mode 100644 index 00000000000000..d2e413af3e9275 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.cc @@ -0,0 +1,184 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h" + +#include + +#include "absl/status/statusor.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +// To be used with LLVM_DEBUG. +#define DEBUG_TYPE "stablehlo_opt_quant_spec" + +namespace mlir::tf_quant::stablehlo { +namespace { + +using ::mlir::stablehlo::DotGeneralOp; +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::StaticRangePtq; + +// Whether it represents a lifted function (i.e. `op` is the corresponding +// `XlaCallModuleOp`) that is explicitly marked `NoQuantization`. +bool IsDenylistedLiftedFunction(Operation* op) { + if (auto xla_call_module_op = dyn_cast_or_null(op); + xla_call_module_op != nullptr) { + absl::StatusOr method = GetQuantizationMethod(xla_call_module_op); + if (method.ok() && method->has_no_quantization()) { + return true; + } + } + return false; +} + +// Populates `spec.coeff_op_quant_dim` according to `xla_call_module_op`'s +// `_quantization_method` attribute. If there is an input `QuantizedType` with +// `dimension_specs` set, which represents the quantization dimension for the +// input, then the corresponding operand index -> quantization dimension mapping +// is set for `spec`. +// TODO: b/323478683 - Duplicate tracking of config will be eliminated. +// `OpQuantSpec` will be deprecated and `Method` will be used instead. +void PopulateCoeffOpQuantDimIfPerChannelQuantized( + TF::XlaCallModuleOp xla_call_module_op, OpQuantSpec& spec) { + absl::StatusOr method = GetQuantizationMethod(xla_call_module_op); + if (method.ok() && method->has_static_range_ptq()) { + // TODO: b/331145946 - Use `Method` accessors. + const StaticRangePtq& static_range_ptq_spec = method->static_range_ptq(); + // Look for quantized dimension specs for each quantized type and + // populate `coeff_op_quant_dim`. + for (const auto& [operand_idx, quantized_type] : + static_range_ptq_spec.input_quantized_types()) { + if (quantized_type.has_dimension_specs()) { + spec.coeff_op_quant_dim[operand_idx] = + quantized_type.dimension_specs().dimension(); + } + } + } +} + +} // namespace + +std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { + auto spec = std::make_unique(); + if (auto call_op = dyn_cast_or_null(op)) { + auto entry_function = + call_op->getAttrOfType("_entry_function"); + StringRef function_name = entry_function.getValue(); + if (!function_name.starts_with("composite_")) { + return spec; + } + + if (function_name.contains("conv")) { + // Looks up `Method` to see if it should be per-channel quantized and + // populates the spec accordingly. + PopulateCoeffOpQuantDimIfPerChannelQuantized(call_op, *spec); + + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, GetUniformQuantizedTypeForBias}; + } + } else if (function_name.contains("dot_general")) { + const auto module_op = call_op->getParentOfType(); + + const SymbolTable symbol_table(module_op); + auto entry_func_op = + dyn_cast_or_null(symbol_table.lookup(function_name)); + auto dot_general_op = *entry_func_op.getOps().begin(); + if (auto optional_dim = GetDotGeneralQuantizationDim(dot_general_op); + optional_dim) { + spec->coeff_op_quant_dim[1] = optional_dim.value(); + } else { + spec->coeff_op_quant_dim[1] = -1; + } + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, GetUniformQuantizedTypeForBias}; + } + } + for (const auto [operand_idx, per_channel_dim] : spec->coeff_op_quant_dim) { + spec->quantizable_operands.insert(operand_idx); + } + } + return spec; +} + +std::unique_ptr GetStableHloQuantConstraints(Operation* op) { + auto scale_spec = std::make_unique(); + if (llvm::isa(op)) { + scale_spec->has_same_scale_requirement = true; + } + if (llvm::isa(op)) { + scale_spec->has_same_operand_and_result_type_requirement = true; + } + return scale_spec; +} + +bool IsOpQuantizableStableHlo(Operation* op) { + if (isa(op)) { + // Constant ops do not have QuantizableResult attribute but can be + // quantized. + return true; + } else if (op->hasTrait() || + isa(op)) { + // Terminators, qcast and decast are not quantizable. + return false; + } + + // `op` is not quantizable when it is an `XlaCallModuleOp` representing lifted + // function whose `_quantization_method` attribute is marked `NoQuantization`. + // This means this quantizable unit has been explicitly denylisted by the + // user. + if (IsDenylistedLiftedFunction(op)) { + LLVM_DEBUG(llvm::errs() << "Denylisted quantizable unit: \n" << op << "\n"); + return false; + } + + if (GetStableHloQuantConstraints(op)->has_same_scale_requirement) { + return true; + } + + const bool attr_enforced_quantizable = + op->hasAttrOfType(kQuantTraitAttrName) && + op->getAttrOfType(kQuantTraitAttrName).getValue().str() == + QuantTraitValues[QuantizationTrait::FullyQuantizable]; + return attr_enforced_quantizable; +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h b/tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h new file mode 100644 index 00000000000000..2c6ca14b5f0a15 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h @@ -0,0 +1,41 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_TF_STABLEHLO_OP_QUANT_SPEC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_TF_STABLEHLO_OP_QUANT_SPEC_H_ + +#include + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir::tf_quant::stablehlo { + +// Returns StableHLO quantization specs for an op. +std::unique_ptr GetStableHloOpQuantSpec(Operation* op); + +// Returns quantization constraints (ex: fixed output, same scale) given +// a StableHLO op. +std::unique_ptr GetStableHloQuantConstraints(Operation* op); + +// Checks if an op is quantizable in StableHLO quantizer. Argument op is not +// necessarily a StableHLO op. +bool IsOpQuantizableStableHlo(Operation* op); + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_TF_STABLEHLO_OP_QUANT_SPEC_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_shape_constraint_to_assert.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_shape_constraint_to_assert.cc new file mode 100644 index 00000000000000..d63dfdeaec7514 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_shape_constraint_to_assert.cc @@ -0,0 +1,218 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/transforms/Passes.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" // IWYU pragma: keep + +namespace mlir::quant::stablehlo { + +#define GEN_PASS_DEF_CONVERTSHAPETOSTABLEHLOWITHCONSTRAINTSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" + +namespace { +using ::mlir::stablehlo::AndOp; +using ::mlir::stablehlo::CompareOp; +using ::mlir::stablehlo::ComparisonDirection; +using ::mlir::stablehlo::ConcatenateOp; +using ::mlir::stablehlo::ConstantOp; +using ::mlir::stablehlo::CustomCallOp; +using ::mlir::stablehlo::OrOp; +using ::mlir::stablehlo::ReshapeOp; +using ::mlir::stablehlo::SliceOp; + +// Cast from index-based shape representation used in the Shape dialect to the +// i32-based representation used in HLO: +// * index => tensor. +// * tensor => tensor. +// * All i32-based types from above => themselves. +// There is no convenient op that can express this, so we're using +// unrealized_conversion_cast (with the idea that all these casts will +// annihilate at the end of the pass). +Value castToI32(PatternRewriter& rewriter, Location loc, Value value) { + Type resultType; + if (value.getType().isIndex()) + resultType = RankedTensorType::get({}, rewriter.getI32Type()); + if (auto valueType = mlir::dyn_cast(value.getType())) { + if (!valueType.hasStaticShape()) return {}; + if (valueType.getElementType().isInteger(32)) return value; + if (valueType.getElementType().isIndex()) + resultType = + RankedTensorType::get(valueType.getShape(), rewriter.getI32Type()); + } + if (!resultType) return {}; + auto cast = + rewriter.create(loc, resultType, value); + return cast.getResult(0); +} + +// Pads input tensor by X ones from the left. The number X is +// determined by input pad. Result is tensor<(X+N) x i32>, where the first X +// elements are ones. +Value padFromLeft(PatternRewriter& rewriter, Location loc, Value input, + int64_t pad) { + Value padI32 = rewriter.create( + loc, DenseIntElementsAttr::get( + RankedTensorType::get({pad}, rewriter.getI32Type()), 1)); + return rewriter.create(loc, ValueRange{padI32, input}, + /*dimension=*/0); +} + +void insertShapeAssertionCustomCall(OpBuilder builder, Location loc, + Value assert) { + auto customCall = + builder.create(loc, TypeRange{}, ValueRange{assert}); + customCall.setCallTargetName("shape_assertion"); + customCall.setHasSideEffect(true); + customCall->setAttr("error_message", + builder.getStringAttr("Shape assertion failed")); +} + +struct ConvertCstrBroadcastableOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, + PatternRewriter& rewriter) const override { + // As defined, op inputs must be 1D tensor or !shape.shape. + // We only support inputs of two 1D tensors. + if (op.getShapes().size() != 2) return failure(); + auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front()); + auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back()); + if (!shape1 || !shape2) return failure(); + auto tensorType1 = mlir::dyn_cast(shape1.getType()); + auto tensorType2 = mlir::dyn_cast(shape2.getType()); + if (!tensorType1 || !tensorType2) return failure(); + + // If the two operand shapes are of different sizes, the smaller one is + // padded with 1's from the left. + int32_t rank = + std::max(tensorType1.getDimSize(0), tensorType2.getDimSize(0)); + if (tensorType1.getDimSize(0) < tensorType2.getDimSize(0)) { + shape1 = + padFromLeft(rewriter, op.getLoc(), shape1, + tensorType2.getDimSize(0) - tensorType1.getDimSize(0)); + } else if (tensorType1.getDimSize(0) > tensorType2.getDimSize(0)) { + shape2 = + padFromLeft(rewriter, op.getLoc(), shape2, + tensorType1.getDimSize(0) - tensorType2.getDimSize(0)); + } + + // Compute if each dim is broadcastable. A dim is broadcastable iff + // dimSize1 == dimSize2 or dimSize1 == 1 or dimSize2 == 1 + auto allOne = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get( + RankedTensorType::get({rank}, rewriter.getI32Type()), + static_cast(1))); + Value dimSize1Is1 = rewriter.create(op.getLoc(), shape1, allOne, + ComparisonDirection::EQ); + Value dimSize2Is1 = rewriter.create(op.getLoc(), shape2, allOne, + ComparisonDirection::EQ); + Value eitherDimSizeIs1 = + rewriter.create(op.getLoc(), dimSize1Is1, dimSize2Is1); + Value dimSizeEq = rewriter.create(op.getLoc(), shape1, shape2, + ComparisonDirection::EQ); + Value dimBroadcastable = + rewriter.create(op.getLoc(), eitherDimSizeIs1, dimSizeEq); + + // Iterate over each dim to check that all dims are broadcastable. + auto boolType = RankedTensorType::get({1}, rewriter.getI1Type()); + Value allBroadcastable = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get(boolType, true)); + for (auto i = 0; i < rank; ++i) { + Value broadcastable = rewriter.create( + op.getLoc(), dimBroadcastable, rewriter.getDenseI64ArrayAttr(i), + rewriter.getDenseI64ArrayAttr(i + 1), + rewriter.getDenseI64ArrayAttr(1)); + allBroadcastable = + rewriter.create(op.getLoc(), allBroadcastable, broadcastable); + } + Value allBroadcastableScalar = rewriter.create( + op.getLoc(), RankedTensorType::get({}, rewriter.getI1Type()), + allBroadcastable); + + // Add CustomCallOp and replace Cstr op with const witness, which is useful + // for canonicalizer to remove the shape.assuming region. + insertShapeAssertionCustomCall(rewriter, op->getLoc(), + allBroadcastableScalar); + rewriter.replaceOpWithNewOp(op.getOperation(), true); + return success(); + } +}; + +bool hasIndexStyle(Value value) { + if (value.getType().isIndex()) return true; + auto type = mlir::dyn_cast(value.getType()); + return type && type.getElementType().isIndex(); +} + +struct ConvertShapeToStablehloWithConstraintsPass + : public impl::ConvertShapeToStablehloWithConstraintsPassBase< + ConvertShapeToStablehloWithConstraintsPass> { + void runOnOperation() override { + ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addIllegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addDynamicallyLegalDialect<::mlir::stablehlo::StablehloDialect>( + [](Operation* op) { + return !llvm::any_of(op->getOperands(), hasIndexStyle); + }); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(&getContext()); + ::mlir::stablehlo::populateShapeToStablehloPatterns(&getContext(), + &patterns); + + patterns.add(&getContext()); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc index b16f8787e9ea35..1a6663f4a7356c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/base/nullability.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -47,7 +48,7 @@ using ::mlir::stablehlo::TransposeOp; // Returns `success()` if `op` is a `TransposeOp` with permutation attribute // equivalent to `permuation`. -LogicalResult IsTransposeOpWithPermuation(absl::Nullable op, +LogicalResult IsTransposeOpWithPermuation(Operation* absl_nullable op, const ArrayRef permutation) { auto transpose_op = dyn_cast_or_null(op); return success(transpose_op != nullptr && transpose_op.getPermutation() == @@ -89,8 +90,8 @@ void DeferRhsTransposeForBinaryOp(OpT op, PatternRewriter& rewriter) { // "Climbs up" the `op` if `op` is a `BraodcastInDimOp` and returns the defining // op of its operand. Returns `op` otherwise. May return `nullptr` when the // `BroadcastInDimOp`'s operand is a block argument. -absl::Nullable SkipUpwardsOptionalBroadcastInDimOp( - absl::Nonnull op) { +Operation* absl_nullable SkipUpwardsOptionalBroadcastInDimOp( + Operation* absl_nonnull op) { if (auto broadcast_in_dim_op = dyn_cast_or_null(op); broadcast_in_dim_op != nullptr) { return broadcast_in_dim_op.getOperand().getDefiningOp(); @@ -98,12 +99,12 @@ absl::Nullable SkipUpwardsOptionalBroadcastInDimOp( return op; } -class DeferActivationTransposeForAddOp - : public OpRewritePattern::SplitMatchAndRewrite { +class DeferActivationTransposeForAddOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(AddOp op) const override { + LogicalResult matchAndRewrite(AddOp op, + PatternRewriter& rewriter) const override { // Only supports the case for 2D convolution. const Value lhs = op.getOperand(0); if (!HasRankOf(lhs, /*rank=*/4)) return failure(); @@ -120,12 +121,13 @@ class DeferActivationTransposeForAddOp } // Match LHS permutation that converts: NHWC -> NCHW. - return IsTransposeOpWithPermuation(lhs.getDefiningOp(), - kNhwcToNchwPermutation); - } + if (IsTransposeOpWithPermuation(lhs.getDefiningOp(), kNhwcToNchwPermutation) + .failed()) { + return failure(); + } - void rewrite(AddOp op, PatternRewriter& rewriter) const override { DeferRhsTransposeForBinaryOp(op, rewriter); + return success(); } }; @@ -134,12 +136,12 @@ class DeferActivationTransposeForAddOp // to the result. The reduce function should be equivalent to // `stablehlo.maximum`, representing max pooling. class DeferActivationTransposeForMaxPoolReduceWindowOp - : public OpRewritePattern< - mlir::stablehlo::ReduceWindowOp>::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(mlir::stablehlo::ReduceWindowOp op) const override { + LogicalResult matchAndRewrite(mlir::stablehlo::ReduceWindowOp op, + PatternRewriter& rewriter) const override { if (failed(MatchMaxPoolReduceWindowOp(op))) return failure(); // Match only when the lhs is connected to a transpose. @@ -148,13 +150,12 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp if (!HasRankOf(lhs, /*rank=*/4)) return failure(); // Match input permutation that converts: NHWC -> NCHW. - return IsTransposeOpWithPermuation(lhs.getDefiningOp(), - kNhwcToNchwPermutation); - } + if (IsTransposeOpWithPermuation(lhs.getDefiningOp(), kNhwcToNchwPermutation) + .failed()) { + return failure(); + } - // Pushes the transpose op at the input to the result. - void rewrite(mlir::stablehlo::ReduceWindowOp op, - PatternRewriter& rewriter) const override { + // Pushes the transpose op at the input to the result. auto transpose_op = cast(op.getOperand(0).getDefiningOp()); const auto result_type = mlir::cast(op.getResult(0).getType()); @@ -194,6 +195,7 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp rewriter); rewriter.replaceAllUsesWith(op.getResult(0), result_transpose_op); + return success(); } private: @@ -242,12 +244,12 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp // Rewrites `maximum(transpose(%rhs), %lhs)` patterns to // `transpose(maximum(%rhs, transpose(%lhs)))`. -class DeferActivationTransposeForMaxOp - : public OpRewritePattern::SplitMatchAndRewrite { +class DeferActivationTransposeForMaxOp : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(MaxOp op) const override { + LogicalResult matchAndRewrite(MaxOp op, + PatternRewriter& rewriter) const override { Value input = op.getOperand(0); if (!HasRankOf(input, /*rank=*/4)) return failure(); @@ -258,12 +260,13 @@ class DeferActivationTransposeForMaxOp return failure(); } - return IsTransposeOpWithPermuation(input.getDefiningOp(), - kNhwcToNchwPermutation); - } - - void rewrite(MaxOp op, PatternRewriter& rewriter) const override { + if (IsTransposeOpWithPermuation(input.getDefiningOp(), + kNhwcToNchwPermutation) + .failed()) { + return failure(); + } DeferRhsTransposeForBinaryOp(op, rewriter); + return success(); } }; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc index 79872b57e1574e..197fb1c868afb3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc @@ -116,12 +116,12 @@ class DenseElementsTransposer { }; class FoldTransposedConstantOp - : public OpRewritePattern< - mlir::stablehlo::TransposeOp>::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(mlir::stablehlo::TransposeOp op) const override { + LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp op, + PatternRewriter& rewriter) const override { Value operand = op.getOperand(); auto const_op = dyn_cast_or_null(operand.getDefiningOp()); @@ -133,14 +133,9 @@ class FoldTransposedConstantOp return failure(); } - return success( - mlir::isa_and_nonnull(const_op.getValue())); - } - - void rewrite(mlir::stablehlo::TransposeOp op, - PatternRewriter& rewriter) const override { - auto const_op = - cast(op.getOperand().getDefiningOp()); + if (!mlir::isa_and_nonnull(const_op.getValue())) { + return failure(); + } const auto value_attr = mlir::cast(const_op.getValue()); @@ -169,7 +164,8 @@ class FoldTransposedConstantOp combined_loc, new_value_attr); rewriter.replaceAllUsesWith(op, new_const_op); - }; + return success(); + } }; } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc index b3c309c76adb79..fb2e5caba7b59f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc @@ -80,15 +80,13 @@ class InsertWeightParamPass // Inserts quantization parameters for weights for hybrid quantization of // `stablehlo.convolution` and `stablehlo.dot_general`. class InsertWeightParamPattern - : public OpTraitRewritePattern< - OpTrait::ConstantLike>::SplitMatchAndRewrite { + : public OpTraitRewritePattern { public: explicit InsertWeightParamPattern(MLIRContext* context) - : SplitMatchAndRewrite(Pattern::MatchTraitOpTypeTag(), - TypeID::get(), 1, context) { - } + : OpTraitRewritePattern(context) {} - LogicalResult match(Operation* op) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { if (op->getNumResults() != 1) { return failure(); } @@ -96,27 +94,11 @@ class InsertWeightParamPattern if (!type || !type.getElementType().isF32()) { return failure(); } - return success( - op->hasOneUse() && - IsWeightQuantizableFunction(*op->getUses().begin(), type.getRank())); - } - - // Checks if the operand is second operand of `tf.XlaCallModule` op for - // `stablehlo.convolution` or `stablehlo.dot_general` with fully_quantizable - // trait. - static bool IsWeightQuantizableFunction(OpOperand& operand, int64_t rank) { - if (operand.getOperandNumber() != 1) { - return false; - } - Operation* user = operand.getOwner(); - if (!IsWeightOnlyQuantizableOp(*user)) { - return false; + if (!op->hasOneUse() || + !IsWeightQuantizableFunction(*op->getUses().begin(), type.getRank())) { + return failure(); } - Method method = GetQuantizationMethodOrDefault(user); - return HasValidWeightOnlyPtqMethod(method.weight_only_ptq(), rank); - } - void rewrite(Operation* op, PatternRewriter& rewriter) const override { Operation* quantizable_op = *op->getUsers().begin(); DenseFPElementsAttr attr; matchPattern(op->getResult(0), m_Constant(&attr)); @@ -144,7 +126,7 @@ class InsertWeightParamPattern op->emitError( "Failed to get weight quantization parameters for weight-only " "quantization."); - return; + return failure(); } const Type expressed_type = op->getResult(0).getType(); @@ -157,6 +139,22 @@ class InsertWeightParamPattern auto dq = rewriter.create(op->getLoc(), expressed_type, q); quantizable_op->setOperand(1, dq.getResult()); + return success(); + } + + // Checks if the operand is second operand of `tf.XlaCallModule` op for + // `stablehlo.convolution` or `stablehlo.dot_general` with fully_quantizable + // trait. + static bool IsWeightQuantizableFunction(OpOperand& operand, int64_t rank) { + if (operand.getOperandNumber() != 1) { + return false; + } + Operation* user = operand.getOwner(); + if (!IsWeightOnlyQuantizableOp(*user)) { + return false; + } + Method method = GetQuantizationMethodOrDefault(user); + return HasValidWeightOnlyPtqMethod(method.weight_only_ptq(), rank); } private: @@ -221,7 +219,7 @@ class InsertWeightParamPattern dimension_numbers.getRhsContractingDimensions(); ArrayRef rhs_batching_dims = dimension_numbers.getRhsBatchingDimensions(); - int64_t rank = dot.getRhs().getType().cast().getRank(); + int64_t rank = mlir::cast(dot.getRhs().getType()).getRank(); for (int i = 0; i < rank; ++i) { // Return the first non-contracting, non-batching dimension of rhs. if (llvm::find(rhs_contracting_dims, i) == rhs_contracting_dims.end() && @@ -230,7 +228,7 @@ class InsertWeightParamPattern } } } - return op.getOperand(1).getType().cast().getRank() - 1; + return mlir::cast(op.getOperand(1).getType()).getRank() - 1; } }; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc index 3b2b20bc2e4c52..4bb871a56886b3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc @@ -48,12 +48,12 @@ class NchwConvolutionToNhwcPass // * Src dimension numbers: [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1] // * Dst dimension numbers: [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] class RewriteNchwConvolutionToNhwc - : public OpRewritePattern< - mlir::stablehlo::ConvolutionOp>::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(mlir::stablehlo::ConvolutionOp op) const override { + LogicalResult matchAndRewrite(mlir::stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { // Handles 2D convolutions only. if (!HasRankOf(op.getOperand(0), /*rank=*/4) || !HasRankOf(op.getOperand(1), /*rank=*/4)) { @@ -63,13 +63,14 @@ class RewriteNchwConvolutionToNhwc if (!IsOpNotQuantized(op)) return failure(); const ConvDimensionNumbersAttr dimension_nums = op.getDimensionNumbers(); - return success(MatchInputDimensionNumbers(dimension_nums) && - MatchKernelDimensionNumbers(dimension_nums) && - MatchOutputDimensionNumbers(dimension_nums)); - } + const bool dimension_nums_matched = + MatchInputDimensionNumbers(dimension_nums) && + MatchKernelDimensionNumbers(dimension_nums) && + MatchOutputDimensionNumbers(dimension_nums); + if (!dimension_nums_matched) { + return failure(); + } - void rewrite(mlir::stablehlo::ConvolutionOp op, - PatternRewriter& rewriter) const override { // Transpose the input tensor: [b, f, 0, 1] => [b, 0, 1, f] Value input = op->getOperand(0); const TensorType new_input_tensor_type = GetTransposedTensorType( @@ -130,6 +131,7 @@ class RewriteNchwConvolutionToNhwc rewriter.getDenseI64ArrayAttr(kNhwcToNchwPermutation)); rewriter.replaceAllUsesWith(op, output_transpose_op); + return success(); } private: diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index da59c218a56926..e6108ca6d13e02 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -153,6 +153,15 @@ def ConvertXlaCallModuleOpToBfloat16Pass : Pass<"stablehlo-convert-xla-call-modu ]; } +def ConvertShapeToStablehloWithConstraintsPass : Pass<"stablehlo-convert-shape-to-stablehlo-with-constraints", "mlir::func::FuncOp"> { + let summary = "Convert shape.cstr_broadcastable to stablehlo.custom_call @shape_assertion"; + let dependentDialects = [ + "mlir::shape::ShapeDialect", + "mlir::tensor::TensorDialect", + "mlir::stablehlo::StablehloDialect", + ]; +} + def OptimizeGraphPass : Pass<"optimize-graph", "ModuleOp"> { let summary = "Optimize the sub-optimal patterns after quantization."; let dependentDialects = ["mlir::stablehlo::StablehloDialect",]; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index 232115e53d3219..d6a88055c8c855 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -668,16 +668,16 @@ void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( template >> -class XlaCallModuleOpToCallOp - : public OpRewritePattern::SplitMatchAndRewrite { +class XlaCallModuleOpToCallOp : public OpRewritePattern { public: explicit XlaCallModuleOpToCallOp( MLIRContext& ctx, const bool enable_per_channel_quantized_weight) - : OpRewritePattern::SplitMatchAndRewrite(&ctx), + : OpRewritePattern::OpRewritePattern(&ctx), enable_per_channel_quantized_weight_( enable_per_channel_quantized_weight) {} - LogicalResult match(TF::XlaCallModuleOp op) const override { + LogicalResult matchAndRewrite(TF::XlaCallModuleOp op, + PatternRewriter& rewriter) const override { ModuleOp module_op = op->getParentOfType(); // Ignore ops without quantization method. @@ -698,22 +698,20 @@ class XlaCallModuleOpToCallOp return failure(); } Method quantization_method = GetQuantizationMethodOrDefault(op); - return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) - .match(entry_func_op, quantization_method); - } + if (FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) + .match(entry_func_op, quantization_method) + .failed()) { + return failure(); + } - void rewrite(TF::XlaCallModuleOp xla_call_module_op, - PatternRewriter& rewriter) const override { // TODO: b/331145946 - Each quantization method should be valid // (GetQuantizationMethodOrDefault swallows invalid method attribute). Check // the validity in `match()`. Use accessors to achieve this. - const Method quantization_method = - GetQuantizationMethodOrDefault(xla_call_module_op); - ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( - *rewriter.getContext(), rewriter, xla_call_module_op, + *rewriter.getContext(), rewriter, op, FuncBodyRewritePatternT(enable_per_channel_quantized_weight_), quantization_method); + return success(); } private: @@ -726,14 +724,22 @@ class XlaCallModuleOpToCallOp // Quantizes only when the nested region consists of ops whose quantization // parameters can be propagated from outside. class QuantizeOpWithRegionPattern - : public OpRewritePattern< - quantfork::DequantizeCastOp>::SplitMatchAndRewrite { + : public OpRewritePattern { public: explicit QuantizeOpWithRegionPattern(MLIRContext& ctx) - : OpRewritePattern::SplitMatchAndRewrite( - &ctx) {}; + : OpRewritePattern(&ctx) {}; - LogicalResult match(quantfork::DequantizeCastOp op) const final { + LogicalResult matchAndRewrite(quantfork::DequantizeCastOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(quantfork::DequantizeCastOp op) const { // Match only when there is one user of the dequantize op. if (!op.getResult().hasOneUse()) { return failure(); @@ -762,7 +768,7 @@ class QuantizeOpWithRegionPattern } void rewrite(quantfork::DequantizeCastOp op, - PatternRewriter& rewriter) const final { + PatternRewriter& rewriter) const { // Rewrite the floating-point ops to the quantized version, by fusing // preceding dequantize ops and succeding quantize ops. for (Operation* op_with_region : op.getResult().getUsers()) { @@ -849,7 +855,6 @@ class QuantizeOpWithRegionPattern } } - private: // Checks if an op is quantizable in a nested region. bool IsOpQuantizableInNestedRegion(Operation& op) const { return isa(op); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.td index 70ee6dc077ee11..0ff3ece326d242 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.td @@ -15,7 +15,7 @@ limitations under the License. include "stablehlo/dialect/StablehloOps.td" class IsStringAttrOf : Constraint< - CPred<"::llvm::isa_and_nonnull($_self) && $_self.cast().getValue() == \"" # value # "\"">, + CPred<"::llvm::isa_and_nonnull($_self) && llvm::cast($_self).getValue() == \"" # value # "\"">, "Is a string attribute whose value is \"" # value # "\"" >; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_func_to_bfloat16.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_func_to_bfloat16.cc new file mode 100644 index 00000000000000..d4f2d88ea34f32 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_func_to_bfloat16.cc @@ -0,0 +1,232 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.h" +#include "tensorflow/core/platform/bfloat16.h" + +namespace mlir::tf_quant::stablehlo { +namespace { + +class BFloat16TypeConverter : public TypeConverter { + public: + BFloat16TypeConverter() { + addConversion([](const Type type) -> Type { + return quant::stablehlo::IsLargeFloatType(type) + ? quant::stablehlo::ToBfloat16Type(type) + : type; + }); + } +}; + +// This helper function makes legality check easier. Both convert ops in the +// patterns below are considered legal: +// - `BitcastConvertOp` (i32 -> f32) + `ConvertOp` (f32 -> bf16) +// - `ConvertOp` (bf16 -> f32) -> `BitcastConvertOp` (f32 -> i32) +template +bool IsConvertOpLegal(ConvertOp convert_op, BFloat16TypeConverter& converter) { + if (!converter.isLegal(convert_op.getOperand().getType())) { + auto other_convert_op = dyn_cast_or_null( + convert_op.getOperand().getDefiningOp()); + return other_convert_op && + converter.isLegal(other_convert_op.getOperand().getType()); + } else if (!converter.isLegal(convert_op.getResult().getType())) { + if (!convert_op.getResult().hasOneUse()) { + return false; + } + auto other_convert_op = dyn_cast_or_null( + *convert_op.getResult().getUsers().begin()); + return other_convert_op && + converter.isLegal(other_convert_op.getResult().getType()); + } + return true; +} + +class BFloat16TypeConversionTarget : public ConversionTarget { + public: + explicit BFloat16TypeConversionTarget(MLIRContext& ctx, + BFloat16TypeConverter& converter) + : ConversionTarget(ctx), converter_(converter) { + markUnknownOpDynamicallyLegal([this](Operation* op) { + // The FuncOp type can contain types that the op's operand and result + // types do not contain. + if (auto func = dyn_cast(op)) { + if (!converter_.isSignatureLegal(func.getFunctionType())) return false; + } else if (auto bitcast_convert_op = + dyn_cast(op)) { + return IsConvertOpLegal(bitcast_convert_op, + converter_); + } else if (auto convert_op = dyn_cast(op)) { + return IsConvertOpLegal(convert_op, + converter_); + } + return converter_.isLegal(op); + }); + } + + private: + BFloat16TypeConverter& converter_; +}; + +class BFloat16TypePattern : public ConversionPattern { + public: + BFloat16TypePattern(TypeConverter& converter, MLIRContext* ctx) + : ConversionPattern(converter, MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite( + Operation* op, const ArrayRef operands, + ConversionPatternRewriter& rewriter) const override { + if (getTypeConverter()->isLegal(op)) { + return failure(); + } + if (isa(op)) { + // Skip `BitcastConvertOp`, which is handled by the other pattern. + return failure(); + } + + // Update the results. + SmallVector new_results; + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + new_results))) + return failure(); + + // Update the regions. The dialect conversion framework wants new regions to + // be created and updated, rather than updating the old op. Thus we use an + // OperationState so we can add regions to the new op. + OperationState state(op->getLoc(), op->getName().getStringRef(), operands, + new_results, op->getAttrs(), op->getSuccessors()); + for (Region& region : op->getRegions()) { + auto new_region = std::make_unique(op); + rewriter.inlineRegionBefore(region, *new_region, new_region->begin()); + if (failed(rewriter.convertRegionTypes(new_region.get(), + *getTypeConverter()))) { + return failure(); + } + state.addRegion(std::move(new_region)); + } + + // Convert value of ConstantOp to bfloat16. + if (auto const_op = dyn_cast(op)) { + const auto values = const_op.getValue().tryGetValues(); + if (!values.has_value()) { + return failure(); + } + const SmallVector bfloat16_values(values->begin(), + values->end()); + state.attributes.set( + const_op.getValueAttrName(), + DenseFPElementsAttr::get( + mlir::dyn_cast(const_op.getValue().getType()) + .clone(rewriter.getBF16Type()), + bfloat16_values)); + } + + rewriter.replaceOp(op, rewriter.create(state)->getResults()); + + return success(); + } +}; + +class BitcastConvertOpPattern + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::stablehlo::BitcastConvertOp op, + mlir::stablehlo::BitcastConvertOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + const bool is_input_legal = + getTypeConverter()->isLegal(op.getOperand().getType()); + const bool is_output_legal = + getTypeConverter()->isLegal(op.getResult().getType()); + if (is_input_legal && is_output_legal) { + return failure(); + } else if (is_input_legal) { + // output is f32, we bitcast_convert to f32 and then convert to bf16. + const Value output = rewriter.create( + op->getLoc(), op.getResult().getType(), adaptor.getOperand()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getResult().getType()), + output); + } else if (is_output_legal) { + // input is f32, we convert from bf16 and then bitcast_convert. + const Value output = rewriter.create( + op->getLoc(), op.getOperand().getType(), adaptor.getOperand()); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), output); + } else { + // Both input/output are f32. Convert to no-op. + rewriter.replaceOp(op, adaptor.getOperand()); + } + return success(); + } +}; +} // namespace + +#define GEN_PASS_DEF_CONVERTFUNCTOBFLOAT16PASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" +namespace { +class ConvertFuncToBfloat16Pass + : public impl::ConvertFuncToBfloat16PassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertFuncToBfloat16Pass) + + explicit ConvertFuncToBfloat16Pass() = default; + + private: + void runOnOperation() override; +}; + +void ConvertFuncToBfloat16Pass::runOnOperation() { + func::FuncOp func_op = getOperation(); + MLIRContext* context = func_op.getContext(); + RewritePatternSet patterns(context); + + BFloat16TypeConverter converter; + patterns.add(converter, + context); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + BFloat16TypeConversionTarget target(*context, converter); + if (failed(applyPartialConversion(func_op.getOperation(), target, + std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_shape_constraint_to_assert.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_shape_constraint_to_assert.cc new file mode 100644 index 00000000000000..bc9f247c7195cf --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_shape_constraint_to_assert.cc @@ -0,0 +1,215 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/transforms/Passes.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_CONVERTSHAPETOSTABLEHLOWITHCONSTRAINTSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { +using ::mlir::stablehlo::AndOp; +using ::mlir::stablehlo::CompareOp; +using ::mlir::stablehlo::ComparisonDirection; +using ::mlir::stablehlo::ConcatenateOp; +using ::mlir::stablehlo::ConstantOp; +using ::mlir::stablehlo::CustomCallOp; +using ::mlir::stablehlo::OrOp; +using ::mlir::stablehlo::ReshapeOp; +using ::mlir::stablehlo::SliceOp; + +// Cast from index-based shape representation used in the Shape dialect to the +// i32-based representation used in HLO: +// * index => tensor. +// * tensor => tensor. +// * All i32-based types from above => themselves. +// There is no convenient op that can express this, so we're using +// unrealized_conversion_cast (with the idea that all these casts will +// annihilate at the end of the pass). +Value castToI32(PatternRewriter& rewriter, Location loc, Value value) { + Type resultType; + if (value.getType().isIndex()) + resultType = RankedTensorType::get({}, rewriter.getI32Type()); + if (auto valueType = mlir::dyn_cast(value.getType())) { + if (!valueType.hasStaticShape()) return {}; + if (valueType.getElementType().isInteger(32)) return value; + if (valueType.getElementType().isIndex()) + resultType = + RankedTensorType::get(valueType.getShape(), rewriter.getI32Type()); + } + if (!resultType) return {}; + auto cast = + rewriter.create(loc, resultType, value); + return cast.getResult(0); +} + +// Pads input tensor by X ones from the left. The number X is +// determined by input pad. Result is tensor<(X+N) x i32>, where the first X +// elements are ones. +Value padFromLeft(PatternRewriter& rewriter, Location loc, Value input, + int64_t pad) { + Value padI32 = rewriter.create( + loc, DenseIntElementsAttr::get( + RankedTensorType::get({pad}, rewriter.getI32Type()), 1)); + return rewriter.create(loc, ValueRange{padI32, input}, + /*dimension=*/0); +} + +void insertShapeAssertionCustomCall(OpBuilder builder, Location loc, + Value assert) { + auto customCall = + builder.create(loc, TypeRange{}, ValueRange{assert}); + customCall.setCallTargetName("shape_assertion"); + customCall.setHasSideEffect(true); + customCall->setAttr("error_message", + builder.getStringAttr("Shape assertion failed")); +} + +struct ConvertCstrBroadcastableOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, + PatternRewriter& rewriter) const override { + // As defined, op inputs must be 1D tensor or !shape.shape. + // We only support inputs of two 1D tensors. + if (op.getShapes().size() != 2) return failure(); + auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front()); + auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back()); + if (!shape1 || !shape2) return failure(); + auto tensorType1 = mlir::dyn_cast(shape1.getType()); + auto tensorType2 = mlir::dyn_cast(shape2.getType()); + if (!tensorType1 || !tensorType2) return failure(); + + // If the two operand shapes are of different sizes, the smaller one is + // padded with 1's from the left. + int32_t rank = + std::max(tensorType1.getDimSize(0), tensorType2.getDimSize(0)); + if (tensorType1.getDimSize(0) < tensorType2.getDimSize(0)) { + shape1 = + padFromLeft(rewriter, op.getLoc(), shape1, + tensorType2.getDimSize(0) - tensorType1.getDimSize(0)); + } else if (tensorType1.getDimSize(0) > tensorType2.getDimSize(0)) { + shape2 = + padFromLeft(rewriter, op.getLoc(), shape2, + tensorType1.getDimSize(0) - tensorType2.getDimSize(0)); + } + + // Compute if each dim is broadcastable. A dim is broadcastable iff + // dimSize1 == dimSize2 or dimSize1 == 1 or dimSize2 == 1 + auto allOne = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get( + RankedTensorType::get({rank}, rewriter.getI32Type()), + static_cast(1))); + Value dimSize1Is1 = rewriter.create(op.getLoc(), shape1, allOne, + ComparisonDirection::EQ); + Value dimSize2Is1 = rewriter.create(op.getLoc(), shape2, allOne, + ComparisonDirection::EQ); + Value eitherDimSizeIs1 = + rewriter.create(op.getLoc(), dimSize1Is1, dimSize2Is1); + Value dimSizeEq = rewriter.create(op.getLoc(), shape1, shape2, + ComparisonDirection::EQ); + Value dimBroadcastable = + rewriter.create(op.getLoc(), eitherDimSizeIs1, dimSizeEq); + + // Iterate over each dim to check that all dims are broadcastable. + auto boolType = RankedTensorType::get({1}, rewriter.getI1Type()); + Value allBroadcastable = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get(boolType, true)); + for (auto i = 0; i < rank; ++i) { + Value broadcastable = rewriter.create( + op.getLoc(), dimBroadcastable, rewriter.getDenseI64ArrayAttr(i), + rewriter.getDenseI64ArrayAttr(i + 1), + rewriter.getDenseI64ArrayAttr(1)); + allBroadcastable = + rewriter.create(op.getLoc(), allBroadcastable, broadcastable); + } + Value allBroadcastableScalar = rewriter.create( + op.getLoc(), RankedTensorType::get({}, rewriter.getI1Type()), + allBroadcastable); + + // Add CustomCallOp and replace Cstr op with const witness, which is useful + // for canonicalizer to remove the shape.assuming region. + insertShapeAssertionCustomCall(rewriter, op->getLoc(), + allBroadcastableScalar); + rewriter.replaceOpWithNewOp(op.getOperation(), true); + return success(); + } +}; + +bool hasIndexStyle(Value value) { + if (value.getType().isIndex()) return true; + auto type = mlir::dyn_cast(value.getType()); + return type && type.getElementType().isIndex(); +} + +struct ConvertShapeToStablehloWithConstraintsPass + : public impl::ConvertShapeToStablehloWithConstraintsPassBase< + ConvertShapeToStablehloWithConstraintsPass> { + void runOnOperation() override { + ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addIllegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addDynamicallyLegalDialect<::mlir::stablehlo::StablehloDialect>( + [](Operation* op) { + return !llvm::any_of(op->getOperands(), hasIndexStyle); + }); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(&getContext()); + ::mlir::stablehlo::populateShapeToStablehloPatterns(&getContext(), + &patterns); + + patterns.add(&getContext()); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_xla_call_module_op_to_bfloat16.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_xla_call_module_op_to_bfloat16.cc new file mode 100644 index 00000000000000..2db14f7470f049 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_xla_call_module_op_to_bfloat16.cc @@ -0,0 +1,146 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/Serialization.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::tf_quant::stablehlo { + +absl::StatusOr ConvertSerializedStableHloModuleToBfloat16( + const StringRef serialized_stablehlo_module) { + // StableHLO module is empty often because the XlaCallModuleOp is already + // deserialized, e.g. after invoking XlaCallModuleDeserializationPass. We + // don't handle this situation. + if (serialized_stablehlo_module.empty()) { + return absl::InvalidArgumentError("StableHLO module is empty."); + } + + MLIRContext context; + OwningOpRef stablehlo_module_op = + mlir::stablehlo::deserializePortableArtifact(serialized_stablehlo_module, + &context); + auto version = + mlir::stablehlo::getPortableArtifactVersion(serialized_stablehlo_module); + if (failed(version)) { + return absl::InternalError( + "Failed to get the deserialized StableHLO version, XlaCallModuleOp " + "must have a valid StableHLO module serialized using " + "stablehlo::serializePortableArtifact APIs."); + } + + // Convert the StableHLO module to bfloat16. + PassManager pm(&context); + pm.addNestedPass(createConvertFuncToBfloat16Pass()); + if (failed(pm.run(stablehlo_module_op.get()))) { + return absl::InternalError( + "Failed to convert StableHLO module to bfloat16."); + } + + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + if (failed(mlir::stablehlo::serializePortableArtifact( + stablehlo_module_op.get(), version.value().toString(), os))) { + return absl::InternalError("Failed to serialize StableHLO module."); + } + return bytecode; +} + +#define GEN_PASS_DEF_CONVERTXLACALLMODULEOPTOBFLOAT16PASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { +class ConvertXlaCallModuleOpToBfloat16Pass + : public impl::ConvertXlaCallModuleOpToBfloat16PassBase< + ConvertXlaCallModuleOpToBfloat16Pass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + ConvertXlaCallModuleOpToBfloat16Pass) + + explicit ConvertXlaCallModuleOpToBfloat16Pass() = default; + + private: + void runOnOperation() override; +}; + +void ConvertXlaCallModuleOpToBfloat16Pass::runOnOperation() { + Operation* func_op = getOperation(); + SymbolTableCollection symbol_table; + OpBuilder builder(&getContext()); + + auto result = func_op->walk([&](TF::XlaCallModuleOp op) { + // Converts the serialized StableHLO module to bfloat16. + auto result = + ConvertSerializedStableHloModuleToBfloat16(op.getModuleAttr()); + if (!result.ok()) { + llvm::errs() << "Failed to convert StableHLO module to bfloat16: " + << result.status().message(); + return WalkResult::interrupt(); + } + op.setModuleAttr(StringAttr::get(&getContext(), *result)); + + // Convert the `tf.XlaCallModuleOp` to bfloat16 and add casts around it. + builder.setInsertionPoint(op); + for (auto& op_operand : op->getOpOperands()) { + if (quant::stablehlo::IsLargeFloatType(op_operand.get().getType())) { + op_operand.set(builder.create( + op->getLoc(), + quant::stablehlo::ToBfloat16Type(op_operand.get().getType()), + op_operand.get())); + } + } + builder.setInsertionPointAfter(op); + for (auto op_result : op->getOpResults()) { + if (quant::stablehlo::IsLargeFloatType(op_result.getType())) { + const Type original_type = op_result.getType(); + op_result.setType(quant::stablehlo::ToBfloat16Type(original_type)); + const Value cast = + builder.create(op->getLoc(), original_type, op_result); + op_result.replaceAllUsesExcept(cast, cast.getDefiningOp()); + } + } + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) return signalPassFailure(); +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_defer_activation_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_defer_activation_transpose.cc new file mode 100644 index 00000000000000..f2816f4a700c72 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_defer_activation_transpose.cc @@ -0,0 +1,294 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/base/nullability.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_DEFERACTIVATIONTRANSPOSEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +using ::mlir::stablehlo::AddOp; +using ::mlir::stablehlo::BroadcastInDimOp; +using ::mlir::stablehlo::MaxOp; +using ::mlir::stablehlo::TransposeOp; + +// Returns `success()` if `op` is a `TransposeOp` with permutation attribute +// equivalent to `permuation`. +LogicalResult IsTransposeOpWithPermuation(Operation* absl_nullable op, + const ArrayRef permutation) { + auto transpose_op = dyn_cast_or_null(op); + return success(transpose_op != nullptr && transpose_op.getPermutation() == + ArrayRef(permutation)); +} + +// Convenience function to create a `TransposeOp` with a given `permutation`. +// The Location is set as `input`'s loc. +TransposeOp CreateTransposeOp(Value input, const ArrayRef permutation, + PatternRewriter& rewriter) { + return rewriter.create( + input.getLoc(), input, rewriter.getDenseI64ArrayAttr(permutation)); +} + +// Defers the transpose of the left-hand side (LHS) to the right-hand side and +// the result of a binary operation. In detail, this rewrites the +// `op(transpose(%rhs), %lhs)` to `transpose(op(%rhs, transpose(%lhs)))`. The +// LHS transpose permutation must be a NCHW->NHWC permutation. +template +void DeferRhsTransposeForBinaryOp(OpT op, PatternRewriter& rewriter) { + auto transpose_op = cast(op.getOperand(0).getDefiningOp()); + Value lhs_pre_transpose = transpose_op.getOperand(); + + // NCHW -> NHWC for the right-hand side, to match the operand's shape. + Value rhs = op.getOperand(1); + TransposeOp rhs_transpose_op = CreateTransposeOp( + /*input=*/rhs, kNchwToNhwcPermutation, rewriter); + + auto new_binary_op = + rewriter.create(op.getLoc(), lhs_pre_transpose, rhs_transpose_op); + + // NHWC -> NCHW for the output, to match the shapes of `op`'s users. + TransposeOp output_transpose_op = CreateTransposeOp( + /*input=*/new_binary_op, kNhwcToNchwPermutation, rewriter); + + rewriter.replaceAllUsesWith(op.getResult(), output_transpose_op); +} + +// "Climbs up" the `op` if `op` is a `BraodcastInDimOp` and returns the defining +// op of its operand. Returns `op` otherwise. May return `nullptr` when the +// `BroadcastInDimOp`'s operand is a block argument. +Operation* absl_nullable SkipUpwardsOptionalBroadcastInDimOp( + Operation* absl_nonnull op) { + if (auto broadcast_in_dim_op = dyn_cast_or_null(op); + broadcast_in_dim_op != nullptr) { + return broadcast_in_dim_op.getOperand().getDefiningOp(); + } + return op; +} + +class DeferActivationTransposeForAddOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AddOp op, + PatternRewriter& rewriter) const override { + // Only supports the case for 2D convolution. + const Value lhs = op.getOperand(0); + if (!HasRankOf(lhs, /*rank=*/4)) return failure(); + + const Value rhs = op.getOperand(1); + Operation* rhs_op = rhs.getDefiningOp(); + if (rhs_op == nullptr) return failure(); + + // Ignore the optional `BroadcastInDimOp` in between the constant and RHS. + rhs_op = SkipUpwardsOptionalBroadcastInDimOp(rhs_op); + + if (rhs_op == nullptr || !rhs_op->hasTrait()) { + return failure(); + } + + // Match LHS permutation that converts: NHWC -> NCHW. + if (IsTransposeOpWithPermuation(lhs.getDefiningOp(), kNhwcToNchwPermutation) + .failed()) { + return failure(); + } + + DeferRhsTransposeForBinaryOp(op, rewriter); + return success(); + } +}; + +// Rewrites the `reduce_window(transpose(%activation), %init_value)` patterns to +// `transpose(reduce_window(%activation), %init_value)`, deferring the transpose +// to the result. The reduce function should be equivalent to +// `stablehlo.maximum`, representing max pooling. +class DeferActivationTransposeForMaxPoolReduceWindowOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::ReduceWindowOp op, + PatternRewriter& rewriter) const override { + if (failed(MatchMaxPoolReduceWindowOp(op))) return failure(); + + // Match only when the lhs is connected to a transpose. + // Only supports the case commonly appearing for 2D convolutions. + Value lhs = op.getOperand(0); + if (!HasRankOf(lhs, /*rank=*/4)) return failure(); + + // Match input permutation that converts: NHWC -> NCHW. + if (IsTransposeOpWithPermuation(lhs.getDefiningOp(), kNhwcToNchwPermutation) + .failed()) { + return failure(); + } + + // Pushes the transpose op at the input to the result. + auto transpose_op = cast(op.getOperand(0).getDefiningOp()); + + const auto result_type = mlir::cast(op.getResult(0).getType()); + const SmallVector new_result_shape = + quant::Permute(result_type.getShape(), kNchwToNhwcPermutation); + + const TensorType new_result_type = + result_type.cloneWith(new_result_shape, result_type.getElementType()); + + // Create a new `stablehlo.reduce_window` with all relevant attributes + // permutated to match the new operand & result type. + auto new_reduce_window_op = + rewriter.create( + op.getLoc(), new_result_type, transpose_op.getOperand(), + /*init_value=*/op.getOperand(1), + /*window_dimensions=*/ + PermuteI64ArrayAttr(rewriter, op.getWindowDimensions(), + kNchwToNhwcPermutation), + /*window_strides=*/ + PermuteI64ArrayAttr(rewriter, op.getWindowStrides(), + kNchwToNhwcPermutation), + /*base_dilations=*/ + PermuteI64ArrayAttr(rewriter, op.getBaseDilations(), + kNchwToNhwcPermutation), + /*window_dilations=*/ + PermuteI64ArrayAttr(rewriter, op.getWindowDilations(), + kNchwToNhwcPermutation), + /*padding=*/DenseIntElementsAttr(nullptr)); + + // Clone the reduce body. It is not affected by the permutation. + IRMapping mapping; + op.getBody().cloneInto(&new_reduce_window_op.getBody(), mapping); + + // Introduce a transpose to the result to match the shapes of `op`'s uses. + TransposeOp result_transpose_op = CreateTransposeOp( + /*input=*/new_reduce_window_op.getResult(0), kNhwcToNchwPermutation, + rewriter); + + rewriter.replaceAllUsesWith(op.getResult(0), result_transpose_op); + return success(); + } + + private: + // Permutes `array_attr` with `permutation`. The number of elements in + // `array_attr` and `permutation` must be equal. Returns a null attribute + // if `array_attr` is null. + DenseI64ArrayAttr PermuteI64ArrayAttr( + PatternRewriter& rewriter, + const std::optional> array_attr, + const ArrayRef permutation) const { + if (!array_attr.has_value()) return DenseI64ArrayAttr(nullptr); + + return rewriter.getDenseI64ArrayAttr( + quant::Permute(array_attr.value(), permutation)); + } + + LogicalResult MatchMaxPoolReduceWindowOp( + mlir::stablehlo::ReduceWindowOp op) const { + // TODO: b/321099943 - Support explicit padding. + if (HasPadding(op)) return failure(); + + // Check that the reduce-window body is a max operation. + return success(IsMaxFunction(op.getBody().front())); + } + + // Whether `block` semantically corresponds to a `stablehlo.maximum` op. + bool IsMaxFunction(Block& block) const { + if (block.getNumArguments() != 2) return false; + + auto return_op = cast(block.getTerminator()); + if (return_op.getNumOperands() != 1) return false; + + auto max_op = dyn_cast_or_null( + return_op.getOperands().front().getDefiningOp()); + if (!max_op) return false; + + return (max_op.getLhs() == block.getArgument(0)) && + (max_op.getRhs() == block.getArgument(1)); + } + + // Whether `op` has the `padding` attribute (which is optional). + bool HasPadding(mlir::stablehlo::ReduceWindowOp op) const { + return op.getPadding() != std::nullopt; + } +}; + +// Rewrites `maximum(transpose(%rhs), %lhs)` patterns to +// `transpose(maximum(%rhs, transpose(%lhs)))`. +class DeferActivationTransposeForMaxOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MaxOp op, + PatternRewriter& rewriter) const override { + Value input = op.getOperand(0); + if (!HasRankOf(input, /*rank=*/4)) return failure(); + + const Value max_value = op.getOperand(1); + Operation* max_value_op = max_value.getDefiningOp(); + if (max_value_op == nullptr || + !max_value_op->hasTrait()) { + return failure(); + } + + if (IsTransposeOpWithPermuation(input.getDefiningOp(), + kNhwcToNchwPermutation) + .failed()) { + return failure(); + } + DeferRhsTransposeForBinaryOp(op, rewriter); + return success(); + } +}; + +} // namespace + +class DeferActivationTransposePass + : public impl::DeferActivationTransposePassBase< + DeferActivationTransposePass> { + private: + void runOnOperation() override; +}; + +void DeferActivationTransposePass::runOnOperation() { + func::FuncOp func_op = getOperation(); + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + patterns.add(&ctx); + if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) { + func_op->emitWarning() << "Failed to converge patterns: " << getArgument(); + } +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_fold_constant_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_fold_constant_transpose.cc new file mode 100644 index 00000000000000..4de2b0ee026b20 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_fold_constant_transpose.cc @@ -0,0 +1,195 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_FOLDCONSTANTTRANSPOSEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Returns contiguous offset (address) of the position represented by `indices` +// in a `shape` shaped tensor. Assumes row-major order. `indices` and `shape` +// should have the same size. +// Example: Index (2, 3) of a (4, 5)-shaped tensor has the contiguous offset of +// 2 * 5 + 3 = 13. +int64_t GetContiguousOffset(const ArrayRef indices, + const ArrayRef shape) { + int64_t contiguous_offset = 0; + int64_t base_offset = 1; + for (auto [i, dimension] : llvm::reverse(llvm::zip_equal(indices, shape))) { + contiguous_offset += base_offset * i; + base_offset *= dimension; + } + + return contiguous_offset; +} + +// Performs transposition of a tensor represented as a contiguous element array. +// Assumes row-major order. The shape of the input tensor and the desired +// permutation is registered during construction, and calling `TransposeValues` +// returns the transposed tensor values. +class DenseElementsTransposer { + public: + DenseElementsTransposer(const ArrayRef original_shape, + const ArrayRef permutation) + : rank_(original_shape.size()), + original_shape_(original_shape), + target_shape_(quant::Permute(original_shape, permutation)), + permutation_(permutation) {} + + // Transposes `values` with the permutation. Returns the transposed values. + SmallVector TransposeValues(const ArrayRef values) const { + SmallVector transposed_values(values.size()); + SmallVector current_indices = {}; + TransposeRecursively(values, transposed_values, current_indices); + + return transposed_values; + } + + // Returns the shape after permutation. + SmallVector GetTargetShape() const { return target_shape_; } + + private: + // Helper function that performs transposition recursively by mapping each set + // of indices from the original values to the target values. + void TransposeRecursively(const ArrayRef original_values, + const MutableArrayRef target_values, + SmallVector& current_indices) const { + // Map an element from `original_values` to `target_values` when a set of + // indices is formed. + if (current_indices.size() == rank_) { + const int64_t original_index = + GetContiguousOffset(current_indices, original_shape_); + + const SmallVector target_indices = + quant::Permute(current_indices, permutation_); + const int64_t target_index = + GetContiguousOffset(target_indices, target_shape_); + + target_values[target_index] = original_values[original_index]; + return; + } + + // Recursively iterate by selecting the index of the next dimension. + const int next_shape_idx = current_indices.size(); + for (int i = 0; i < original_shape_[next_shape_idx]; ++i) { + current_indices.push_back(i); + TransposeRecursively(original_values, target_values, current_indices); + current_indices.pop_back(); + } + } + + int rank_; // Rank of the input values. + SmallVector original_shape_; // Shape of the original tensor. + SmallVector target_shape_; // Shape of the target tensor. + SmallVector permutation_; +}; + +class FoldTransposedConstantOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp op, + PatternRewriter& rewriter) const override { + Value operand = op.getOperand(); + auto const_op = + dyn_cast_or_null(operand.getDefiningOp()); + if (!const_op) return failure(); + + // Only support float tensors. + auto tensor_type = mlir::dyn_cast_or_null(const_op.getType()); + if (!tensor_type || !tensor_type.getElementType().isF32()) { + return failure(); + } + + if (!mlir::isa_and_nonnull(const_op.getValue())) { + return failure(); + } + + const auto value_attr = + mlir::cast(const_op.getValue()); + const ArrayRef original_shape = + value_attr.getShapedType().getShape(); + + const SmallVector original_values = + llvm::to_vector(value_attr.getValues()); + + // Fold the constant value by transposing the values according to the + // `TransposeOp`'s permutation attribute. + const DenseElementsTransposer transposer(original_shape, + op.getPermutation()); + SmallVector transposed_values = + transposer.TransposeValues(original_values); + + // Create a new constant op with the transposed values. + const Location combined_loc = + rewriter.getFusedLoc({const_op.getLoc(), op.getLoc()}); + auto new_value_type = + RankedTensorType::getChecked(combined_loc, transposer.GetTargetShape(), + /*elementType=*/rewriter.getF32Type()); + auto new_value_attr = + DenseFPElementsAttr::get(new_value_type, std::move(transposed_values)); + auto new_const_op = rewriter.create( + combined_loc, new_value_attr); + + rewriter.replaceAllUsesWith(op, new_const_op); + return success(); + } +}; + +} // namespace + +class FoldConstantTransposePass + : public impl::FoldConstantTransposePassBase { + public: + using impl::FoldConstantTransposePassBase< + FoldConstantTransposePass>::FoldConstantTransposePassBase; + + private: + void runOnOperation() override; +}; + +void FoldConstantTransposePass::runOnOperation() { + func::FuncOp func_op = getOperation(); + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + patterns.add(&ctx); + if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) { + func_op.emitError("Failed to fold constant->transpose pattern."); + signalPassFailure(); + } +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_insert_calibration_statistics_saver.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_insert_calibration_statistics_saver.cc new file mode 100644 index 00000000000000..1f4fa95533c3aa --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_insert_calibration_statistics_saver.cc @@ -0,0 +1,190 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep +#include "tsl/platform/path.h" + +namespace mlir::tf_quant::stablehlo { +namespace { + +std::string GetOutputFilePath(absl::string_view calibration_data_dir, + absl::string_view func_name, + int32_t output_file_idx) { + return tsl::io::JoinPath(calibration_data_dir, + llvm::Twine(func_name) + .concat("_") + .concat(std::to_string(output_file_idx)) + .concat(".pb") + .str()); +} + +// Finds `CustomAggregator` ops and collects their outputs and attributes. +void FindCustomAggregatorOps( + Region& region, + const std::unordered_set& aggregator_ops_to_ignore, + SmallVector& statistics_outputs, SmallVector& ids, + SmallVector& calibration_methods) { + for (auto op : region.getOps()) { + if (aggregator_ops_to_ignore.count(op.getId().str())) continue; + + ids.push_back(op.getId()); + calibration_methods.push_back(op.getCalibrationMethod()); + statistics_outputs.push_back(op.getMin()); + statistics_outputs.push_back(op.getMax()); + statistics_outputs.push_back(op.getHistogram()); + } +} + +// Inserts a `CalibrationStatisticsSaverOp` to the end of the region. +LogicalResult InsertCalibrationStatisticsSaverOp( + Region& region, MLIRContext& ctx, absl::string_view output_file_path, + const std::unordered_set& aggregator_ops_to_ignore) { + SmallVector statistics_outputs; + SmallVector ids; + SmallVector calibration_methods; + FindCustomAggregatorOps(region, aggregator_ops_to_ignore, statistics_outputs, + ids, calibration_methods); + if (statistics_outputs.empty()) return failure(); + + OpBuilder builder(&ctx); + // Set the insertion point right before the return op. + builder.setInsertionPoint(®ion.back().back()); + + StringAttr output_file_path_attr = builder.getStringAttr(output_file_path); + ArrayAttr ids_attr = builder.getStrArrayAttr(ids); + ArrayAttr calibration_methods_attr = + builder.getI32ArrayAttr(calibration_methods); + builder.create( + region.getLoc(), statistics_outputs, output_file_path_attr, ids_attr, + calibration_methods_attr); + return success(); +} + +// Returns true if the op contains a `CalibrationStatisticsSaverOp`. +bool ContainCalibrationStatisticsSaverOp(Operation* op) { + // Check the region for CaseRegionOp, IfRegionOp and WhileRegionOp. + for (Region& region : op->getRegions()) { + if (!region.getOps().empty()) { + return true; + } + } + + SymbolTable symbol_table(op->getParentOfType()); + // Check the functions associated to CaseOp, IfOp and WhileOp. + for (const NamedAttribute& attr : op->getAttrs()) { + FlatSymbolRefAttr symbol_attr = + dyn_cast_or_null(attr.getValue()); + if (!symbol_attr) continue; + + func::FuncOp target_func = dyn_cast_or_null( + symbol_table.lookup(symbol_attr.getValue())); + if (!target_func) continue; + + if (!target_func.getBody() + .getOps() + .empty()) { + return true; + } + } + return false; +} + +} // namespace + +#define GEN_PASS_DECL_INSERTCALIBRATIONSTATISTICSSAVERPASS +#define GEN_PASS_DEF_INSERTCALIBRATIONSTATISTICSSAVERPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +class InsertCalibrationStatisticsSaverPass + : public impl::InsertCalibrationStatisticsSaverPassBase< + InsertCalibrationStatisticsSaverPass> { + public: + using impl::InsertCalibrationStatisticsSaverPassBase< + InsertCalibrationStatisticsSaverPass>:: + InsertCalibrationStatisticsSaverPassBase; + + private: + void runOnOperation() override; +}; + +void InsertCalibrationStatisticsSaverPass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext& ctx = getContext(); + + std::unordered_set aggregator_ops_to_ignore( + aggregator_ops_to_ignore_.begin(), aggregator_ops_to_ignore_.end()); + + // Insert CalibrationStatisticsSaverOp to the end of each region. + for (auto func_op : module_op.getOps()) { + int32_t output_file_idx = 0; + StringRef func_name = func_op.getSymName(); + + func_op.walk([&output_file_idx, &ctx, &func_name, &aggregator_ops_to_ignore, + this](Operation* op) { + for (Region& region : op->getRegions()) { + if (succeeded(InsertCalibrationStatisticsSaverOp( + region, ctx, + GetOutputFilePath(calibration_data_dir_, func_name, + output_file_idx), + aggregator_ops_to_ignore))) { + ++output_file_idx; + }; + } + }); + } + + // Control flow ops that contains CalibrationStatisticsSaver ops must be set + // to stateful, otherwise the op will not be executed. + OpBuilder builder(&ctx); + module_op.walk([&builder](Operation* op) { + if (op->hasAttrOfType("is_stateless") && + ContainCalibrationStatisticsSaverOp(op)) { + op->setAttr("is_stateless", builder.getBoolAttr(false)); + } + }); +} + +std::unique_ptr> +CreateInsertCalibrationStatisticsSaverPass( + StringRef calibration_data_dir, + const std::vector& aggregator_ops_to_ignore) { + InsertCalibrationStatisticsSaverPassOptions options = { + .aggregator_ops_to_ignore_ = llvm::to_vector(aggregator_ops_to_ignore), + .calibration_data_dir_ = calibration_data_dir.str(), + }; + return std::make_unique(options); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_insert_weight_param.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_insert_weight_param.cc new file mode 100644 index 00000000000000..d6d4a90930512e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_insert_weight_param.cc @@ -0,0 +1,249 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_INSERTWEIGHTPARAMPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizedType; +using ::stablehlo::quantization::WeightOnlyPtq; + +// Inserts quantization parameters of weights for weight-only quantization and +// dynamic range quantization of `stablehlo.convolution` and +// `stablehlo.dot_general`. +class InsertWeightParamPass + : public impl::InsertWeightParamPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertWeightParamPass) + + using impl::InsertWeightParamPassBase< + InsertWeightParamPass>::InsertWeightParamPassBase; + + private: + void runOnOperation() override; +}; + +// Inserts quantization parameters for weights for hybrid quantization of +// `stablehlo.convolution` and `stablehlo.dot_general`. +class InsertWeightParamPattern + : public OpTraitRewritePattern { + public: + explicit InsertWeightParamPattern(MLIRContext* context) + : OpTraitRewritePattern(context) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + if (op->getNumResults() != 1) { + return failure(); + } + auto type = mlir::cast(op->getResult(0).getType()); + if (!type || !type.getElementType().isF32()) { + return failure(); + } + if (!op->hasOneUse() || + !IsWeightQuantizableFunction(*op->getUses().begin(), type.getRank())) { + return failure(); + } + + Operation* quantizable_op = *op->getUsers().begin(); + DenseFPElementsAttr attr; + matchPattern(op->getResult(0), m_Constant(&attr)); + + Method method = GetQuantizationMethodOrDefault(quantizable_op); + const WeightOnlyPtq& weight_only_ptq = method.weight_only_ptq(); + + Type weight_type; + if (IsPerTensor(weight_only_ptq)) { + weight_type = + dyn_cast(GetUniformQuantizedTypeForWeight( + attr, /*symmetric=*/true, /*num_bits=*/8, /*is_signed=*/true, + /*narrow_range=*/true, /*legacy_float_scale=*/false)); + } else { + int quantization_dimension = GetQuantizationDimension( + weight_only_ptq, cast(quantizable_op)); + weight_type = GetUniformQuantizedPerAxisTypeForWeight( + attr, quantization_dimension, /*symmetric=*/true, /*num_bits=*/8, + /*is_signed=*/true, + /*narrow_range=*/true, /*legacy_float_scale=*/false); + } + + auto quant_type = dyn_cast(weight_type); + if (!quant_type) { + op->emitError( + "Failed to get weight quantization parameters for weight-only " + "quantization."); + return failure(); + } + + const Type expressed_type = op->getResult(0).getType(); + const Type quantized_type = + quant_type.castFromExpressedType(expressed_type); + + rewriter.setInsertionPointAfter(op); + auto q = rewriter.create( + op->getLoc(), quantized_type, op->getResult(0)); + auto dq = rewriter.create( + op->getLoc(), expressed_type, q); + quantizable_op->setOperand(1, dq.getResult()); + return success(); + } + + // Checks if the operand is second operand of `tf.XlaCallModule` op for + // `stablehlo.convolution` or `stablehlo.dot_general` with fully_quantizable + // trait. + static bool IsWeightQuantizableFunction(OpOperand& operand, int64_t rank) { + if (operand.getOperandNumber() != 1) { + return false; + } + Operation* user = operand.getOwner(); + if (!IsWeightOnlyQuantizableOp(*user)) { + return false; + } + Method method = GetQuantizationMethodOrDefault(user); + return HasValidWeightOnlyPtqMethod(method.weight_only_ptq(), rank); + } + + private: + static bool HasValidWeightOnlyPtqMethod(const WeightOnlyPtq& weight_only_ptq, + int64_t rank) { + const auto& input_quantized_types = weight_only_ptq.input_quantized_types(); + if (IsPerTensor(weight_only_ptq)) { + return true; + } + // `input_quantized_types` should contain spec for quantization type of the + // second operand, which is weight. + const QuantizedType& quantized_type = input_quantized_types.at(1); + if (const auto& specs = quantized_type.dimension_specs(); + specs.has_dimension()) { + return specs.dimension() >= 0 && specs.dimension() < rank; + } + return true; + } + + static bool IsPerTensor(const WeightOnlyPtq& weight_only_ptq) { + const auto& input_quantized_types = weight_only_ptq.input_quantized_types(); + if (input_quantized_types.empty()) { + return true; + } + auto weight_type = input_quantized_types.find(1); + if (weight_type == input_quantized_types.end()) { + return true; + } + return weight_type->second.has_per_tensor(); + } + + static int GetQuantizationDimension(const WeightOnlyPtq& weight_only_ptq, + TF::XlaCallModuleOp op) { + const QuantizedType& quantized_type = + weight_only_ptq.input_quantized_types().at(1); + if (quantized_type.dimension_specs().has_dimension()) { + return quantized_type.dimension_specs().dimension(); + } + return GetDefaultQuantizationDimension(op); + } + + // Determines quantization dimension of weights for given `tf.XlaCallModule` + // op. For convolution, returns output feature dimension of the kernel. For + // dot_general, returns the first non-contracting dimension, non-batching + // dimension. If such dimension does not exists, returns the last dimension of + // rhs. + static int64_t GetDefaultQuantizationDimension(TF::XlaCallModuleOp op) { + const StringRef function_name = GetEntryFunctionName(op); + const auto module_op = op->getParentOfType(); + const SymbolTable symbol_table(module_op); + func::FuncOp func = symbol_table.lookup(function_name); + + if (function_name.contains("conv")) { + return (*(func.getOps().begin())) + .getDimensionNumbers() + .getKernelOutputFeatureDimension(); + } else if (function_name.contains("dot_general")) { + auto dot = *(func.getOps().begin()); + const ::mlir::stablehlo::DotDimensionNumbersAttr dimension_numbers = + dot.getDotDimensionNumbers(); + ArrayRef rhs_contracting_dims = + dimension_numbers.getRhsContractingDimensions(); + ArrayRef rhs_batching_dims = + dimension_numbers.getRhsBatchingDimensions(); + int64_t rank = cast(dot.getRhs().getType()).getRank(); + for (int i = 0; i < rank; ++i) { + // Return the first non-contracting, non-batching dimension of rhs. + if (llvm::find(rhs_contracting_dims, i) == rhs_contracting_dims.end() && + llvm::find(rhs_batching_dims, i) == rhs_batching_dims.end()) { + return i; + } + } + } + return cast(op.getOperand(1).getType()).getRank() - 1; + } +}; + +void InsertWeightParamPass::runOnOperation() { + func::FuncOp func = getOperation(); + MLIRContext* context = func.getContext(); + RewritePatternSet patterns(context); + + patterns.add(context); + + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_lift_quantizable_spots_as_functions.cc new file mode 100644 index 00000000000000..bdd9255d90995c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_lift_quantizable_spots_as_functions.cc @@ -0,0 +1,243 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/regexp.h" // IWYU pragma: keep + +#define DEBUG_TYPE "lift_quantizable_spots_as_functions" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_LIFTQUANTIZABLESPOTSASFUNCTIONSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +using ::stablehlo::quantization::FunctionNameMatcherSpec; +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizationSpec; +using ::stablehlo::quantization::QuantizationSpecs; +using ::tsl::protobuf::TextFormat; + +// TODO - b/303543789: Move the helper functions below to a separate util. +// Fetches the default or null attribute, used for pattern matching. +Attribute DefaultOrNullAttr(OpBuilder& builder, const Attribute& attr) { + if (attr) return attr; + return builder.getStringAttr(kNullAttributeValue); +} + +// Checks whether the value of a constant equals the given float, regardless +// of the tensor dimension. +bool FloatValueEquals(const Attribute& attr, const double value) { + const auto fp_attr = mlir::dyn_cast_or_null(attr); + if (!fp_attr) return false; + + if (fp_attr.isSplat()) { + return fp_attr.getSplatValue().isExactlyValue(value); + } + return llvm::all_of(fp_attr.getValues(), [value](const APFloat& f) { + return f.isExactlyValue(value); + }); +} + +inline void TrimTrailingWhitespaces(std::string& str) { + while (!str.empty() && str.back() == ' ') { + str.pop_back(); + } +} + +// Lifts quantizable units as separate functions, thereby identifying the +// boundaries of quantizable subgraphs. `QuantizationSpecs` influences how +// quantizable units are lifted. +// +// FileCheck test cases using various `QuantizationSpecs` can be seen at +// `TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass`. +class LiftQuantizableSpotsAsFunctionsPass + : public impl::LiftQuantizableSpotsAsFunctionsPassBase< + LiftQuantizableSpotsAsFunctionsPass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + LiftQuantizableSpotsAsFunctionsPass) + + LiftQuantizableSpotsAsFunctionsPass() = default; + + // Constructor with explicit user-provided `QuantizationSpecs`. + explicit LiftQuantizableSpotsAsFunctionsPass( + QuantizationSpecs quantization_specs) + : quantization_specs_(std::move(quantization_specs)) {} + + private: + void runOnOperation() override; + + // No explicit quantization spec is specified by default. Implicitly this + // means that all quantizable units will be identified and lifted. + QuantizationSpecs quantization_specs_{}; +}; + +namespace simple_patterns { +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.inc" +} + +namespace fusion_patterns { +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.inc" +} + +// Quantizable Unit matcher that uses lifted function's name for matching. +class FunctionNameMatcher { + public: + explicit FunctionNameMatcher(const FunctionNameMatcherSpec& spec) + : match_regex_(GetMatchRegex(spec)) {} + + // Returns `true` when matched with the entry function of + // `xla_call_module_op`. + bool Match(TF::XlaCallModuleOp xla_call_module_op) const { + if (match_regex_ == nullptr) return false; + + const std::string lifted_func_name = + xla_call_module_op->getAttrOfType("_entry_function") + .getValue() + .str(); + + return RE2::FullMatch(lifted_func_name, *match_regex_); // NOLINT + } + + private: + // Returns an owned `RE2` object that corresponds to the `spec`. Returns + // `nullptr` if the `spec` is invalid. + // NOLINTNEXTLINE - RE2 included via TSL regexp.h + std::unique_ptr GetMatchRegex(const FunctionNameMatcherSpec& spec) { + const std::string& regex = spec.regex(); + if (regex.empty()) return nullptr; + + return std::make_unique(regex); // NOLINT + } + + // Regex object used for matching against a lifted function's name. + std::unique_ptr match_regex_; // NOLINT +}; + +// Converts `Method` to a single-line textproto representation. Returns +// `failure()` when converting to textproto failed. +FailureOr QuantizationMethodToTextProto(const Method& method) { + TextFormat::Printer printer; + printer.SetSingleLineMode(true); + + std::string method_txtpb; + if (!printer.PrintToString(method, &method_txtpb)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to convert Method to textproto\n."); + return failure(); + } + + // Single line mode might have an extra space at the end, due to the internal + // details of `Printer`. + TrimTrailingWhitespaces(method_txtpb); + + return method_txtpb; +} + +// Applies quantization spec to all matched lifted functions. At this point only +// denylisting (`NoQuantization`) will be applied if specs is nonempty. +// TODO: b/307620778 - Support more advanced selective quantization methods. +LogicalResult ApplyQuantizationSpec(const QuantizationSpec& spec, + ModuleOp module_op) { + const Method& quantization_method = spec.method(); + + FailureOr quantization_method_txtpb = + QuantizationMethodToTextProto(quantization_method); + if (failed(quantization_method_txtpb)) return failure(); + + const FunctionNameMatcher matcher(spec.matcher().function_name()); + // Iterate over all XlaCallModuleOp in all FuncOps. + for (auto func : module_op.getOps()) { + for (auto xla_call_module_op : func.getOps()) { + if (!matcher.Match(xla_call_module_op)) continue; + + // Set the text representation of `Method` to matched + // `TF::XlaCallModuleOp`. + xla_call_module_op->setAttr( + kQuantizationMethodAttr, + StringAttr::get(module_op.getContext(), + std::move(*quantization_method_txtpb))); + } + } + return success(); +} + +void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + ModuleOp module_op = getOperation(); + + simple_patterns::populateWithGenerated(patterns); + fusion_patterns::populateWithGenerated(patterns); + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + + // Iterate over the sorted list of functions to keep order deterministic. + for (func::FuncOp func : GetSortedFunctions(module_op)) { + if (failed(applyPatternsGreedily(func, frozen_patterns))) { + func.emitError() + << "quant-stablehlo-lift-quantizable-spots-as-functions failed."; + signalPassFailure(); + } + } + + // Remove all attr_map attributes. + module_op.walk([](Operation* op) { op->removeAttr(kAttrMapAttribute); }); + + // Perform selective quantization. Iterates over the quantization specs and + // applies quantization methods to each matched lifted function. + for (const QuantizationSpec& spec : quantization_specs_.specs()) { + if (failed(ApplyQuantizationSpec(spec, module_op))) { + signalPassFailure(); + return; + } + } +} + +} // namespace + +// Creates `LiftQuantizableSpotsAsFunctionsPass` with user-defined +// `QuantizationSpecs`. +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsPass( + const QuantizationSpecs& quantization_specs) { + return std::make_unique( + quantization_specs); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_merge_fusion_with_dequantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_merge_fusion_with_dequantize.cc new file mode 100644 index 00000000000000..8b15d55136aae7 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_merge_fusion_with_dequantize.cc @@ -0,0 +1,146 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_MERGEFUSIONWITHDEQUANTIZEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +class MergeFusionWithDequantizePass + : public impl::MergeFusionWithDequantizePassBase< + MergeFusionWithDequantizePass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeFusionWithDequantizePass) + + explicit MergeFusionWithDequantizePass() = default; + + private: + void runOnOperation() override; +}; + +class MergeFusionWithUniformDequantizePattern + : public OpRewritePattern { + public: + explicit MergeFusionWithUniformDequantizePattern(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(func::CallOp call_op, + PatternRewriter& rewriter) const override { + if (call_op.getNumResults() != 1) return failure(); + auto users = call_op->getUsers(); + for (auto user : users) { + if (!llvm::isa(user)) { + return failure(); + } + } + auto func_name = call_op.getCallee(); + if (!func_name.starts_with("quantized_")) return failure(); + if (call_op->getNumResults() != 1) return failure(); + if (!mlir::isa( + getElementTypeOrSelf(call_op->getResult(0).getType()))) + return failure(); + + // Fetch the callee function. + SymbolTable symbol_table(call_op->getParentOfType()); + auto func_op = + dyn_cast_or_null(symbol_table.lookup(func_name)); + if (!func_op) return failure(); + // The quantized fusion should have requantize and return ops at the end. + auto return_op = dyn_cast_or_null( + func_op.getRegion().getBlocks().front().getTerminator()); + if (!return_op) return failure(); + auto req_op = llvm::dyn_cast_or_null( + return_op.getOperands()[0].getDefiningOp()); + if (!req_op) return failure(); + + // Create a new func.call op with f32 output. + auto new_call_op = call_op.clone(); + new_call_op->getResult(0).setType( + mlir::cast(call_op.getResult(0).getType()) + .clone(rewriter.getF32Type())); + rewriter.setInsertionPoint(call_op); + rewriter.insert(new_call_op); + + // Remove the dequantize ops and replace uses by the new func.call op. + SmallVector users_to_erase; + for (auto user : users) { + llvm::dyn_cast(user) + .replaceAllUsesWith(new_call_op.getResult(0)); + users_to_erase.push_back(user); + } + for (auto user : users_to_erase) rewriter.eraseOp(user); + rewriter.eraseOp(call_op); + func_op.eraseResult(0); + func_op.insertResult(0, new_call_op.getResult(0).getType(), + /*resultAttrs=*/nullptr); + + // Modify the quantized fused function to do dequantize+relu(6). + rewriter.setInsertionPoint(req_op); + Value new_result = rewriter.create( + req_op.getLoc(), func_op.getResultTypes()[0], req_op.getOperand()); + if (func_name.contains("_relu6_")) { + auto min = rewriter.create( + req_op.getLoc(), rewriter.getF32FloatAttr(0)); + auto max = rewriter.create( + req_op.getLoc(), rewriter.getF32FloatAttr(6)); + new_result = rewriter.create( + req_op.getLoc(), min, new_result, max); + } else if (func_name.contains("_relu_")) { + auto min = rewriter.create( + req_op.getLoc(), rewriter.getF32FloatAttr(0)); + new_result = rewriter.create( + req_op.getLoc(), min, new_result, nullptr); + } + return_op->setOperand(0, new_result); + rewriter.eraseOp(req_op); + + return success(); + } +}; + +void MergeFusionWithDequantizePass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext* ctx = module_op.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsGreedily(module_op, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_nchw_convolution_to_nhwc.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_nchw_convolution_to_nhwc.cc new file mode 100644 index 00000000000000..4088b84937c7cf --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_nchw_convolution_to_nhwc.cc @@ -0,0 +1,191 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_NCHWCONVOLUTIONTONHWCPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +using ::mlir::stablehlo::ConvDimensionNumbersAttr; + +class NchwConvolutionToNhwcPass + : public impl::NchwConvolutionToNhwcPassBase { + private: + void runOnOperation() override; +}; + +// Rewrites NCHW convolution to NHWC. +// * Src dimension numbers: [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1] +// * Dst dimension numbers: [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] +class RewriteNchwConvolutionToNhwc + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { + // Handles 2D convolutions only. + if (!HasRankOf(op.getOperand(0), /*rank=*/4) || + !HasRankOf(op.getOperand(1), /*rank=*/4)) { + return failure(); + } + + if (!quant::IsOpNotQuantized(op)) return failure(); + + const ConvDimensionNumbersAttr dimension_nums = op.getDimensionNumbers(); + const bool dimension_nums_matched = + MatchInputDimensionNumbers(dimension_nums) && + MatchKernelDimensionNumbers(dimension_nums) && + MatchOutputDimensionNumbers(dimension_nums); + if (!dimension_nums_matched) { + return failure(); + } + + // Transpose the input tensor: [b, f, 0, 1] => [b, 0, 1, f] + Value input = op->getOperand(0); + const TensorType new_input_tensor_type = GetTransposedTensorType( + mlir::cast(input.getType()), kNchwToNhwcPermutation); + + auto input_transpose_op = rewriter.create( + op.getLoc(), /*resultType0=*/new_input_tensor_type, /*operand=*/input, + rewriter.getDenseI64ArrayAttr(kNchwToNhwcPermutation)); + + // Transpose the filter tensor: [o, i, 0, 1] => [0, 1, i, o] + Value filter = op->getOperand(1); + const TensorType new_filter_tensor_type = GetTransposedTensorType( + mlir::cast(filter.getType()), kOihwToHwioPermutation); + + auto filter_transpose_op = rewriter.create( + op.getLoc(), /*resultType0=*/new_filter_tensor_type, /*operand=*/filter, + rewriter.getDenseI64ArrayAttr(kOihwToHwioPermutation)); + + // [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + const auto new_dimension_nums = rewriter.getAttr( + /*inputBatchDimension=*/0, /*inputFeatureDimension=*/3, + /*inputSpatialDimensions=*/SmallVector{1, 2}, + /*kernelInputFeatureDimension=*/2, /*kernelOutputFeatureDimension=*/3, + /*kernelSpatialDimensions=*/SmallVector{0, 1}, + /*outputBatchDimension=*/0, /*outputFeatureDimension=*/3, + /*outputSpatialDimensions=*/SmallVector{1, 2}); + + // Determine the shape of the output tensor: [b, f, 0, 1] => [b, 0, 1, f] + auto output_tensor_type = + mlir::cast(op->getResult(0).getType()); + const TensorType new_conv_output_tensor_type = + GetTransposedTensorType(output_tensor_type, kNchwToNhwcPermutation); + + // window_strides, padding, lhs_dilation, rhs_dilation, window_reversal are + // reused without modification because the ordering of spatial dimensions + // is not modified (i.e. before: [b, f, 0, 1], after: [b, 0, 1, f] => the + // spatial dimension is still ordered as {0, 1}). + auto new_convolution_op = rewriter.create( + op.getLoc(), /*resultType0=*/new_conv_output_tensor_type, + /*lhs=*/input_transpose_op, + /*rhs=*/filter_transpose_op, + /*window_strides=*/op.getWindowStridesAttr(), + /*padding=*/op.getPaddingAttr(), + /*lhs_dilation=*/op.getLhsDilationAttr(), + /*rhs_dilation=*/op.getRhsDilationAttr(), + /*window_reversal=*/op.getWindowReversalAttr(), + /*dimension_numbers=*/new_dimension_nums, + /*feature_group_count=*/op.getFeatureGroupCountAttr(), + /*batch_group_count=*/op.getBatchGroupCountAttr(), + /*precision_config=*/op.getPrecisionConfigAttr()); + + // Transpose the output of the `ConvolutionOp` back to the original op's + // output shape so that users' shapes match. + // [b, 0, 1, f] => [b, f, 0, 1] + auto output_transpose_op = rewriter.create( + new_convolution_op.getLoc(), /*resultType0=*/output_tensor_type, + /*operand=*/new_convolution_op, + rewriter.getDenseI64ArrayAttr(kNhwcToNchwPermutation)); + + rewriter.replaceAllUsesWith(op, output_transpose_op); + return success(); + } + + private: + // Matches input dimensions corresponding to: [b, f, 0, 1]. + bool MatchInputDimensionNumbers( + const ConvDimensionNumbersAttr dimension_numbers) const { + return dimension_numbers.getInputBatchDimension() == 0 && + dimension_numbers.getInputFeatureDimension() == 1 && + dimension_numbers.getInputSpatialDimensions() == + ArrayRef{2, 3}; + } + + // Matches kernel dimensions corresponding to: [o, i, 0, 1]. + bool MatchKernelDimensionNumbers( + const ConvDimensionNumbersAttr dimension_numbers) const { + return dimension_numbers.getKernelInputFeatureDimension() == 1 && + dimension_numbers.getKernelOutputFeatureDimension() == 0 && + dimension_numbers.getKernelSpatialDimensions() == + ArrayRef{2, 3}; + } + + // Matches output dimensions corresponding to: [b, f, 0, 1]. + bool MatchOutputDimensionNumbers( + const ConvDimensionNumbersAttr dimension_numbers) const { + return dimension_numbers.getOutputBatchDimension() == 0 && + dimension_numbers.getOutputFeatureDimension() == 1 && + dimension_numbers.getOutputSpatialDimensions() == + ArrayRef{2, 3}; + } + + // Returns a new tensor type with the shape transposed according to the + // permutation. The rank of `type` and the size of `permutation` must be + // equal. + TensorType GetTransposedTensorType( + const TensorType type, const ArrayRef permutation) const { + const SmallVector after_shape = + quant::Permute(type.getShape(), permutation); + return type.cloneWith(after_shape, type.getElementType()); + } +}; + +} // namespace + +void NchwConvolutionToNhwcPass::runOnOperation() { + func::FuncOp func_op = getOperation(); + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + patterns.add(&ctx); + + if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) { + func_op.emitError() << "Failed to run NchwConvolutionToNhwcPass."; + signalPassFailure(); + } +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_optimize_graph.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_optimize_graph.cc new file mode 100644 index 00000000000000..0bb7b660e11096 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_optimize_graph.cc @@ -0,0 +1,55 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_OPTIMIZEGRAPHPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +class OptimizeGraphPass + : public impl::OptimizeGraphPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizeGraphPass) + + explicit OptimizeGraphPass() = default; + + private: + void runOnOperation() override; +}; + +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/optimize_graph.inc" + +void OptimizeGraphPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + populateWithGenerated(patterns); + auto func = getOperation(); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } +} +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h new file mode 100644 index 00000000000000..dd62e6f278065c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h @@ -0,0 +1,61 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TF_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TF_PASSES_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" + +namespace mlir::tf_quant::stablehlo { + +// Creates a pass that quantizes weight component of StableHLO graph. +std::unique_ptr> CreateQuantizeWeightPass( + const ::stablehlo::quantization::QuantizationComponentSpec& + quantization_component_spec = {}); + +// Converts a serialized StableHLO module to bfloat16 and output serialized +// module. +absl::StatusOr ConvertSerializedStableHloModuleToBfloat16( + StringRef serialized_stablehlo_module); + +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsPass( + const ::stablehlo::quantization::QuantizationSpecs& quantization_specs); + +// Creates a pass that inserts CalibrationStatisticsSaverOp. +std::unique_ptr> +CreateInsertCalibrationStatisticsSaverPass( + StringRef calibration_data_dir, + const std::vector& aggregator_ops_to_ignore); + +// Adds generated pass default constructors or options definitions. +#define GEN_PASS_DECL +// Adds generated pass registration functions. +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TF_PASSES_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.td new file mode 100644 index 00000000000000..fd47b5d8ec68b5 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.td @@ -0,0 +1,248 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def QuantizeWeightPass : Pass<"tf-stablehlo-quantize-weight", "mlir::func::FuncOp"> { + let summary = "Quantizes the weight component of StableHLO graph."; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; + let constructor = "mlir::tf_quant::stablehlo::CreateQuantizeWeightPass()"; +} + +def UnfuseMhloBatchNormPass : Pass<"tf-stablehlo-unfuse-mhlo-batch-norm", "mlir::func::FuncOp"> { + let summary = "Unfuses batch normalization into arithmetic ops."; +} + +def LiftQuantizableSpotsAsFunctionsPass : Pass<"tf-stablehlo-lift-quantizable-spots-as-functions", "mlir::ModuleOp"> { + let summary = "Replace quantization candidates with composite functions into the module."; + let description = [{ + Mark frequent fusible patterns as functions for quantization targets. + In addition to brining performance benefits by reducing q/dq op overhead in non-full quantization, + this brings higher accuracy by keeping a smaller range when quantizing ops + that disperse values. (ex: convolution, dot_general) + }]; + let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::stablehlo::StablehloDialect", + "TF::TensorFlowDialect", + ]; +} + +def ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass : Pass<"tf-stablehlo-replace-stablehlo-ops-in-main-function-with-xla-call-module-ops", "mlir::ModuleOp"> { + let summary = "Replaces the StableHLO ops with a separate XlaCallModuleOps."; + let description = [{ + Replaces the StableHLO ops in the main function block with + tf.XlaCallModuleOps as separate subgraphs. Wires them back to the main + function block to be compatible with SavedModel structure. + }]; +} + +def RestoreFunctionNamePass : Pass<"tf-stablehlo-restore-function-name", "ModuleOp"> { + let summary = "Restores function name from XlaCallModule op."; +} + +def QuantizeCompositeFunctionsPass : Pass<"tf-stablehlo-quantize-composite-functions", "ModuleOp"> { + let summary = "Quantize composite functions with QDQ input / outputs."; + let options = [ + Option<"enable_per_channel_quantized_weight_", + "enable-per-channel-quantized-weight", + "bool", /*default=*/"true", + "Whether to enable per-channel quantized weights.">, + Option<"mlir_dump_file_name_", "mlir-dump-file-name", + "std::optional", /*default=*/"std::nullopt", + "MLIR dump file name.">, + Option<"merge_fusion_with_dequantize_", + "merge-fusion-with-dequantize", + "bool", /*default=*/"false", + "Whether to merge quantized conv/dot_general fusion with subsequent dequantize.">, + ]; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::stablehlo::StablehloDialect", + "mlir::quant::QuantDialect", + "mlir::quant::ir::TFQuantDialect", + "TF::TensorFlowDialect", + ]; +} + +def PrepareQuantizePass : Pass<"tf-stablehlo-prepare-quantize", "mlir::ModuleOp"> { + let summary = "Prepare StableHLO dialect for static range quantization by converting quantfork.stats into quantfork.qcast and dcast ops."; + let options = [ + Option<"enable_per_channel_quantized_weight_", + "enable-per-channel-quantized-weight", + "bool", /*default=*/"true", + "Whether to enable per-channel quantized weights.">, + Option<"bit_width_", "bit-width", "int", /*default=*/"8", + "Bitwidth of quantized integer"> + ]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::quant::QuantDialect", + "mlir::quant::ir::TFQuantDialect", + "mlir::arith::ArithDialect", + ]; +} + +def QuantizePass : Pass<"tf-stablehlo-quantize", "mlir::ModuleOp"> { + let summary = "Applies static-range quantization on ops by converting quantfork.qcast, quantfork.dcast, and float op into uniform quantized ops ."; + let options = [ + Option<"enable_per_channel_quantized_weight_", + "enable-per-channel-quantized-weight", + "bool", /*default=*/"true", + "Whether to enable per-channel quantized weights.">, + ]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::quant::QuantDialect", + "mlir::quant::ir::TFQuantDialect", + ]; +} + +def PostQuantizePass : Pass<"tf-stablehlo-post-quantize", "mlir::func::FuncOp"> { + let summary = "Apply clean-up after quantization."; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::quant::ir::TFQuantDialect", + ]; +} + +def XlaCallModuleToCallPass : Pass<"tf-stablehlo-xla-call-module-to-call", "ModuleOp"> { + let summary = "Convert XlaCallModuleOp to func.call op"; + let dependentDialects = [ + "TF::TensorFlowDialect", + ]; +} + +def MergeFusionWithDequantizePass : Pass<"tf-stablehlo-merge-fusion-with-dequantize", "mlir::ModuleOp"> { + let summary = "Merge quantized conv/dot_general fusion with subsequent dequantize."; + let dependentDialects = [ + "chlo::ChloDialect", + "mlir::stablehlo::StablehloDialect", + ]; +} + +def UnwrapXlaCallModuleOpPass : Pass<"tf-stablehlo-unwrap-xla-call-module-op", "ModuleOp"> { + let summary = "Unwrap XlaCallModuleOps into inline functions if not used for quantizing fused patterns."; + let dependentDialects = ["TF::TensorFlowDialect"]; +} + +def ConvertFuncToBfloat16Pass : Pass<"tf-stablehlo-convert-func-to-bfloat16", "mlir::func::FuncOp"> { + let summary = "Convert a StableHLO function to bfloat16"; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + +def ConvertXlaCallModuleOpToBfloat16Pass : Pass<"tf-stablehlo-convert-xla-call-module-op-to-bfloat16", "mlir::func::FuncOp"> { + let summary = "Convert serialized XlaCallModuleOp to bfloat16"; + let dependentDialects = [ + "TF::TensorFlowDialect", + "mlir::quant::QuantDialect", + "mlir::shape::ShapeDialect", + "mlir::stablehlo::StablehloDialect", + ]; +} + +def ConvertShapeToStablehloWithConstraintsPass : Pass<"tf-stablehlo-convert-shape-to-stablehlo-with-constraints", "mlir::func::FuncOp"> { + let summary = "Convert shape.cstr_broadcastable to stablehlo.custom_call @shape_assertion"; + let dependentDialects = [ + "mlir::shape::ShapeDialect", + "mlir::tensor::TensorDialect", + "mlir::stablehlo::StablehloDialect", + ]; +} + +def OptimizeGraphPass : Pass<"tf-optimize-graph", "ModuleOp"> { + let summary = "Optimize the sub-optimal patterns after quantization."; + let dependentDialects = ["mlir::stablehlo::StablehloDialect",]; +} + +def NchwConvolutionToNhwcPass : Pass<"tf-stablehlo-nchw-convolution-to-nhwc", "mlir::func::FuncOp"> { + let summary = "Converts stablehlo.convolution op of NCHW format to -> NHWC."; + let description = [{ + Matches `ConvolutionOp`s with NCHW format and converts it to NHWC + format by inserting `TransposeOp`s to input, filter, and output tensors. + In terms of dimension numbers, this matches + `[b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1]` format and converts it to + `[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]` format. + + This pass is useful to convert models that conventionally use the NCHW + format to target hardwares that are more NHWC-friendly. + }]; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + +def DeferActivationTransposePass : Pass<"tf-stablehlo-defer-activation-transpose", "mlir::func::FuncOp"> { + let summary = "Merges stablehlo.transpose for activations."; + let description = [{ + Defers activation transposes (e.g. LHS of `stablehlo.add`) to the output and + optionally inserts `stablehlo.transpose`s to match the shape of operands. + This is useful when recursively pushing down the extra `stablehlo.transpose` + inserted to activation tensors after running `NchwConvolutionToNhwcPass`. + + Currently only converts limited cases that appear in NCHW->NHWC 2D + convolution conversion, to avoid introducing unwanted pessimizations. + }]; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + +def InsertWeightParamPass : Pass<"tf-stablehlo-insert-weight-param", "mlir::func::FuncOp"> { + let summary = "Insert quantization parameters of weights for weight-only quantization and dynamic range quantization."; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "TF::TensorFlowDialect", + "mlir::quant::QuantDialect", + "mlir::quant::ir::TFQuantDialect", + ]; +} + +def FoldConstantTransposePass : Pass<"tf-stablehlo-fold-constant-transpose", "mlir::func::FuncOp"> { + let summary = "Folds stablehlo.constant -> stablehlo.transpose patterns."; + let description = [{ + Finds patterns where a `stablehlo.constant` is directly followed by a + `stablehlo.transpose` and folds them into a single `stablehlo.constant`. + This is considered an aggressive optimization, but it is useful to eliminate + `stablehlo.constant`->`stablehlo.transpose` patterns which are often + by-products of other shape conversion optimizations, such as NCHW->NHWC + convolution conversion. + }]; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + +def RemoveShardingCustomCallPass : Pass<"tf-stablehlo-remove-sharding-custom-call", "mlir::func::FuncOp"> { + let summary = "Removes `stablehlo.custom_call @Sharding`"; + let description = [{ + Finds `stablehlo.custom_call @Sharding` and removes all instances of them, + replacing the usages by its operand. This is used where sharding doesn't + make much sense or sharding custom calls are incompatible, e.g. on-device + targets. + }]; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + +def InsertCalibrationStatisticsSaverPass : Pass<"tf-stablehlo-insert-calibration-statistics-saver", "ModuleOp"> { + let summary = "Inserts `CalibrationStatisticsSaver` op to collect and save calibration statistics."; + let description = [{ + Finds all `CustomAggregator` ops in the each function and add a single + `CalibrationStatisticsSaver` op at the end of the function to collect their + statistics. + }]; + let options = [ + ListOption<"aggregator_ops_to_ignore_", "aggregator-ops-to-ignore", "std::string", + "Ops to ignore when inserting CalibrationStatisticsSaver.">, + Option<"calibration_data_dir_", "calibration-data-dir", + "std::string", /*default=*/"", + "The directory to save calibration data.">, + ]; + let dependentDialects = ["TF::TensorFlowDialect"]; +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_post_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_post_quantize.cc new file mode 100644 index 00000000000000..82e85a0c347062 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_post_quantize.cc @@ -0,0 +1,160 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_POSTQUANTIZEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Applies clean-up patterns after quantization. +class PostQuantizePass : public impl::PostQuantizePassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PostQuantizePass) + + explicit PostQuantizePass() = default; + + private: + void runOnOperation() override; +}; + +// TODO: b/305815328 - Consider preserving leading and trailing QDQs for +// ModifyIONodesPass in TFLite use cases. +// Removes the back-to-back quantize and dequantize ops with volatile attribute. +class RemoveVolatileQdqPattern + : public OpRewritePattern { + public: + explicit RemoveVolatileQdqPattern(MLIRContext* context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(mlir::quant::ir::DequantizeCastOp op, + PatternRewriter& rewriter) const override { + auto input_op = op.getArg().getDefiningOp(); + if (auto q = + llvm::dyn_cast_or_null(input_op)) { + if (!q->getAttr(kVolatileOpAttrName)) return failure(); + + // If the quantize op is a requantize op, it is being used in other scale + // adjustments and should be kept. Instead, move dequantize op before the + // requantize op to remove the unnecessary requantize op. + if (const QuantizedType qtype = + QuantizedType::getQuantizedElementType(q.getArg().getType())) { + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), q.getArg()); + return success(); + } + + op.replaceAllUsesWith(q.getArg()); + return success(); + } + return failure(); + } +}; + +// Replaces constant and uniform_quantize ops with single quantized constant op. +class QuantizeConstPattern + : public OpRewritePattern { + public: + explicit QuantizeConstPattern(MLIRContext* context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(mlir::stablehlo::UniformQuantizeOp op, + PatternRewriter& rewriter) const override { + DenseFPElementsAttr attr; + if (matchPattern(op.getOperand(), m_Constant(&attr))) { + const Type qtype = op.getResult().getType(); + ElementsAttr quantized_attr = Quantize(attr, qtype); + if (quantized_attr) { + rewriter.replaceOpWithNewOp( + op, qtype, quantized_attr); + return success(); + } + } + return failure(); + } +}; + +// Replaces quantfork.dcast with stablehlo.uniform_dequantize. +class ConvertDequantizeCastToUniformDequantizePattern + : public OpRewritePattern { + public: + explicit ConvertDequantizeCastToUniformDequantizePattern(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(mlir::quant::ir::DequantizeCastOp dq_op, + PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp( + dq_op, dq_op.getResult().getType(), dq_op.getArg()); + return success(); + } +}; + +// Replaces quantfork.qcast with stablehlo.uniform_quantize. +class ConvertQuantizeCastToUniformQuantizePattern + : public OpRewritePattern { + public: + explicit ConvertQuantizeCastToUniformQuantizePattern(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(mlir::quant::ir::QuantizeCastOp q_op, + PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp( + q_op, q_op.getResult().getType(), q_op.getArg()); + return success(); + } +}; + +void PostQuantizePass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + func::FuncOp func = getOperation(); + MLIRContext* ctx = func.getContext(); + // TODO: b/307463853 - Consider splitting passes for each pattern set. + patterns.add, + RemoveVolatileQdqPattern>(ctx); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } + + RewritePatternSet patterns_2(&getContext()); + patterns_2 + .add(ctx); + if (failed(applyPatternsGreedily(func, std::move(patterns_2)))) { + signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_prepare_quantize.cc new file mode 100644 index 00000000000000..b7976e35c7f406 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_prepare_quantize.cc @@ -0,0 +1,200 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace tf_quant { +namespace stablehlo { + +#define GEN_PASS_DEF_PREPAREQUANTIZEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Applies prepare quantization on the model in TF dialect. This pass runs +// before the quantization pass and propagate the quantization parameters +// across ops. This step is necessary for post-training quantization and also +// making the quantization rule for some operations in the quantization-aware +// training quantization simpler. +class PrepareQuantizePass + : public impl::PrepareQuantizePassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareQuantizePass) + + using impl::PrepareQuantizePassBase< + PrepareQuantizePass>::PrepareQuantizePassBase; + + explicit PrepareQuantizePass(const bool enable_per_channel_quantized_weight, + const int bit_width) { + enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; + bit_width_ = bit_width; + } + + void runOnOperation() override; +}; + +// Merges consecutive QuantizeCast ops. See b/246655213 for details. +// For example, the following case: +// %1 = quantfork.QuantizeCastOp(%0) : f32 -> qtype1 +// %2 = quantfork.QuantizeCastOp(%1) : qtype1 -> qtype2 +// %3 = quantfork.QuantizedOp1(%1) +// %4 = quantfork.QuantizedOp2(%2) +// will be tranformed to: +// %1 = quantfork.QuantizeCastOp(%0) : f32 -> qtype1 +// %2 = quantfork.QuantizeCastOp(%0) : f32 -> qtype2 +// %3 = quantfork.QuantizedOp1(%1) +// %4 = quantfork.QuantizedOp2(%2) +// Converting from f32 -> qtype1 -> qtype2 will add unexpected quantization +// lost for %2. This pattern avoids that by converting from f32 -> qtype2 +// directly. +class MergeConsecutiveQuantizeCast + : public mlir::OpRewritePattern { + public: + explicit MergeConsecutiveQuantizeCast(MLIRContext* context) + : OpRewritePattern(context) {} + + private: + LogicalResult matchAndRewrite(mlir::quant::ir::QuantizeCastOp q_op, + PatternRewriter& rewriter) const override { + auto preceding_qcast = + q_op.getArg().getDefiningOp(); + if (!preceding_qcast) return failure(); + + auto new_qcast = rewriter.create( + q_op.getLoc(), q_op.getType(), preceding_qcast.getArg()); + new_qcast->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr()); + q_op->replaceAllUsesWith(new_qcast); + return success(); + } +}; + +class ConvertTFConstOpToArithConstOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::ConstOp op, + PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getValue()); + return success(); + } +}; + +class ConvertStablehloConstToArithConstOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mlir::stablehlo::ConstantOp op, + PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getValue()); + return success(); + } +}; + +class ConvertArithConstToStablehloConstOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ConstantOp op, + PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getValue()); + return success(); + } +}; + +void PrepareQuantizePass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext* ctx = module_op.getContext(); + + auto func_op_quant_spec = GetStableHloOpQuantSpec; + auto func_op_quant_scale_spec = GetStableHloQuantConstraints; + + for (auto func_op : module_op.getOps()) { + // The function might contain more stats ops than required, and it will + // introduce requantize if the calibration stats have conflicts. This tries + // to remove all the redundant stats ops. + RemoveRedundantStatsOps(func_op, func_op_quant_spec, + func_op_quant_scale_spec); + + RewritePatternSet patterns(ctx); + // Convert quant stats to int8 quantization parameters. + // Currently, only activation stats are imported, so narrow_range = false. + patterns.add>( + bit_width_, + /*narrow_range=*/false, + /*is_signed=*/true, + /*legacy_float_scale=*/false, ctx); + // Convert all constants to arith::ConstantOp as quantization driver can + // deal with the arith::ConstantOp instances. + patterns.add(ctx); + patterns.add(ctx); + if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) { + signalPassFailure(); + } + + // Finally, the quantization parameters can be propagated to the rest of the + // values (tensors). + ApplyQuantizationParamsPropagation( + func_op, /*is_signed=*/true, bit_width_, + !enable_per_channel_quantized_weight_, func_op_quant_spec, + func_op_quant_scale_spec, + /*infer_tensor_ranges=*/true, /*legacy_float_scale=*/false, + /*is_qdq_conversion=*/false); + + // Restore constants as stablehlo::ConstantOp. + RewritePatternSet patterns_2(ctx); + patterns_2 + .add( + ctx); + if (failed(applyPatternsGreedily(func_op, std::move(patterns_2)))) { + signalPassFailure(); + } + } +} + +} // namespace + +// Creates an instance of the TensorFlow dialect PrepareQuantize pass. +std::unique_ptr> CreatePrepareQuantizePass( + const bool enable_per_channel_quantized_weight, const int bit_width) { + return std::make_unique( + enable_per_channel_quantized_weight, bit_width); +} + +} // namespace stablehlo +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.cc new file mode 100644 index 00000000000000..028d7e861d219a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.cc @@ -0,0 +1,1039 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.h" + +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BlockSupport.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +#define DEBUG_TYPE "populate-quantization-patterns" + +namespace mlir::tf_quant::stablehlo { + +namespace { + +using ::mlir::stablehlo::AddOp; +using ::mlir::stablehlo::BroadcastInDimOp; +using ::mlir::stablehlo::ConcatenateOp; +using ::mlir::stablehlo::ConvolutionOp; +using ::mlir::stablehlo::DotGeneralOp; +using ::mlir::stablehlo::DynamicBroadcastInDimOp; +using ::mlir::stablehlo::GatherOp; +using ::mlir::stablehlo::GetDimensionSizeOp; +using ::mlir::stablehlo::ReshapeOp; +using ::mlir::stablehlo::UniformQuantizeOp; +using ::mlir::tf_quant::FindUserOfType; +using ::mlir::tf_quant::TryCast; +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizedDimension; +using ::stablehlo::quantization::QuantizedType; +using ::stablehlo::quantization::StaticRangePtq; + +constexpr StringRef kEntryFuncAttrName = "_entry_function"; + +// Returns broadcasted user op of an input op. Returns null if +// the op is not broadcasted or not the intended type. +// Supports both static broadcast and dynamic broadcast. +// Note that the patterns below differ from lifted patterns as +// ShapeLegalizeToHloPass is ran prior to running this pass. +// +// Dynamically broadcasted bias due to unknown input batch size +// usually has the following pattern. In the example below, +// the input operand would be stablehlo.convolution op, and return value would +// be stablehlo.add op. +// +// ``` +// %0 = stablehlo.constant dense<3> +// %1 = stablehlo.constant dense<4> +// %2 = stablehlo.constant dense<2> +// %3 = stablehlo.convolution(%%arg0, %%arg1) : +// (tensor, tensor<2x3x3x2xf32>) -> tensor +// %4 = stablehlo.get_dimension_size %3, dim = 0 : +// (tensor) -> tensor +// %5 = stablehlo.reshape %4 : +// (tensor) -> tensor<1xi32> +// %6 = stablehlo.concatenate %5, %0, %1, %2, dim = 0 : +// (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) +// -> tensor<4xi32> +// %7 = stablehlo.dynamic_broadcast_in_dim %arg2, %6 +// %8 = stablehlo.add %3, %7 +// ``` +// +// Statically broadcasted bias will be broadcasted to match the accumulation. +// ``` +// %3 = stablehlo.convolution(%%arg0, %%arg1) : +// (tensor, tensor<2x3x3x2xf32>) -> tensor +// %4 = stablehlo.broadcast_in_dim %arg2, %3 +// %5 = stablehlo.add %3, %4 +// ``` +template +Operation* GetBroadcastedUserOp(Operation* op) { + // Broadcast bias for known input shape. + auto broadcast_in_dim_op = FindUserOfType(op); + if (broadcast_in_dim_op != nullptr) { + auto target_op = FindUserOfType(broadcast_in_dim_op); + if (target_op != nullptr) return target_op; + } + // Broadcast bias for unknown input shape. + auto get_dimension_size_op = FindUserOfType(op); + if (get_dimension_size_op == nullptr) return nullptr; + + auto reshape_op = FindUserOfType(get_dimension_size_op); + if (reshape_op == nullptr) return nullptr; + + auto concatenate_op = FindUserOfType(reshape_op); + if (concatenate_op == nullptr) return nullptr; + + auto dynamic_broadcast_in_dim_op = + FindUserOfType(concatenate_op); + if (dynamic_broadcast_in_dim_op == nullptr) return nullptr; + + auto target_op = FindUserOfType(dynamic_broadcast_in_dim_op); + return target_op; +} + +// Gets the corresponding quantized function name from the given function name. +// Example: "composite_dot_general_fn_1" => "quantized_dot_general_fn" +std::string GetQuantizedFunctionName(const StringRef func_name) { + return Twine(kQuantizedFuncPrefix) + .concat(func_name.rsplit(kCompositeFuncPrefix).second) + .str(); +} + +// Returns true if `xla_call_module_op` is quantized. To be considered +// quantized, it should meet three conditions: +// 1. At least one of the inputs and outputs should be a uniform quantized type. +// 2. `xla_call_module_op` should have the `kQuantTraitAttrName` attribute. +// 3. It should also have the `kEntryFuncAttrName` attribute, which points to +// the function that `xla_call_module_op` represents. +bool IsQuantizedXlaCallModuleOp(TF::XlaCallModuleOp xla_call_module_op) { + return !quant::IsOpNotQuantized(xla_call_module_op) && + xla_call_module_op->hasAttr(kQuantTraitAttrName) && + xla_call_module_op->hasAttr(kEntryFuncAttrName); +} + +// Returns the entry function, i.e. the callee of `xla_call_module_op`. +func::FuncOp GetEntryFuncOp(TF::XlaCallModuleOp xla_call_module_op, + const SymbolTable symbol_table) { + const auto entry_function_symbol_ref = + xla_call_module_op->getAttrOfType(kEntryFuncAttrName); + + return dyn_cast_or_null( + symbol_table.lookup(entry_function_symbol_ref.getValue())); +} + +// Replaces the function type of `entry_func_op` to a quantized one, matching +// the input and output types of `xla_call_module_op`. +void SetQuantizedFunctionType(PatternRewriter& rewriter, + func::FuncOp entry_func_op, + TF::XlaCallModuleOp xla_call_module_op) { + SmallVector arg_types; + SmallVector arg_locs; + for (const Value arg : xla_call_module_op.getArgs()) { + arg_types.push_back(arg.getType()); + arg_locs.push_back(arg.getLoc()); + } + + SmallVector output_types; + for (const Value output : xla_call_module_op.getOutput()) { + output_types.push_back(output.getType()); + } + + entry_func_op.setFunctionType( + rewriter.getFunctionType(arg_types, output_types)); + + // Replace argument types and locs. + Block& entry = entry_func_op->getRegion(0).front(); + for (auto [arg, arg_type, arg_loc] : + llvm::zip_equal(entry.getArguments(), arg_types, arg_locs)) { + arg.setType(arg_type); + arg.setLoc(arg_loc); + } +} + +// Creates a UniformQuantize op and sets it as return op. +// The requantize scale and zero point should be determined from the +// `entry_func_op`'s output, containing information on layerStats of the +// entire function. +void CreateAndReturnUniformQuantizeOp(PatternRewriter& rewriter, Operation& op, + func::FuncOp entry_func_op, + const Type func_result_type) { + // Add i32 -> i8 requantization. + UniformQuantizeOp uniform_quant_op = rewriter.create( + op.getLoc(), func_result_type, op.getResults()); + cast(entry_func_op.getBody().front().getTerminator()) + .setOperand(0, uniform_quant_op); +} + +template +// Creates a quantized bias pattern for static and dynamic shape case +// and sets the quantized bias as the return op. +void CreateAndReturnQuantizedBiasPattern( + Operation* op, PatternRewriter& rewriter, func::FuncOp entry_func_op, + const Type func_result_type, const Type accumulation_quantized_element_type, + GemmStyleOp gemm_style_op) { + const Value bias_op = op->getOperand(1); + Value add_op_result = op->getResult(0); + + // Broadcast bias value if unmatched with output shape. + auto bcast_op = TryCast(bias_op.getDefiningOp(), + /*name=*/"broadcast_in_dim_op"); + + if (failed(bcast_op)) { + bcast_op = TryCast( + bias_op.getDefiningOp(), + /*name=*/"dynamic_broadcast_in_dim_op"); + } + // Update the bias type for both static and dynamic broadcasts. + if (succeeded(bcast_op)) { + Value bcast_op_result = (*bcast_op)->getResult(0); + auto bcast_op_result_type = + mlir::cast(bcast_op_result.getType()); + const ArrayRef bcast_shape = bcast_op_result_type.getShape(); + const TensorType new_bcast_op_result_type = bcast_op_result_type.cloneWith( + bcast_shape, accumulation_quantized_element_type); + bcast_op_result.setType(new_bcast_op_result_type); + } + + const auto add_op_result_type = + mlir::cast(add_op_result.getType()); + const ArrayRef add_op_shape = add_op_result_type.getShape(); + // For quantized bias add case, lhs, rhs, and result have the same types. + const TensorType new_add_op_result_type = add_op_result_type.cloneWith( + add_op_shape, accumulation_quantized_element_type); + add_op_result.setType(new_add_op_result_type); + + AddOp bias_add_op = + rewriter.create(gemm_style_op->getLoc(), gemm_style_op, bias_op); + + CreateAndReturnUniformQuantizeOp(rewriter, *bias_add_op, entry_func_op, + func_result_type); +} + +// An interface representing patterns that quantizes an entry function's body. +// The entry function's signatures should have already been quantized at the +// point of rewriting. +class EntryFuncBodyQuantizationPattern { + public: + virtual ~EntryFuncBodyQuantizationPattern() = default; + + // Returns `success()` if `entry_func_op`'s body is eligible for rewriting. At + // this point `entry_func_op`'s signature has not been reset with quantized + // types. + virtual LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const = 0; + + // Rewrites the `entry_func_op`'s body. + virtual void rewrite(func::FuncOp entry_func_op, + const Method& quantization_method, + PatternRewriter& rewriter) const = 0; +}; + +// Gemm Style Op: glossary/gemm. +template +// Match for all gemm_style op and check for possible fusions. +LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { + const auto op_iterator_range = entry_func_op.getOps(); + if (op_iterator_range.empty()) { + LLVM_DEBUG(llvm::dbgs() << "Function does not have " + << GemmStyleOp::getOperationName() << " op.\n"); + return failure(); + } + if (!isa( + (*op_iterator_range.begin()).getResult().getType())) { + LLVM_DEBUG(llvm::dbgs() << GemmStyleOp::getOperationName() + << " op must have ranked tensor type.\n"); + return failure(); + } + + MutableArrayRef operands = + entry_func_op.getBody().getArguments(); + // Function must have input, filter, and optionally bias. + if (operands.size() != 2 && operands.size() != 3) { + LLVM_DEBUG(llvm::dbgs() << GemmStyleOp::getOperationName() + << " op function should have 2 or 3 operands.\n"); + return failure(); + } + return success(); +} + +// Gemm Style Op: glossary/gemm. +template +void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, + const bool enable_per_channel_quantized_weight) { + const GemmStyleOp gemm_style_op = + *entry_func_op.getOps().begin(); + + const Type input_type = entry_func_op.getArgumentTypes()[0]; + const Type filter_type = entry_func_op.getArgumentTypes()[1]; + const Type func_result_type = entry_func_op.getResultTypes()[0]; + + Value gemm_style_op_result = gemm_style_op->getResult(0); + const auto gemm_style_op_result_type = + mlir::cast(gemm_style_op_result.getType()); + const ArrayRef gemm_style_shape = + gemm_style_op_result_type.getShape(); + + Type accumulation_quantized_element_type; + TensorType new_gemm_style_op_result_type; + + const double input_scale = + mlir::cast(getElementTypeOrSelf(input_type)) + .getScale(); + + if (enable_per_channel_quantized_weight) { + ArrayRef filter_scales = + mlir::cast( + getElementTypeOrSelf(filter_type)) + .getScales(); + std::vector result_scales; + result_scales.reserve(filter_scales.size()); + + for (const double filter_scale : filter_scales) { + result_scales.push_back(input_scale * filter_scale); + } + + const ArrayRef zero_points = + mlir::cast( + getElementTypeOrSelf(filter_type)) + .getZeroPoints(); + + // `stablehlo.convolution` assumes the following format: + // [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // `stablehlo.dot_general` can take various formats. We only per-channel + // quantize non-batch ops. + // `stablehlo.dot_general` legalizable to `tfl.fully_connected` has a + // filter rank of 2 with the last dimension as the channel dimension. + const int64_t quantization_dimension = + mlir::cast(filter_type).getShape().size() - 1; + accumulation_quantized_element_type = + quant::CreateI32F32UniformQuantizedPerAxisType( + gemm_style_op->getLoc(), *rewriter.getContext(), result_scales, + zero_points, quantization_dimension); + + new_gemm_style_op_result_type = gemm_style_op_result_type.cloneWith( + gemm_style_shape, accumulation_quantized_element_type); + } else { + const double filter_scale = + mlir::cast(getElementTypeOrSelf(filter_type)) + .getScale(); + const double result_scale = input_scale * filter_scale; + + accumulation_quantized_element_type = + quant::CreateI32F32UniformQuantizedType( + gemm_style_op->getLoc(), *rewriter.getContext(), result_scale, + /*zero_point=*/0); + + new_gemm_style_op_result_type = gemm_style_op_result_type.cloneWith( + gemm_style_shape, accumulation_quantized_element_type); + } + + gemm_style_op_result.setType(new_gemm_style_op_result_type); + + rewriter.setInsertionPointAfter(gemm_style_op); + + Operation* next_op = FindUserOfType<>(gemm_style_op); + + // If activation exists, omit clipping op. + // Since out_scale and out_zp are computed based on clipped range, + // explicit activation clipping op is not required. + if (isa(next_op) && gemm_style_op->hasOneUse()) { + // bias fusion + CreateAndReturnQuantizedBiasPattern( + next_op, rewriter, entry_func_op, func_result_type, + accumulation_quantized_element_type, gemm_style_op); + } else if (auto add_op = cast_or_null( + GetBroadcastedUserOp(gemm_style_op))) { + // broadcasted bias fusion + rewriter.setInsertionPointAfter(add_op); + CreateAndReturnQuantizedBiasPattern( + add_op, rewriter, entry_func_op, func_result_type, + accumulation_quantized_element_type, gemm_style_op); + } else { + // Non fusible op + // If an op is used multiple times and is not a broadcasted shape case, + // do not apply quantization of fused patterns to prevent removal of + // dependee ops. + CreateAndReturnUniformQuantizeOp(rewriter, *gemm_style_op, entry_func_op, + func_result_type); + } +} + +// Quantizes the entry function's body containing a `DotGeneralOp`. +class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeDotGeneralOpPattern( + const bool enable_per_channel_quantized_weight) + : enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} + + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_static_range_ptq()) { + return failure(); + } + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, + PatternRewriter& rewriter) const override { + DotGeneralOp dot_general_op = *entry_func_op.getOps().begin(); + const bool should_quantize_per_channel = + enable_per_channel_quantized_weight_ && + GetDotGeneralQuantizationDim(dot_general_op); + RewriteGemmStyleOp(entry_func_op, rewriter, + should_quantize_per_channel); + } + + private: + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; +}; + +// Quantizes the entry function's body containing a `ConvolutionOp`. +class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeConvolutionOpPattern( + const bool enable_per_channel_quantized_weight) + : enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} + + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_static_range_ptq()) { + return failure(); + } + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, + PatternRewriter& rewriter) const override { + RewriteGemmStyleOp( + entry_func_op, rewriter, + enable_per_channel_quantized_weight_ && + IsWeightPerChannelQuantized(quantization_method)); + } + + // Returns true if the quantization method indicates per-channel quantization + // for convolution weights. This method specifically matches a quantization + // dimension of 3 for the input index 1 or unspecified quantization dimension + // for the input index 1. + bool IsWeightPerChannelQuantized(const Method& quantization_method) const { + if (quantization_method.has_static_range_ptq()) { + const StaticRangePtq& static_range_ptq_spec = + quantization_method.static_range_ptq(); + + if (static_range_ptq_spec.input_quantized_types().contains(1)) { + const QuantizedType& weight_quantized_type = + static_range_ptq_spec.input_quantized_types().at(1); + if (weight_quantized_type.has_per_tensor()) { + return false; + } + const QuantizedDimension& dimension_specs = + weight_quantized_type.dimension_specs(); + return !dimension_specs.has_dimension() || + dimension_specs.dimension() == 3; + } + } + return false; + } + + private: + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; +}; + +// Quantizes the entry function's body for weight-only quantized op. +template +class QuantizeWeightOnlyOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeWeightOnlyOpPattern( + const bool enable_per_channel_quantized_weight) + : enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} + + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_weight_only_ptq()) { + return failure(); + } + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, + PatternRewriter& rewriter) const override {} + + private: + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; +}; + +template +class QuantizeSingularOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeSingularOpPattern( + const bool enable_per_channel_quantized_weight) {} + + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_static_range_ptq()) { + return failure(); + } + const auto op_iterator_range = entry_func_op.getOps(); + if (op_iterator_range.empty()) { + LLVM_DEBUG(llvm::dbgs() << "Function does not have " + << SingularOpT::getOperationName() << " op.\n"); + return failure(); + } + + // Entry function body should have one block with two ops(op to be quantized + // and return op). + Region& body = entry_func_op.getBody(); + if (body.getBlocks().size() != 1 || + body.begin()->getOperations().size() != 2) { + return failure(); + } + + if (!isa( + (*op_iterator_range.begin()).getResult().getType())) { + LLVM_DEBUG(llvm::dbgs() << SingularOpT::getOperationName() + << " op must have ranked tensor type.\n"); + return failure(); + } + return success(); + } + + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, + PatternRewriter& rewriter) const override { + auto singular_op = *entry_func_op.getOps().begin(); + Value singular_op_result = singular_op.getResult(); + + // For ops that require same operand and result types, use explicit + // requantize op rather than using `entry_func_op`'s result as op result. + auto spec = GetStableHloQuantConstraints(singular_op); + const bool has_same_operand_and_result_type = + spec->has_same_operand_and_result_type_requirement; + if (has_same_operand_and_result_type) { + const Type operand_type = entry_func_op.getArgumentTypes()[0]; + const Type func_result_type = entry_func_op.getResultTypes()[0]; + + // Get the quantized tensor manipulation op's output type and update. + const auto singular_op_result_type = + mlir::cast(singular_op_result.getType()); + const ArrayRef singular_op_shape = + singular_op_result_type.getShape(); + const TensorType new_singular_op_result_type = + singular_op_result_type.cloneWith( + singular_op_shape, mlir::cast( + getElementTypeOrSelf(operand_type))); + singular_op_result.setType(new_singular_op_result_type); + + // Create requantization op and return. + rewriter.setInsertionPointAfter(singular_op); + CreateAndReturnUniformQuantizeOp(rewriter, *singular_op, entry_func_op, + func_result_type); + } else { + singular_op_result.setType(entry_func_op.getResultTypes()[0]); + } + } +}; + +// Converts `entry_func_op` to be quantized according to the respective +// inputs and outputs of `xla_call_module_op` that are possibly quantized. It +// signature (type) is reset to match that of `xla_call_module_op`. +// `entry_func_body_quantization_pattern` rewrites the function's body, based on +// the new signature. `quantization_method` specifies the quantization method +// applied to the quantizable unit `xla_call_module_op` and its corresponding +// function `entry_func_op`. +void QuantizeEntryFuncOp( + const MLIRContext& ctx, PatternRewriter& rewriter, + const TF::XlaCallModuleOp xla_call_module_op, func::FuncOp entry_func_op, + const EntryFuncBodyQuantizationPattern& body_rewrite_pattern, + const Method& quantization_method) { + SetQuantizedFunctionType(rewriter, entry_func_op, xla_call_module_op); + + body_rewrite_pattern.rewrite(entry_func_op, quantization_method, rewriter); + + // Rename the function to be clear that the function has been quantized. + const std::string quantized_function_name = + GetQuantizedFunctionName(entry_func_op.getSymName()); + entry_func_op.setSymName(quantized_function_name); +} + +// Replaces `xla_call_module_op` with a newly created `func::CallOp`, where the +// callee is `callee_func_op`. The existence of `kQuantizationMethodAttr` in +// `xla_call_module_op` should be guaranteed. +void ReplaceXlaCallModuleOpWithNewCallOp(TF::XlaCallModuleOp xla_call_module_op, + func::FuncOp callee_func_op, + PatternRewriter& rewriter) { + OpBuilder::InsertionGuard insertion_guard(rewriter); + + // Create a new `CallOp` that calls `callee_func_op`. + rewriter.setInsertionPoint(xla_call_module_op); + auto call_op = + rewriter.create(xla_call_module_op.getLoc(), callee_func_op, + xla_call_module_op.getArgs()); + + // Transfer the `kQuantizationMethodAttr` attribute to the `CallOp`, + // indicating what `Method` has been applied to the quantized unit. + call_op->setAttr( + kQuantizationMethodAttr, + xla_call_module_op->getAttrOfType(kQuantizationMethodAttr)); + + rewriter.replaceOp(xla_call_module_op, call_op); +} + +// Replaces a quantized `xla_call_module_op` with a `func::CallOp`. The callee +// is expected to remain unquantized (thus having a signature mismatch), and it +// is also quantized accordingly. +void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( + const MLIRContext& ctx, PatternRewriter& rewriter, + TF::XlaCallModuleOp xla_call_module_op, + const EntryFuncBodyQuantizationPattern& body_rewrite_pattern, + const Method& quantization_method) { + const ModuleOp module_op = xla_call_module_op->getParentOfType(); + + func::FuncOp entry_func_op = + GetEntryFuncOp(xla_call_module_op, SymbolTable(module_op)); + QuantizeEntryFuncOp(ctx, rewriter, xla_call_module_op, entry_func_op, + body_rewrite_pattern, quantization_method); + + ReplaceXlaCallModuleOpWithNewCallOp(xla_call_module_op, entry_func_op, + rewriter); +} + +// Pattern that mainly does two things: +// +// 1. Replaces quantized `TF::XlaCallModuleOp` with a `func::CallOp`. +// 2. Quantizes the callee function. +// +// The inputs of this pattern assumes an invalid IR, where even if a +// `TF::XlaCallModuleOp` is quantized the callee remains unquantized. Step (2) +// not only replaces the input and output tensor types into quantized ones, but +// also rewrites the body with a quantized equivalent. +// +// `FuncBodyRewritePatternT` defines how a function body is quantized and +// rewritten. +template >> +class XlaCallModuleOpToCallOp : public OpRewritePattern { + public: + explicit XlaCallModuleOpToCallOp( + MLIRContext& ctx, const bool enable_per_channel_quantized_weight) + : OpRewritePattern::OpRewritePattern(&ctx), + enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} + + LogicalResult matchAndRewrite(TF::XlaCallModuleOp op, + PatternRewriter& rewriter) const override { + ModuleOp module_op = op->getParentOfType(); + + // Ignore ops without quantization method. + // Consider adding checks for individual methods. + if (!op->getAttr(kQuantizationMethodAttr)) return failure(); + + // Ignore unquantized ops. + if (!IsQuantizedXlaCallModuleOp(op)) return failure(); + + // For weight-only quantization, op should be hybrid quantized. + if (HasWeightOnlyPtqMethod(op) && !IsHybridQuantizedOp(op)) { + return failure(); + } + + func::FuncOp entry_func_op = GetEntryFuncOp(op, SymbolTable(module_op)); + if (!entry_func_op) { + op->emitError("Failed to find a valid entry function."); + return failure(); + } + Method quantization_method = GetQuantizationMethodOrDefault(op); + if (FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) + .match(entry_func_op, quantization_method) + .failed()) { + return failure(); + } + + // TODO: b/331145946 - Each quantization method should be valid + // (GetQuantizationMethodOrDefault swallows invalid method attribute). Check + // the validity in `match()`. Use accessors to achieve this. + ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( + *rewriter.getContext(), rewriter, op, + FuncBodyRewritePatternT(enable_per_channel_quantized_weight_), + quantization_method); + return success(); + } + + private: + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; +}; + +// Quantizes op with regions such as stablehlo.reduce_window op. +// Quantizes only when the nested region consists of ops whose quantization +// parameters can be propagated from outside. +class QuantizeOpWithRegionPattern + : public OpRewritePattern { + public: + explicit QuantizeOpWithRegionPattern(MLIRContext& ctx) + : OpRewritePattern(&ctx) {}; + + LogicalResult matchAndRewrite(mlir::quant::ir::DequantizeCastOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(mlir::quant::ir::DequantizeCastOp op) const { + // Match only when there is one user of the dequantize op. + if (!op.getResult().hasOneUse()) { + return failure(); + } + + for (Operation* op_with_region : op.getResult().getUsers()) { + // Among the ops with regions, only reduce_window op is supported for now. + if (!isa(op_with_region)) { + return failure(); + } + + if (!IsNestedRegionQuantizable(op_with_region)) { + return failure(); + } + + // Quantization parameters can be propagated only for same-scale ops and + // same-scale ops are quantized only when they are connected to quantized + // composite functions. + if (!GetStableHloQuantConstraints(op_with_region) + ->has_same_scale_requirement || + !IsConnectedWithQuantizedCompsiteFunction(op_with_region)) { + return failure(); + } + } + return success(); + } + + void rewrite(mlir::quant::ir::DequantizeCastOp op, + PatternRewriter& rewriter) const { + // Rewrite the floating-point ops to the quantized version, by fusing + // preceding dequantize ops and succeding quantize ops. + for (Operation* op_with_region : op.getResult().getUsers()) { + // Collect all the quantized inputs and "clone" the matched op by these + // inputs. + SmallVector inputs; + inputs.reserve(op_with_region->getNumOperands()); + for (Value operand : op_with_region->getOperands()) { + const Type operand_type = operand.getType(); + if (mlir::isa(operand_type)) { + inputs.push_back(operand); + continue; + } + + const Type element_type = + mlir::cast(operand.getType()).getElementType(); + if (auto dq_op = dyn_cast_or_null( + operand.getDefiningOp())) { + inputs.push_back(dq_op.getOperand()); + } else if (isa(element_type)) { + // If the operand is an integer tensor, then it doesn't require the + // DequantizeOp in the pattern. + inputs.push_back(operand); + } else { + return; + } + } + + // Collect all the quantized outputs and replace them by the results of + // the new quantized op. + SmallVector outputs_replaced; + SmallVector output_types; + output_types.reserve(op_with_region->getNumResults()); + for (const Value result : op_with_region->getResults()) { + const Type result_type = result.getType(); + if (mlir::isa(result_type)) { + outputs_replaced.push_back(result); + output_types.push_back(result_type); + continue; + } + const Type result_element_type = + mlir::cast(result.getType()).getElementType(); + // If the user is the QuantizeOp, it must be the only user. + if (result.hasOneUse() && + isa(*result.user_begin())) { + auto user = + cast(*result.user_begin()); + outputs_replaced.push_back(user.getResult()); + output_types.push_back(user.getType()); + } else if (isa(result_element_type)) { + // If the result is an integer tensor, then it doesn't require the + // dequantize op in the pattern. + outputs_replaced.push_back(result); + output_types.push_back(result.getType()); + } else { + return; + } + } + + rewriter.setInsertionPointAfter(op_with_region); + OperationState new_state(op_with_region->getLoc(), + op_with_region->getName().getStringRef(), inputs, + output_types, op_with_region->getAttrs()); + for (int i = 0; i < op_with_region->getNumRegions(); ++i) { + new_state.addRegion(); + } + Operation* quantized_op = rewriter.create(new_state); + for (const auto& [index, region] : + llvm::enumerate(op_with_region->getRegions())) { + Region& target_region = quantized_op->getRegion(index); + IRMapping mapping; + region.cloneInto(&target_region, mapping); + } + + const Type operand_type = quantized_op->getOperandTypes()[0]; + const Type element_type = + mlir::cast(operand_type).getElementType(); + for (Region& region : quantized_op->getRegions()) { + ReplaceTypesInNestedRegion(region, element_type); + } + + for (auto [index, output] : llvm::enumerate(outputs_replaced)) { + output.replaceAllUsesWith(quantized_op->getResult(index)); + } + } + } + + // Checks if an op is quantizable in a nested region. + bool IsOpQuantizableInNestedRegion(Operation& op) const { + return isa(op); + } + + // Checks if a region only consists of ops that are quantizable in a nested + // region. + // tf.CustomAggregator op cannot be inserted into region of a StableHLO op, + // thus calibration is impossible within a nested region. Therefore, when an + // op involves a region, the op is only quantizable when the region only + // consists of ops whose quantization parameters can be propagated from + // outside. + bool IsNestedRegionQuantizable(Operation* op) const { + for (Region& region : op->getRegions()) { + for (Operation& op : region.getOps()) { + if (!IsOpQuantizableInNestedRegion(op)) { + return false; + } + } + } + return true; + } + + // Replaces all types in nested regions under the assumption that the body + // consists of same-scale ops only. + void ReplaceTypesInNestedRegion(Region& region, + const Type element_type) const { + for (BlockArgument arg : region.getArguments()) { + arg.setType(ReplaceElementType(arg.getType(), element_type)); + } + + for (Operation& op : region.getOps()) { + for (Value operand : op.getOperands()) { + operand.setType(ReplaceElementType(operand.getType(), element_type)); + } + + for (Value result : op.getResults()) { + result.setType(ReplaceElementType(result.getType(), element_type)); + } + } + } + + // Replaces element type of the given tensor type while preserving shape of + // the given type. If the given type is not tensor type, just return itself. + Type ReplaceElementType(const Type type, const Type element_type) const { + if (TensorType tensor_type = mlir::dyn_cast(type)) { + return tensor_type.clone(element_type); + } + return type; + } +}; + +} // namespace + +// Checks if an op calls a composite function and all the inputs and outputs are +// quantized. +bool IsQuantizedCompositeFunction(func::CallOp call_op) { + if (!call_op.getCallee().starts_with("quantized_")) { + return false; + } + + bool has_quantized_types = false; + for (Value operand : call_op.getOperands()) { + if (const TensorType type = mlir::dyn_cast(operand.getType())) { + if (mlir::isa(type.getElementType())) { + return false; + } + if (mlir::isa( + type.getElementType())) { + has_quantized_types = true; + } + } + } + for (const Value result : call_op.getResults()) { + if (const auto type = mlir::dyn_cast(result.getType())) { + if (mlir::isa(type.getElementType())) { + return false; + } + if (mlir::isa( + type.getElementType())) { + has_quantized_types = true; + } + } + } + return has_quantized_types; +} + +bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { + for (const Value operand : same_scale_op->getOperands()) { + auto dq_op = dyn_cast_or_null( + operand.getDefiningOp()); + if (!dq_op) continue; + + Operation* preceding_op = dq_op.getArg().getDefiningOp(); + if (!preceding_op) continue; + + // Check whether the preceding op is a quantized composite function. + if (isa(preceding_op)) { + auto call_op = cast(preceding_op); + if (!IsQuantizedCompositeFunction(call_op)) continue; + return true; + } + + // Check whether the preceding op is a quantized same-scale op. + if (GetStableHloQuantConstraints(preceding_op) + ->has_same_scale_requirement) { + for (const OpResult result : preceding_op->getResults()) { + const Type element_type = getElementTypeOrSelf(result.getType()); + if (mlir::isa(element_type)) { + return true; + } + } + } + } + + for (const Value result : same_scale_op->getResults()) { + // If the user is the Quantize op, it must be the only user. + if (!result.hasOneUse() || + !isa(*result.user_begin())) { + continue; + } + + auto q_op = cast(*result.user_begin()); + for (Operation* following_op : q_op->getUsers()) { + // Check whether the following op is a quantized composite function. + if (isa(following_op)) { + auto call_op = cast(following_op); + if (!IsQuantizedCompositeFunction(call_op)) continue; + return true; + } + + // Check whether the following op is a quantized same-scale op. + if (GetStableHloQuantConstraints(following_op) + ->has_same_scale_requirement) { + for (Value operand : following_op->getOperands()) { + const Type element_type = getElementTypeOrSelf(operand.getType()); + if (mlir::isa(element_type)) { + return true; + } + } + } + } + } + + return false; +} + +// Compute heavy patterns should be quantized for both server and ODML targets. +// Most patterns here are useful when quantized since they are compute heavy +// or memory bound. +void PopulateCommonQuantizationPatterns( + MLIRContext& ctx, RewritePatternSet& patterns, + const bool enable_per_channel_quantized_weight) { + patterns.add>( + ctx, enable_per_channel_quantized_weight); + patterns.add>( + ctx, enable_per_channel_quantized_weight); + patterns + .add>>( + ctx, enable_per_channel_quantized_weight); + patterns + .add>>( + ctx, enable_per_channel_quantized_weight); + // TODO: b/307620772 - Per-channel quantization for gather. + patterns.add>>( + ctx, /*enable_per_channel_quantized_weight=*/false); + // Populate pattern for quantization of ops with regions such as + // `stablehlo.reduce_window` op. + patterns.add(ctx); +} + +void PopulateAllQuantizablePatterns(MLIRContext& ctx, + RewritePatternSet& patterns) { + patterns.add>>( + ctx, /*enable_per_channel_quantized_weight=*/false); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.h new file mode 100644 index 00000000000000..f1098ed0aa120a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.h @@ -0,0 +1,254 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TF_QUANTIZATION_PATTERNS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TF_QUANTIZATION_PATTERNS_H_ + +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir::tf_quant::stablehlo { + +// Checks whether an op is connected with a quantized composite function. If +// not, the same-scale op will not be quantized. This decision is based on the +// current assumption that the performance gain of the same-scale op itself +// could not beat the overhead of the quantize and dequantize routines need to +// be added around that op. When the assumption changes, this policy might +// change as well. +bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op); + +// A base rewrite pattern which matches any N-in-M-out operations with +// quantization parameters propagated to at least one of its operands. The +// quantization parameters are annotated by the QuantizeOp/DequantizeOp pairs. +// Each matched pattern are rewritten by its quantized alternatives. +// +// Quantization method is determined by the `_quantization_method` attributes +// attached to each quantizable units. +// +// Template constraints are imposed as follows: +// +// * `QuantizeOpT` should have only one operand. +// * `DequantizeOpT` should have only one result. +template () && + DequantizeOpT::template hasTrait()>> +class StableHloQuantizationPattern : public OpRewritePattern { + public: + explicit StableHloQuantizationPattern(MLIRContext* context) + // Set the benefit to a large number so that it is always preferred. + : OpRewritePattern(context, /*benefit=*/300) {} + + private: + // Collects all candidate ops for quantization, which are the + // `dequantize_op`'s users. + FailureOr> CollectCandidateOps( + DequantizeOpT dequantize_op) const { + auto users = dequantize_op->getResult(0).getUsers(); + return SmallVector(users.begin(), users.end()); + } + + // Collects all candidate ops for quantization, which is the operand of + // `quantize_op`. If successful, this always returns one element which is the + // operand of `quantize_op`. + FailureOr> CollectCandidateOps( + QuantizeOpT quantize_op) const { + Value operand = quantize_op->getOperand(0); + if (QuantizedType::getQuantizedElementType(operand.getType())) { + // The input of the quantize op has already been quantized, i.e. + // rescale. + return failure(); + } + + Operation* operand_op = operand.getDefiningOp(); + if (operand_op == nullptr) { + // When `QuantizeOpT`'s operand does not have a defining op, it means it + // is a `BlockArgument`. The pattern does not match if there is no op to + // quantize. + return failure(); + } + + if (operand_op->hasTrait()) { + // Const-> QuantizeOp pattern will be handled separately. + return failure(); + } + + return SmallVector{operand_op}; + } + + LogicalResult matchAndRewrite(RootOpT op, + PatternRewriter& rewriter) const override { + // Collect all the candidate ops for quantization. + FailureOr> candidate_ops = CollectCandidateOps(op); + // Safeguard check to ensure that there is at least one quantizable op. + if (failed(candidate_ops) || candidate_ops->empty()) return failure(); + + // Rewrite the floating-point ops to the quantized version, by fusing + // preceding dequantize ops and succeding quantize ops. + for (Operation* candidate_op : *candidate_ops) { + // If it is requantize op, we shouldn't rewrite this op. + if (isa(candidate_op)) { + return failure(); + } + + // If the op is terminator, we shouldn't rewrite. + if (candidate_op->hasTrait()) { + return failure(); + } + + if (!IsOpQuantizableStableHlo(candidate_op)) { + return failure(); + } + + if (GetStableHloQuantConstraints(candidate_op) + ->has_same_scale_requirement && + !IsConnectedWithQuantizedCompsiteFunction(candidate_op)) { + return failure(); + } + + // Ops with regions will be quantized in a separate pattern. + if (isa(candidate_op)) { + return failure(); + } + + const bool weight_only_quantizable = + IsWeightOnlyQuantizableOp(*candidate_op); + + // Collect all the quantized inputs and "clone" the matched op by these + // inputs. + SmallVector inputs; + inputs.reserve(candidate_op->getNumOperands()); + for (auto operand : candidate_op->getOperands()) { + Type operand_type = operand.getType(); + if (mlir::isa(operand_type)) { + inputs.push_back(operand); + continue; + } + + auto ele_type = + mlir::cast(operand.getType()).getElementType(); + if (auto dq_op = + dyn_cast_or_null(operand.getDefiningOp())) { + inputs.push_back(dq_op.getOperand()); + } else if (!ele_type.isF32()) { + // If the operand is an integer tensor, then it doesn't require the + // DequantizeOp in the pattern. + inputs.push_back(operand); + } else if (weight_only_quantizable) { + inputs.push_back(operand); + } else { + return failure(); + } + } + + // Collect all the quantized outputs and replace them by the results of + // the new quantized op. + llvm::SmallDenseMap outputs_replaced; + SmallVector output_types; + output_types.reserve(candidate_op->getNumResults()); + for (const auto& enumerated_result : + llvm::enumerate(candidate_op->getResults())) { + Value result = enumerated_result.value(); + Type result_type = result.getType(); + // Add this to the test coverage once we create test ops with none type + // results. + if (mlir::isa(result_type)) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result_type); + continue; + } + Type result_ele_type = + mlir::cast(result.getType()).getElementType(); + // If the user is the QuantizeOp, it must be the only user. + if (result.hasOneUse() && isa(*result.user_begin())) { + auto user = cast(*result.user_begin()); + outputs_replaced.insert( + {user.getResult(), enumerated_result.index()}); + output_types.push_back(user.getType()); + } else if (!result_ele_type.isF32()) { + // If the result is an integer tensor, then it doesn't require the + // D op in the pattern. + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else if (weight_only_quantizable) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else { + return failure(); + } + } + + rewriter.setInsertionPointAfter(candidate_op); + OperationState new_state(candidate_op->getLoc(), + candidate_op->getName().getStringRef(), inputs, + output_types, candidate_op->getAttrs()); + for (int i = 0; i < candidate_op->getNumRegions(); ++i) { + new_state.addRegion(); + } + Operation* quantized_op = rewriter.create(new_state); + if (candidate_op->getNumRegions() != 0) { + for (const auto& indexed_regions : + llvm::enumerate(candidate_op->getRegions())) { + Region& target_region = + quantized_op->getRegion(indexed_regions.index()); + IRMapping mapping; + indexed_regions.value().cloneInto(&target_region, mapping); + } + } + for (auto output : outputs_replaced) { + output.getFirst().replaceAllUsesWith( + quantized_op->getResult(output.getSecond())); + } + } + return success(); + } +}; + +// Populates common patterns that are usually compute heavy or memory bound. +void PopulateCommonQuantizationPatterns( + MLIRContext& ctx, RewritePatternSet& patterns, + bool enable_per_channel_quantized_weight); + +// Populates conversion patterns for all quantizable ops, including +// ops that are not compute-heavy and data movement ops. +void PopulateAllQuantizablePatterns(MLIRContext& ctx, + RewritePatternSet& patterns); + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TF_QUANTIZATION_PATTERNS_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize.cc new file mode 100644 index 00000000000000..5dad68992a8067 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize.cc @@ -0,0 +1,111 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_QUANTIZEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Base struct for quantization. +template +struct StableHloQuantizationBase + : public StableHloQuantizationPattern { + explicit StableHloQuantizationBase(MLIRContext* ctx) + : StableHloQuantizationPattern(ctx) {} + + static bool AllowWeightOnlyQuantization(Operation& op) { return false; } +}; + +// Quantization rewrite pattern using DQ as the root op. +struct StableHloQuantization + : public StableHloQuantizationBase { + explicit StableHloQuantization(MLIRContext* ctx) + : StableHloQuantizationBase(ctx) {} +}; + +// Quantization rewrite pattern using Q as the root op. This is for the +// quantizable ops without floating-point operands. +struct StableHloQuantizationReverse + : public StableHloQuantizationBase { + explicit StableHloQuantizationReverse(MLIRContext* ctx) + : StableHloQuantizationBase(ctx) {} +}; + +class QuantizePass : public impl::QuantizePassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizePass) + + using impl::QuantizePassBase::QuantizePassBase; + + explicit QuantizePass(const bool enable_per_channel_quantized_weight) { + enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; + } + + private: + void runOnOperation() override; +}; + +void QuantizePass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + patterns.add(&ctx); + + PopulateCommonQuantizationPatterns(ctx, patterns, + enable_per_channel_quantized_weight_); + + // Quantize all quantizable ops, including ops that are not compute-heavy. + PopulateAllQuantizablePatterns(ctx, patterns); + + if (failed(applyPatternsGreedily(module_op, std::move(patterns)))) { + // There are cases where no rewrites happen even if a pattern matches, + // causing this to result in a convergence failure. Consider this as a + // best-effort. + module_op.emitWarning("Failed to converge pattern at QuantizePass."); + } +} + +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize_composite_functions.cc new file mode 100644 index 00000000000000..38379ef7b12df0 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize_composite_functions.cc @@ -0,0 +1,114 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep + +#define DEBUG_TYPE "quantize-composite-functions" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_QUANTIZECOMPOSITEFUNCTIONSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +using ::tensorflow::quantization::RunPassesOnModuleOp; + +class QuantizeCompositeFunctionsPass + : public impl::QuantizeCompositeFunctionsPassBase< + QuantizeCompositeFunctionsPass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizeCompositeFunctionsPass) + + using impl::QuantizeCompositeFunctionsPassBase< + QuantizeCompositeFunctionsPass>::QuantizeCompositeFunctionsPassBase; + + explicit QuantizeCompositeFunctionsPass( + const bool enable_per_channel_quantized_weight) { + enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; + } + + private: + void runOnOperation() override; +}; + +void QuantizeCompositeFunctionsPass::runOnOperation() { + MLIRContext& ctx = getContext(); + + PassManager pm(&ctx); + // Intermediate output from QuantizePass will have quantized ops + // (XlaCallModuleOps) with quantized input and output types, which are not + // allowed in the TF dialect. + pm.enableVerifier(false); + + PrepareQuantizePassOptions options; + options.enable_per_channel_quantized_weight_ = + enable_per_channel_quantized_weight_; + // Change this to user-given bit width once we have custom configuration. + options.bit_width_ = 8; + + // Insert quantization parameters for weights for ops with `weight_only_ptq` + // attribute. + pm.addNestedPass(createInsertWeightParamPass()); + + // PrepareQuantizePass uses SymbolTable to fetch relevant GEMM ops for + // determining quantization attributes. This requires module-level context. + pm.addPass(createPrepareQuantizePass(options)); + + QuantizePassOptions quantize_options; + quantize_options.enable_per_channel_quantized_weight_ = + enable_per_channel_quantized_weight_; + + // QuantizePass modifies FuncOps referenced outside of its given scope + // and therefore requires a module-level context. + pm.addPass(createQuantizePass(quantize_options)); + pm.addNestedPass(createPostQuantizePass()); + + // Convert XlaCallModuleOps lifted but not quantized to func.call op. + // The reasons these ops are not quantized may be: + // 1. Disabled due to selective quantization. + // 2. Not supported, e.g. add op for server. + pm.addPass(createXlaCallModuleToCallPass()); + + // TODO: b/321729008 - move this implementation to quantization_patterns.cc. + if (merge_fusion_with_dequantize_) { + pm.addPass(createMergeFusionWithDequantizePass()); + } + + ModuleOp module_op = getOperation(); + if (const absl::Status pm_run_status = + RunPassesOnModuleOp(mlir_dump_file_name_, pm, module_op); + !pm_run_status.ok()) { + signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize_weight.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize_weight.cc new file mode 100644 index 00000000000000..3b3435298f3801 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize_weight.cc @@ -0,0 +1,244 @@ +/* Copyright 2023 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "Eigen/Core" // from @eigen_archive +#include "llvm/ADT/SetVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" + +// NOLINTNEXTLINE +//===----------------------------------------------------------------------===// +// The Quantization Pass for Weight. +//===----------------------------------------------------------------------===// + +namespace mlir::tf_quant::stablehlo { + +// Put the definitions inside the ::mlir::tf_quant::stablehlo namespace, to +// match the declarations in tf_passes.h. +#define GEN_PASS_DEF_QUANTIZEWEIGHTPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +using QuantizationUnits = llvm::SetVector>; +using mlir::stablehlo::ConstantOp; +using mlir::stablehlo::ConvertOp; +using ::stablehlo::quantization::QuantizationComponentSpec; + +// Min/Max values used for creating ConstantOp. +constexpr float kMaxFloat16Value = 65504.f; +constexpr float kMinFloat16Value = -65504.f; + +class QuantizeWeightPass + : public impl::QuantizeWeightPassBase { + public: + explicit QuantizeWeightPass( + QuantizationComponentSpec quantization_component_spec) + : quantization_component_spec_(quantization_component_spec) {} + + private: + void runOnOperation() override; + QuantizationComponentSpec quantization_component_spec_; +}; + +// Collects quantizable target ops, then insert Q-DQ quantization patterns. +class QuantizeWeight : public OpRewritePattern { + public: + explicit QuantizeWeight( + MLIRContext* context, + const QuantizationComponentSpec& quantization_component_spec) + : OpRewritePattern(context), + quantization_component_spec_(quantization_component_spec) {} + + LogicalResult matchAndRewrite(ConstantOp op, + PatternRewriter& rewriter) const override { + // 1. Collect quantizable ops. + QuantizationUnits quantizable_ops = GetQuantizableOps(op); + if (quantizable_ops.empty()) { + return failure(); + } + + // 2. Quantize collected ops. + if (!QuantizeOps(rewriter, op, quantizable_ops)) { + return failure(); + } + + // 3. Complete the Q-DQ pair for each inference type. + if (!ConvertToFloat16Constant(rewriter, op)) { + return failure(); + } + return success(); + } + + private: + const QuantizationComponentSpec quantization_component_spec_; + // Marks users that are applicable for quantization where the criteria for + // determining quantizable ops differs by the inference type. + QuantizationUnits GetQuantizableOps(ConstantOp op) const { + // Non-float tensors do not need quantization. + QuantizationUnits quantizable_ops; + const ShapedType type = mlir::dyn_cast(op.getType()); + if (!type || !type.getElementType().isF32()) return quantizable_ops; + + const Value value = op.getResult(); + + for (OpOperand& use : value.getUses()) { + Operation* user = use.getOwner(); + const int operand_num = use.getOperandNumber(); + quantizable_ops.insert({user, operand_num}); + } + return quantizable_ops; + } + + // Returns whether quantization is applied to filtered users. + bool QuantizeOps(PatternRewriter& rewriter, ConstantOp op, + const QuantizationUnits& quantizable_ops) const { + for (const std::pair& quant_op : quantizable_ops) { + // For f16 quantization, quantize all constant ops as float16. + QuantizeOpAsFloat16(rewriter, op, quant_op); + } + // TODO: b/264218457 - Return a value that accurately captures result + // status. + return true; + } + + // Inserts ConvertOp which is used for converting float32 ConstantOp into + // float16 quantization. If there is an existing ConvertOp connected to the + // ConstantOp, the quantizable_op will be rewired to the existing ConvertOp. + // This guarantees at most one ConvertOp is created for float32 to float16 + // conversion. + void QuantizeOpAsFloat16(PatternRewriter& rewriter, ConstantOp op, + const std::pair quant_op) const { + const auto [quantizable_op, quantize_operand_num] = quant_op; + // If the constant is an output tensor, do nothing. + if (isa(quantizable_op)) { + return; + } + + TensorType old_result_type = + mlir::dyn_cast(op.getResult().getType()); + const FloatType quantized_type = Float16Type::get(op.getContext()); + const ShapedType new_result_type = old_result_type.clone(quantized_type); + + // Insert ConvertOp if it does not exist yet. Otherwise, just rewire without + // creating a ConvertOp. + for (const OpOperand& connected_op : op.getResult().getUses()) { + ConvertOp convert_op = + dyn_cast_or_null(connected_op.getOwner()); + // ConvertOp already exists. Rewire the existing convert op into f16. + if (convert_op && convert_op.getType() == new_result_type) { + quantizable_op->setOperand(quantize_operand_num, convert_op); + return; + } + } + rewriter.setInsertionPointAfter(op); + ConvertOp new_convert_op = rewriter.create( + op->getLoc(), new_result_type, op.getResult()); + quantizable_op->setOperand(quantize_operand_num, + new_convert_op.getResult()); + } + + // Returns whether a ConvertOp-Operation sequence can be converted into new + // ConstantOp-Convert-Operation. The new ConstantOp has float16 data type. + bool ConvertToFloat16Constant(PatternRewriter& rewriter, + ConstantOp op) const { + for (Operation* connected_op : op.getResult().getUsers()) { + ConvertOp convert_op = dyn_cast_or_null(connected_op); + // Skip if no convert op exists. + if (!convert_op || convert_op.getResult().use_empty()) continue; + + // Get types. + const Type old_result_type = op.getResult().getType(); + const ShapedType new_result_type = + mlir::dyn_cast(convert_op.getType()); + + // Proceeds only if the converting is to float16. + if (!new_result_type.getElementType().isF16()) continue; + + // Convert values. + std::vector new_values; + const DenseFPElementsAttr value_attr = + mlir::cast(op.getValue()); + new_values.reserve(value_attr.getNumElements()); + + for (const float value : value_attr.getValues()) { + new_values.push_back(Eigen::half( + std::min(std::max(value, kMinFloat16Value), kMaxFloat16Value))); + } + const DenseElementsAttr new_value_attr = DenseFPElementsAttr::get( + new_result_type, ArrayRef(new_values)); + // Create new ConstantOp-ConvertOp-Operation sequences. At this moment, + // old ConstantOp is guaranteed to have one F32->F16 convert op regardless + // of its number of users. + rewriter.setInsertionPointAfter(op); + // create new F16 constant op in that location + ConstantOp new_const = rewriter.create( + op->getLoc(), new_result_type, new_value_attr); + ConvertOp dcast = + rewriter.create(op->getLoc(), old_result_type, new_const); + // replace all convert ops with dq op. + convert_op->replaceAllUsesWith(dcast); + // Return without scanning for the next ConvertOp as only one ConvertOp is + // connected to all quantizable ops. + return true; + } + return false; + } +}; + +// TODO: b/264218457 - Refactors the current file to parse preset quantization +// options and allow modular control of quantization specs. +void QuantizeWeightPass::runOnOperation() { + func::FuncOp func = getOperation(); + MLIRContext* ctx = func.getContext(); + RewritePatternSet patterns(ctx); + + patterns.add(ctx, quantization_component_spec_); + + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + + if (failed(applyPatternsGreedily(func, frozen_patterns))) { + signalPassFailure(); + } +} + +} // namespace + +// Creates an instance of the StableHLO dialect Quantize Weight pass. +std::unique_ptr> CreateQuantizeWeightPass( + const QuantizationComponentSpec& quantization_component_spec) { + return std::make_unique(quantization_component_spec); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_remove_sharding_custom_call.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_remove_sharding_custom_call.cc new file mode 100644 index 00000000000000..cae6c33226dca7 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_remove_sharding_custom_call.cc @@ -0,0 +1,59 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_REMOVESHARDINGCUSTOMCALLPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +// Include patterns generated from `remove_sharding_custom_call.td`. +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.inc" + +class RemoveShardingCustomCallPass + : public impl::RemoveShardingCustomCallPassBase< + RemoveShardingCustomCallPass> { + public: + using impl::RemoveShardingCustomCallPassBase< + RemoveShardingCustomCallPass>::RemoveShardingCustomCallPassBase; + + private: + void runOnOperation() override; +}; + +void RemoveShardingCustomCallPass::runOnOperation() { + func::FuncOp func_op = getOperation(); + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + populateWithGenerated(patterns); + + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + if (failed(applyPatternsGreedily(func_op, frozen_patterns))) { + func_op.emitWarning() << "Failed to converge " + << RemoveShardingCustomCallPass::getArgumentName(); + } +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc new file mode 100644 index 00000000000000..6e4a608857e39b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc @@ -0,0 +1,536 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/dialect/Version.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/func.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" +#include "tensorflow/core/ir/types/dialect.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_REPLACESTABLEHLOOPSINMAINFUNCTIONWITHXLACALLMODULEOPSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +constexpr StringRef kStablehloModuleAttrsAttrName = "_stablehlo_module_attrs"; +constexpr StringRef kUsesShapePolymorphismAttr = "jax.uses_shape_polymorphism"; + +// Default version number for native serialization. +constexpr int64_t kDefaultVersion = 9; +// Platforms for XlaCallModuleOp. +constexpr StringRef kPlatformCpu = "CPU"; +constexpr StringRef kPlatformTpu = "TPU"; + +class ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass + : public impl:: + ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPassBase< + ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass) + + ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass() = default; + + ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass( + const ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass& other) = + default; + + private: + void runOnOperation() override; +}; + +// Creates a unique stablehlo function name based on op order. +std::string CreateStablehloFunctionName(const int id) { + return Twine("_stablehlo_main_").concat(std::to_string(id)).str(); +} + +// Follows the structure of Live-variable analysis. It is a form of +// CFG (Control Flow Graph) analysis, often used in compilers. +// +// A variable is live if it holds a value that may be used in the future. +// It is live-in at node n if it is live on any of the node's in-edges. +// It is live-out at node n if it is live on any of the node's out-edges. +// def[n] refers to values that are defined at node n. +// use[n] refers to values that are used at node n. +// +// Given a node n, variables' liveliness is defined like the following: +// live_in[n] = use[n] U (live_out[n] - def[n]) +// live_out[n] = U {live_in[s] | s ε succ[n]} +// +// Consider a sequence of op: +// +// ``` +// node 1: %0 = stablehlo.constant +// node 2: %1 = stablehlo.constant +// node 3: %2 = stablehlo.add %0, %1 +// node 4: %3 = stablehlo.multiply %2, %1 +// node 5: return %3 +// ``` +// +// In Backward Liveliness analysis, the liveliness for each node above becomes: +// live_in[5] = use[5] U (live_out[5] - def[5]) +// = {%3} U ({∅} - {∅}) = {%3} +// live_in[4] = use[4] U (live_out[4] - def[4]) +// = {%1, %2} U ({%3} - {%3}) = {%1, %2} +// live_in[3] = use[3] U (live_out[3] - def[3]) +// = {%0, %1} U ({%1, %2} - {%2}) = {%0, %1} +// live_in[2] = use[2] U (live_out[2] - def[2]) +// = {∅} U ({%0, %1} - {%1}) = {%0} +// live_in[1] = use[1] U (live_out[1] - def[1]) +// = {∅} U ({%0} - {%0}) = {∅} +// +// This analogy is used throughout this pass to ensure only live edges form +// proper subgraphs. +class LiveOuts { + public: + LiveOuts() = default; + + explicit LiveOuts(OperandRange range) + : liveouts_(range.begin(), range.end()), prev_liveouts_(liveouts_) {} + + // Delete the current op from liveouts and moves on to the parent ops. + void update(Operation& op) { + for (Value result_value : op.getResults()) { + liveouts_.remove(result_value); + } + for (Value operand : op.getOperands()) { + liveouts_.insert(operand); + } + } + + // Snapshot the current live values to previous live values. + void snapshot_previous_state() { prev_liveouts_ = liveouts_; } + + // Return the current live values. + const SetVector& get() const { return liveouts_; } + + // Return the previous live values. + const SetVector& get_previous() const { return prev_liveouts_; } + + private: + // Use SerVector to ensure deterministic traversal order. + SetVector liveouts_; + SetVector prev_liveouts_; +}; + +// Creates the tf.XlaCallModuleOp from attributes. +void CreateXlaCallModuleOp(ValueRange inputs, ValueRange outputs, + const TypeRange result_types, + const SetVector& reverse_subgraph, + const func::FuncOp stablehlo_func_op, + ModuleOp module_op) { + MLIRContext* ctx = module_op.getContext(); + OpBuilder builder(ctx); + Operation* last_subgraph_op = reverse_subgraph.front(); + builder.setInsertionPointAfter(last_subgraph_op); + + // Create attributes used for creating an XlaCallModuleOp. + SmallVector shape_attrs; + for (const Type result_type : result_types) { + shape_attrs.push_back( + tf_type::ShapeAttr::get(ctx, mlir::cast(result_type))); + } + const auto empty_array_attr = ArrayAttr::get(ctx, {}); + // TODO: b/310291615 - find a better way for platform support. + const auto platforms = ArrayAttr::get( + ctx, + {StringAttr::get(ctx, kPlatformCpu), StringAttr::get(ctx, kPlatformTpu)}); + + auto xla_call_module_op = builder.create( + module_op.getLoc(), /*output=*/result_types, + /*args=*/inputs, + /*version=*/kDefaultVersion, /*module=*/"", + /*Sout=*/ArrayAttr::get(ctx, shape_attrs), + /*dim_args_spec=*/empty_array_attr, platforms, + /*function_list=*/empty_array_attr, + /*has_token_input_output=*/false, + /*disabled_checks=*/empty_array_attr); + xla_call_module_op->setAttr(TF::kStablehloEntryFunctionAttrName, + SymbolRefAttr::get(stablehlo_func_op)); + std::string target_version = + mlir::vhlo::Version::fromCompatibilityRequirement( + vhlo::Version::CompatibilityRequirement::WEEK_4) + .toString(); + xla_call_module_op->setAttr(TF::kStablehloVersionAttrName, + builder.getStringAttr(target_version)); + // Set jax.uses_shape_polymorphism=true to enable shape refinement at runtime. + // This is needed for native serialization version >= 8. + xla_call_module_op->setAttr( + kStablehloModuleAttrsAttrName, + builder.getDictionaryAttr(builder.getNamedAttr( + kUsesShapePolymorphismAttr, builder.getBoolAttr(true)))); + + for (auto [original_output_value, xla_call_module_op_result_value] : + llvm::zip_equal(outputs, xla_call_module_op->getResults())) { + original_output_value.replaceAllUsesExcept(xla_call_module_op_result_value, + /*exceptedUser=*/nullptr); + } +} + +// Replaces the StableHLO ops with a separate XlaCallModuleOp, then wires it +// back into the main graph. +void ReplaceStablehloOpsWithXlaCallModuleOp( + const ArrayRef inputs, const ArrayRef outputs, + const SetVector& reverse_subgraph, const int stablehlo_func_id, + ModuleOp module_op) { + MLIRContext* ctx = module_op.getContext(); + OpBuilder builder(ctx); + + // Identify arg types & arg locs. + SmallVector arg_types; + SmallVector arg_locs; + + // Add an argument for platform_index. This allows for multiple platforms. + // TODO: b/310291615 - find a better way for platform support. + arg_types.push_back(RankedTensorType::get({}, builder.getI32Type())); + arg_locs.push_back(module_op.getLoc()); + for (const Value input_value : inputs) { + arg_types.push_back(input_value.getType()); + arg_locs.push_back(input_value.getLoc()); + } + + // Identify result types. + SmallVector result_types; + for (const Value output_value : outputs) { + result_types.push_back(output_value.getType()); + } + + // 1) Create FuncOp for the StableHLO ops. They will be separate subgraphs. + builder.setInsertionPoint(&*module_op.begin()); + auto stablehlo_func_op = builder.create( + module_op.getLoc(), CreateStablehloFunctionName(stablehlo_func_id), + FunctionType::get(ctx, arg_types, result_types)); + stablehlo_func_op.setVisibility(SymbolTable::Visibility::Private); + stablehlo_func_op->setAttr(TF::kFromXlaCallModuleAttrName, + builder.getUnitAttr()); + + builder.createBlock(&stablehlo_func_op.getBody(), stablehlo_func_op.begin(), + arg_types, arg_locs); + + IRMapping mapper; + // stablehlo_func_op has 1 extra arg for platform index. + for (auto [input, stablehlo_func_arg] : llvm::zip_equal( + inputs, stablehlo_func_op.getArguments().take_back(inputs.size()))) { + mapper.map(input, stablehlo_func_arg); + } + + for (Operation* subgraph_op : llvm::reverse(reverse_subgraph)) { + // Create a deep copy of the subgraph ops' operands to the func op. + stablehlo_func_op.getBody().begin()->push_back(subgraph_op->clone(mapper)); + } + + SmallVector result_values; + for (const Value original_output_value : outputs) { + // Use the mapped values in the newly created function that correspond to + // outputs in the original function. + result_values.push_back(mapper.lookup(original_output_value)); + } + builder.create(module_op.getLoc(), result_values); + + // 2) Create XlaCallModuleOp (with ops mapped). + CreateXlaCallModuleOp(inputs, outputs, result_types, reverse_subgraph, + stablehlo_func_op, module_op); + + // 3) Erase the replaced ops. + for (Operation* subgraph_op : reverse_subgraph) { + subgraph_op->erase(); + } +} + +// Contains the actual logic for updating states and replacing StableHLO ops +// with tf.XlaCallModuleOps. +void UpdateStatesAndReplaceStablehloOps( + const SetVector& operands, const SetVector& defined_values, + const LiveOuts& liveouts, ModuleOp module_op, + const SetVector& reverse_subgraph, const int stablehlo_func_id, + func::FuncOp main_func, const bool is_last_subgraph = false) { + SetVector inputs = operands; + for (Value defined_value : defined_values) { + inputs.remove(defined_value); + } + + SetVector outputs = liveouts.get_previous(); + for (const Value live_value : liveouts.get()) { + outputs.remove(live_value); + } + + if (is_last_subgraph) { + // Additionally remove arguments from the outputs, as it provides liveness + // throughout (functions as an invisible op above the very first op that + // returns the arguments). + for (const BlockArgument arg : main_func.getArguments()) { + outputs.remove(arg); + } + } + + ReplaceStablehloOpsWithXlaCallModuleOp( + SmallVector(inputs.begin(), inputs.end()), + SmallVector(outputs.begin(), outputs.end()), reverse_subgraph, + stablehlo_func_id, module_op); +} + +// Check if the op should be added to the subgraph. +// The op should be added to the subgraph if all of its users match one +// of following two conditions: +// 1: The user is already in the current subgraph. +// 2: The user will reach a dead end. +// +// If the op should be added to the subgraph and there are users who +// will reach the dead end, add the ops on the dead end to the subgraph as well. +bool ShouldAddOpToSubgraph(Operation* op, + const SetVector& reverse_subgraph, + const SetVector& ops_to_add, + SmallVector& all_descendants) { + if (!op) { + return false; + } + + SmallVector current_layer_descendants; + SmallVector next_layer_descendants; + int current_depth = 0; + current_layer_descendants.push_back(op); + // BFS downstream ops for current user. + // If any one of the descendants meet one of the three conditions, we return + // false for the current value: + // 1: The descendant is not in the ops_to_add. + // 2: The descendant is not a stablehlo op. + // 3: The depth of the descendant is larger than 5, we don't want to search + // too deep, max depth is arbitrarily chosen. + while (!current_layer_descendants.empty()) { + if (current_depth > 5) { + all_descendants.clear(); + return false; + } + current_depth++; + + for (Operation* descendant : current_layer_descendants) { + if (!quant::stablehlo::IsStablehloOp(descendant) || + !ops_to_add.contains(descendant)) { + all_descendants.clear(); + return false; + } + for (Operation* next_descendant : descendant->getUsers()) { + if (reverse_subgraph.contains(next_descendant)) { + continue; + } + next_layer_descendants.push_back(next_descendant); + } + all_descendants.push_back(descendant); + } + + current_layer_descendants = next_layer_descendants; + next_layer_descendants.clear(); + } + + return true; +} + +// Replaces the StableHLO ops in the main function block with +// tf.XlaCallModuleOps as separate subgraphs. Wires them back to the main +// function block to be compatible with SavedModel structure. +void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOps( + ModuleOp module_op, func::FuncOp main_func, int& stablehlo_func_id) { + Block& main_func_block = main_func.getBody().front(); + + // LiveOuts keeps track of live values at the output of some op. The updates + // must be made in a reverse, bottom-up manner. + const auto result_values = main_func_block.getTerminator()->getOperands(); + LiveOuts liveouts(result_values); + + // Copy ops to iterate because we will be modifying the block during + // iteration. The ordering should be reversed because liveness analysis is a + // bottom-up analysis. The terminator is not included because the return + // statement is not included in any subgraph (e.g. XlaCallModuleOp) and is + // untouched. + SmallVector reverse_main_func_block_ops; + SetVector ops_to_add; + for (Operation& main_func_block_op : + llvm::reverse(main_func_block.without_terminator())) { + reverse_main_func_block_ops.push_back(&main_func_block_op); + ops_to_add.insert(&main_func_block_op); + } + + // Create a separate subgraph invoked with XlaCallModuleOp per each + // set of StableHLO ops in the main func block. + SetVector reverse_subgraph; + SetVector operands; + SetVector defined_values; + + // Add op to the subgraph. + const auto add_to_subgraph = [&](Operation* op) { + // Move on to the parent ops. + liveouts.update(*op); + ops_to_add.remove(op); + + if (!quant::stablehlo::IsStablehloOp(op)) { + // Always update the liveouts when the subgraph isn't being continued. + liveouts.snapshot_previous_state(); + return; + } + + reverse_subgraph.insert(op); + defined_values.insert(op->getResults().begin(), op->getResults().end()); + operands.insert(op->getOperands().begin(), op->getOperands().end()); + }; + + for (Operation* op : reverse_main_func_block_ops) { + if (!ops_to_add.contains(op)) continue; + // When hitting a non-StableHLO op, i.e. tf.CustomAggregatorOp, start + // recursively tracing defining ops of the current subgraph's operands. This + // makes sure that all dependencies needed for shape inference are included + // in the subgraph. We only trace StableHLO ops that have all users inside + // the current subgraph. + // TODO: b/311239049 - Consider rewrite this using BFS. + if (!quant::stablehlo::IsStablehloOp(op)) { + bool should_add_op = true; + while (should_add_op) { + should_add_op = false; + SmallVector all_descendants; + for (Value v : operands) { + if (defined_values.contains(v)) continue; + if (ShouldAddOpToSubgraph(v.getDefiningOp(), reverse_subgraph, + ops_to_add, all_descendants)) { + should_add_op = true; + break; + } + } + if (should_add_op) { + for (auto descendant : llvm::reverse(all_descendants)) { + add_to_subgraph(descendant); + } + } + } + // Create an XlaCallModuleOp if reverse_subgraph isn't empty. + if (!reverse_subgraph.empty()) { + UpdateStatesAndReplaceStablehloOps(operands, defined_values, liveouts, + module_op, reverse_subgraph, + ++stablehlo_func_id, main_func); + + // Reset states and start a new subgraph. + reverse_subgraph.clear(); + operands.clear(); + defined_values.clear(); + } + } + add_to_subgraph(op); + } + + // Create the last subgraph if it isn't empty. + if (!reverse_subgraph.empty()) { + UpdateStatesAndReplaceStablehloOps( + operands, defined_values, liveouts, module_op, reverse_subgraph, + ++stablehlo_func_id, main_func, /*is_last_subgraph=*/true); + } +} + +// Duplicates small constants for each use. +// +// In the subsequent graph partitioning, constants for shape inference need to +// be in the same subgraph. But graph partitioning stops at ops with multiple +// uses. So here we duplicate small constants for each use so that if a +// constant is useful for shape inference for multiple subgraphs, they can be +// included in each subgraphs. If duplicate constants are accidentally created +// in the same subgraph, they can be easily removed with a canonicalizer pass. +// +// We set a size limit since constants needed for shape inference are no +// larger than tensor rank. This avoids duplicating large constants. +void DuplicateSmallConstantOps(ModuleOp module_op, func::FuncOp main_func) { + OpBuilder builder(main_func.getContext()); + for (auto constant_op : + main_func.getBody().getOps()) { + builder.setInsertionPointAfter(constant_op); + if (constant_op.getResult().use_empty() || + constant_op.getResult().hasOneUse()) + continue; + // Do not duplicate constant op if the size is too large. + // 32 is chosen to be larger than all constants useful for shape references, + // while not too large to possibly significantly increase model size. + if (constant_op.getValue().getNumElements() > 32) continue; + while (!constant_op.getResult().hasOneUse()) { + auto new_constant_op = builder.clone(*constant_op.getOperation()); + constant_op.getResult().getUses().begin()->assign( + dyn_cast(new_constant_op)); + } + } +} + +void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass:: + runOnOperation() { + ModuleOp module_op = getOperation(); + + func::FuncOp main_func = quant::FindMainFuncOp(module_op); + if (!main_func) return; + + // In case the model has tf.StatefulPartitionedCallOp or tf.PartitionedCallOp, + // we recursively find called functions and process StableHLO ops in them. + SmallVector func_ops; + func_ops.push_back(main_func); + int stablehlo_func_id = -1; + while (!func_ops.empty()) { + auto main_func = func_ops.back(); + func_ops.pop_back(); + if (!main_func) continue; + + SymbolTable symbol_table(module_op); + for (auto call_op : main_func.getOps()) { + func_ops.push_back(dyn_cast_or_null(symbol_table.lookup( + mlir::cast(call_op.getFAttr()).getValue()))); + } + for (auto call_op : main_func.getOps()) { + func_ops.push_back( + dyn_cast_or_null(symbol_table.lookup(call_op.getF()))); + } + + DuplicateSmallConstantOps(module_op, main_func); + ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOps(module_op, main_func, + stablehlo_func_id); + } + + // TODO - b/298966126: Currently quantizable functions are identified in TF + // Quantizer via the tf_quant.composite_function UnitAttr attached to + // func ops. We remove this attribute as this interferes with VHLO conversion. + // Remove this temporary hack. + for (auto func_op : module_op.getOps()) { + func_op->removeAttr(kFusedFunctionAttr); + } +} + +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_restore_function_name.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_restore_function_name.cc new file mode 100644 index 00000000000000..d047953693e2dc --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_restore_function_name.cc @@ -0,0 +1,94 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" + +//===----------------------------------------------------------------------===// +// The stablehlo-restore-function-name Pass. +//===----------------------------------------------------------------------===// + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_RESTOREFUNCTIONNAMEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Restores entry function name from XlaCallModuleOp attribute. +// This restoration is required because StableHLO functions are renamed during +// the XlaCallModuleSerialization. +class RestoreFunctionNamePass + : public impl::RestoreFunctionNamePassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RestoreFunctionNamePass) + + explicit RestoreFunctionNamePass() = default; + + void runOnOperation() override; +}; + +void RestoreFunctionNameFromXlaCallModuleOp(TF::XlaCallModuleOp& call_op, + SymbolTable& symbol_table) { + if (!call_op->hasAttr(kOriginalStablehloEntryFunctionAttrName)) { + return; + } + + const auto original_function_name = call_op->getAttrOfType( + kOriginalStablehloEntryFunctionAttrName); + const auto current_function_name = call_op->getAttrOfType( + TF::kStablehloEntryFunctionAttrName); + + if (!original_function_name || !current_function_name) { + return; + } + + auto function = + symbol_table.lookup(current_function_name.getValue()); + if (function) { + function.setName(original_function_name); + } + + call_op->setAttr(TF::kStablehloEntryFunctionAttrName, + FlatSymbolRefAttr::get(original_function_name)); +} + +void RestoreFunctionNamePass::runOnOperation() { + ModuleOp module_op = getOperation(); + + MLIRContext* ctx = module_op.getContext(); + OpBuilder builder(ctx); + SymbolTable symbol_table(module_op); + + // TODO - b/298966126: Improve this logic if needed. + module_op.walk([&](TF::XlaCallModuleOp call_op) { + RestoreFunctionNameFromXlaCallModuleOp(call_op, symbol_table); + }); +} +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_unfuse_mhlo_batch_norm.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_unfuse_mhlo_batch_norm.cc new file mode 100644 index 00000000000000..8a09a010e5c4b4 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_unfuse_mhlo_batch_norm.cc @@ -0,0 +1,59 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep +#include "xla/mlir_hlo/mhlo/transforms/rewriters.h" + +//===----------------------------------------------------------------------===// +// The unfuse-mhlo-batch-norm Pass. +//===----------------------------------------------------------------------===// + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_UNFUSEMHLOBATCHNORMPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +class UnfuseMhloBatchNormPass + : public impl::UnfuseMhloBatchNormPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnfuseMhloBatchNormPass) + + explicit UnfuseMhloBatchNormPass() = default; + + private: + void runOnOperation() override; +}; + +void UnfuseMhloBatchNormPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + mhlo::populateUnfuseBatchNormPatterns(ctx, &patterns); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } +} +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_unwrap_xla_call_module_op.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_unwrap_xla_call_module_op.cc new file mode 100644 index 00000000000000..2b80378bb8fdaa --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_unwrap_xla_call_module_op.cc @@ -0,0 +1,132 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_UNWRAPXLACALLMODULEOPPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Unwraps XlaCallModule ops without quantizable trait that call function with +// '_from_xla_call_module' trait. +class UnwrapXlaCallModuleOpPass + : public impl::UnwrapXlaCallModuleOpPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnwrapXlaCallModuleOpPass) + + explicit UnwrapXlaCallModuleOpPass() = default; + + private: + void runOnOperation() override; +}; + +void UnwrapXlaCallModuleOp(TF::XlaCallModuleOp call_op, + SymbolTable& symbol_table) { + // Do not inline lifted quantized functions used for fusing patterns. + // TODO - b/310539922: Remove reference to TF/TFL utils. + if (call_op->hasAttr(kQuantTraitAttrName)) { + return; + } + + auto function_name = call_op + ->getAttrOfType( + TF::kStablehloEntryFunctionAttrName) + .getValue(); + func::FuncOp func_op = symbol_table.lookup(function_name); + + // We should not unwrap if the function is not from + // ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass. + if (!func_op->hasAttr(TF::kFromXlaCallModuleAttrName)) { + return; + } + + MLIRContext* context = call_op.getContext(); + OpBuilder builder(context); + builder.setInsertionPointAfter(call_op); + + IRMapping arg_mapper; + bool call_op_has_platform_index_arg = call_op.getPlatforms().size() > 1; + // Add an argument for platform_index. This allows for multiple platforms. + // TODO: b/310291615 - find a better way for multi-platform support. + if (call_op_has_platform_index_arg) { + arg_mapper.map(func_op.getArgument(0), + builder.create( + func_op.getLoc(), builder.getI16IntegerAttr(0))); + } + for (auto [func_arg, operand] : llvm::zip_equal( + func_op.getArguments().take_back(call_op.getNumOperands()), + call_op.getOperands())) { + arg_mapper.map(func_arg, operand); + } + + Region& function_body = func_op.getBody(); + IRMapping new_op_mapper; + for (Operation& op : function_body.getOps()) { + if (llvm::isa(op)) { + for (auto [call_result, return_value] : + llvm::zip_equal(call_op.getResults(), op.getOperands())) { + Value new_result = new_op_mapper.lookup(return_value); + + call_result.replaceAllUsesWith(new_result); + } + continue; + } + + Operation& new_op = *builder.clone(op, arg_mapper); + for (auto [result, new_result] : + llvm::zip_equal(op.getResults(), new_op.getResults())) { + new_op_mapper.map(result, new_result); + } + } + + call_op.erase(); +} + +void UnwrapXlaCallModuleOpPass::runOnOperation() { + ModuleOp module_op = getOperation(); + SymbolTable symbol_table(module_op); + + for (auto func_op : module_op.getOps()) { + Region& function_body = func_op.getBody(); + + function_body.walk([&](TF::XlaCallModuleOp call_op) { + UnwrapXlaCallModuleOp(call_op, symbol_table); + }); + } +} + +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_xla_call_module_to_call.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_xla_call_module_to_call.cc new file mode 100644 index 00000000000000..250123ad9190d3 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_xla_call_module_to_call.cc @@ -0,0 +1,84 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_XLACALLMODULETOCALLPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Converts XlaCallModuleOps to func.call. +class XlaCallModuleToCallPass + : public impl::XlaCallModuleToCallPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XlaCallModuleToCallPass) + + explicit XlaCallModuleToCallPass() = default; + + private: + void runOnOperation() override; +}; + +// Converts XlaCallModuleOps to func.call. +class XlaCallModuleOpToCallOp : public OpRewritePattern { + public: + explicit XlaCallModuleOpToCallOp(MLIRContext* context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(TF::XlaCallModuleOp op, + PatternRewriter& rewriter) const override { + auto module_op = op->getParentOfType(); + SymbolTable symbol_table(module_op); + + auto entry_func_op = dyn_cast_or_null( + symbol_table.lookup(GetEntryFunctionName(op))); + if (!entry_func_op) return failure(); + + // Replace the XlaCallModuleOp with a new CallOp. + rewriter.replaceOpWithNewOp(op, entry_func_op, op.getArgs()); + return success(); + } +}; + +void XlaCallModuleToCallPass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext* ctx = module_op.getContext(); + RewritePatternSet patterns(&getContext()); + patterns.add(ctx); + if (failed(applyPatternsGreedily(module_op, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize.mlir new file mode 100644 index 00000000000000..69f50965332855 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize.mlir @@ -0,0 +1,140 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-prepare-quantize=enable-per-channel-quantized-weight=false -verify-diagnostics | FileCheck %s + +// ----- + +// CHECK-LABEL: func @dot +// CHECK-SAME: (%[[ARG_0:.*]]: tensor) -> tensor +func.func @dot(%arg0: tensor) -> tensor { + // CHECK: %[[cst:.*]] = stablehlo.constant + // CHECK: %[[q1:.*]] = "quantization.qcast"(%[[cst]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq1:.*]] = "quantization.dcast"(%[[q1]]) + // CHECK-SAME: quant.uniform + %cst = stablehlo.constant dense<[[-0.960978984, -0.390246302], [-0.790828585, -0.601039409], [-1.0280807, -1.02731466]]> : tensor<3x2xf32> + // CHECK: %[[q2:.*]] = "quantization.qcast"(%[[ARG_0]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) + // CHECK-SAME: quant.uniform + %0 = "quantization.stats"(%arg0) {bitsNum = 8 : i64, layerStats = dense<[-0.999415695, 0.99998933]> : tensor<2xf32>, narrowRange = false} : (tensor) -> tensor + // CHECK: %[[dot:.*]] = stablehlo.dot %[[dq2]], %[[dq1]] + %1 = stablehlo.dot %0, %cst : (tensor, tensor<3x2xf32>) -> tensor + // CHECK: %[[q3:.*]] = "quantization.qcast"(%[[dot]]) + // CHECK-SAME: quant.uniform> + // CHECK: %[[dq3:.*]] = "quantization.dcast"(%[[q3]]) + // CHECK-SAME: quant.uniform> + %2 = "quantization.stats"(%1) {bitsNum = 8 : i64, layerStats = dense<[-3.6289506, 5.61605835]> : tensor<2xf32>, narrowRange = false} : (tensor) -> tensor + // CHECK: return %[[dq3]] + func.return %2 : tensor +} + +// ----- + +// CHECK-LABEL: func @duplicate_stats +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32> +func.func @duplicate_stats(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK: %[[q1:.*]] = "quantization.qcast"(%[[ARG_0]]) + // CHECK: %[[dq1:.*]] = "quantization.dcast"(%[[q1]]) + // CHECK: %[[q2:.*]] = "quantization.qcast"(%[[dq1]]) + // CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) + // CHECK: stablehlo.convert %[[dq2]] + %0 = "quantization.stats"(%arg0) {bitsNum = 8 : i64, layerStats = dense<[-0.999415695, 0.99998933]> : tensor<2xf32>, narrowRange = false} : (tensor<2x3xf32>) -> tensor<2x3xf32> + %1 = "quantization.stats"(%0) {bitsNum = 8 : i64, layerStats = dense<[-2.0, 2.0]> : tensor<2xf32>, narrowRange = false} : (tensor<2x3xf32>) -> tensor<2x3xf32> + %2 = stablehlo.convert %1 : (tensor<2x3xf32>) -> (tensor<2x3xf32>) + func.return %2 : tensor<2x3xf32> +} + +// ----- + +// CHECK-LABEL: func @dot_redundant_stats +// CHECK-SAME: (%[[ARG_0:.*]]: tensor) -> tensor +func.func @dot_redundant_stats(%arg0: tensor) -> tensor { + // CHECK: %[[cst:.*]] = stablehlo.constant + // CHECK: %[[q1:.*]] = "quantization.qcast"(%[[cst]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq1:.*]] = "quantization.dcast"(%[[q1]]) + // CHECK-SAME: quant.uniform + %cst = stablehlo.constant dense<[[-0.960978984, -0.390246302], [-0.790828585, -0.601039409], [-1.0280807, -1.02731466]]> : tensor<3x2xf32> + // CHECK: %[[q2:.*]] = "quantization.qcast"(%[[ARG_0]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) + // CHECK-SAME: quant.uniform + %0 = "quantization.stats"(%arg0) {bitsNum = 8 : i64, layerStats = dense<[-100.2, 212.4]> : tensor<2xf32>, narrowRange = false} : (tensor) -> tensor + %1 = "quantization.qcast"(%0) {volatile} : (tensor) -> tensor> + %2 = "quantization.dcast"(%1) : (tensor>) -> tensor + // CHECK: %[[dot:.*]] = stablehlo.dot %[[dq2]], %[[dq1]] + %3 = stablehlo.dot %2, %cst : (tensor, tensor<3x2xf32>) -> tensor + // CHECK: %[[q3:.*]] = "quantization.qcast"(%[[dot]]) + // CHECK-SAME: quant.uniform> + // CHECK: %[[dq3:.*]] = "quantization.dcast"(%[[q3]]) + // CHECK-SAME: quant.uniform> + %4 = "quantization.stats"(%3) {bitsNum = 8 : i64, layerStats = dense<[-3.6289506, 5.61605835]> : tensor<2xf32>, narrowRange = false} : (tensor) -> tensor + // CHECK: return %[[dq3]] + func.return %4 : tensor +} + +// ----- + +// CHECK-LABEL: func @reshape_same_scale_propagate +func.func @reshape_same_scale_propagate(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { + // CHECK: %[[dq:.*]] = "quantization.dcast" + // CHECK-SAME: (tensor<2x3x!quant.uniform>) + %0 = "quantization.stats"(%arg0) {bitsNum = 8 : i64, layerStats = dense<[-0.999415695, 0.99998933]> : tensor<2xf32>, narrowRange = false} : (tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK: %[[reshape:.*]] = stablehlo.reshape %[[dq]] + %1 = stablehlo.reshape %0 : (tensor<2x3xf32>) -> (tensor<6xf32>) + // CHECK: %[[q:.*]] = "quantization.qcast"(%[[reshape]]) + // CHECK-SAME: -> tensor<6x!quant.uniform> + %2 = "quantization.stats"(%1) {bitsNum = 8 : i64, layerStats = dense<[-2.0, 2.0]> : tensor<2xf32>, narrowRange = false} : (tensor<6xf32>) -> tensor<6xf32> + func.return %2 : tensor<6xf32> +} + +// ----- + +// CHECK-LABEL: func @merge_consecutive_qcast +// CHECK-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor, %[[ARG_2:.*]]: tensor) -> (tensor, tensor) +func.func @merge_consecutive_qcast(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor) { + // CHECK: "quantization.qcast"(%[[ARG_1]]) + // CHECK-SAME: -> tensor> + // CHECK: "quantization.qcast"(%[[ARG_1]]) + // CHECK-SAME: -> tensor> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[-0.83811146, 2.4960899]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "quantization.stats"(%arg1) {layerStats = dense<[-0.835039615, 1.000000e+00]> : tensor<2xf32>} : (tensor) -> tensor + %2 = "stablehlo.concatenate"(%0, %1) {dimension = 0 : i64} : (tensor, tensor) -> tensor + %3 = "quantization.stats"(%2) {layerStats = dense<[-0.83811146, 2.4960899]> : tensor<2xf32>} : (tensor) -> tensor + %4 = "quantization.stats"(%arg2) {layerStats = dense<[-1.5726943, 1.07351148]> : tensor<2xf32>} : (tensor) -> tensor + %5 = "stablehlo.concatenate"(%4, %1) {dimension = 0 : i64} : (tensor, tensor) -> tensor + %6 = "quantization.stats"(%5) {layerStats = dense<[-1.5726943, 4.6875381]> : tensor<2xf32>} : (tensor) -> tensor + func.return %3, %6 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @skip_nan_inf_constant +// CHECK-SAME: (%[[ARG_0:.*]]: tensor) -> tensor +func.func @skip_nan_inf_constant(%arg0: tensor) -> tensor { + // CHECK-DAG: %[[cst0:.*]] = stablehlo.constant dense<0xFF800000> : tensor : tensor + // CHECK-DAG: %[[cst2:.*]] = stablehlo.constant dense<6.000000e+00> : tensor + // CHECK-DAG: %[[cst3:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK-NOT: %[[q0:.*]] = "quantization.qcast"(%[[cst0]]) + // CHECK-NOT: %[[q1:.*]] = "quantization.qcast"(%[[cst1]]) + // CHECK: %[[q2:.*]] = "quantization.qcast"(%[[cst2]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[q3:.*]] = "quantization.qcast"(%[[cst3]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq3:.*]] = "quantization.dcast"(%[[q3]]) + // CHECK-SAME: quant.uniform + %0 = stablehlo.constant dense<0xFF800000> : tensor + %1 = stablehlo.constant dense<0x7FC00000> : tensor + %2 = stablehlo.constant dense<6.000000e+00> : tensor + %3 = stablehlo.constant dense<0.000000e+00> : tensor + %4 = "stablehlo.add"(%0, %1) : (tensor, tensor) -> tensor + %5 = stablehlo.clamp %3, %arg0, %2 : (tensor, tensor, tensor) -> tensor + %6 = "stablehlo.reduce_window"(%5, %4) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %7 : tensor + }) {padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor + return %6 : tensor +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize_int4.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize_int4.mlir new file mode 100644 index 00000000000000..81a95f9066bc6c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize_int4.mlir @@ -0,0 +1,26 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-prepare-quantize=bit-width=4 -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @dot_int4 +// CHECK-SAME: (%[[ARG_0:.*]]: tensor) -> tensor +func.func @dot_int4(%arg0: tensor) -> tensor { + // CHECK: %[[cst:.*]] = stablehlo.constant + // CHECK: %[[q1:.*]] = "quantization.qcast"(%[[cst]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq1:.*]] = "quantization.dcast"(%[[q1]]) + // CHECK-SAME: quant.uniform + %cst = stablehlo.constant dense<[[-0.960978984, -0.390246302], [-0.790828585, -0.601039409], [-1.0280807, -1.02731466]]> : tensor<3x2xf32> + // CHECK: %[[q2:.*]] = "quantization.qcast"(%[[ARG_0]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) + // CHECK-SAME: quant.uniform + %0 = "quantization.stats"(%arg0) {bitsNum = 8 : i64, layerStats = dense<[-0.999415695, 0.99998933]> : tensor<2xf32>, narrowRange = false} : (tensor) -> tensor + // CHECK: %[[dot:.*]] = stablehlo.dot %[[dq2]], %[[dq1]] + %1 = stablehlo.dot %0, %cst : (tensor, tensor<3x2xf32>) -> tensor + // CHECK: %[[q3:.*]] = "quantization.qcast"(%[[dot]]) + // CHECK-SAME: quant.uniform> + // CHECK: %[[dq3:.*]] = "quantization.dcast"(%[[q3]]) + // CHECK-SAME: quant.uniform> + %2 = "quantization.stats"(%1) {bitsNum = 8 : i64, layerStats = dense<[-3.6289506, 5.61605835]> : tensor<2xf32>, narrowRange = false} : (tensor) -> tensor + // CHECK: return %[[dq3]] + func.return %2 : tensor +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize_per_channel.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize_per_channel.mlir new file mode 100644 index 00000000000000..196c517d3f4657 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize_per_channel.mlir @@ -0,0 +1,130 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-prepare-quantize=enable-per-channel-quantized-weight=true -verify-diagnostics | FileCheck %s + +// ----- + +module { + // CHECK-LABEL: conv_with_bias_and_relu + func.func private @conv_with_bias_and_relu(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> { + %cst = "tf.Const"() {device = "", value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>} : () -> tensor<2xf32> + // CHECK: %[[q_weight_per_channel:.*]] = "quantization.qcast" + // CHECK-SAME: -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.075123051020104109,0.072960192762960605}>> + // CHECK: %[[dq_weight:.*]] = "quantization.dcast"(%[[q_weight_per_channel]]) + %cst_0 = "tf.Const"() {device = "", value = dense<[[[[-6.30731344, 5.4962182], [1.80364347, -7.64542675], [-2.11145878, -7.08605719]], [[-9.54062747, -6.14013147], [6.12640238, -4.18223286], [5.05738974, 8.99269962]], [[3.3535192, 0.84816426], [-6.64676809, -7.95477629], [5.81315517, 9.21566581]]], [[[1.38622558, 4.63866329], [4.54742622, -1.43770897], [-3.96835279, 2.99996852]], [[0.989735424, -4.83384752], [-7.27702999, 1.17216611], [1.33735656, 0.728900194]], [[5.1286211, 8.98645591], [1.55008793, -3.85491467], [3.7003777, 9.26594448]]]]> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + // CHECK: %[[q_act:.*]] = "quantization.qcast"(%arg0) + // CHECK-SAME: -> tensor<1x3x2x3x!quant.uniform> + // CHECK: %[[dq_act:.*]] = "quantization.dcast"(%[[q_act]]) + %0 = "quantization.stats"(%arg0) {layerStats = dense<[1.27501142, 4.824783]> : tensor<2xf32>} : (tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> + // CHECK: "tf.XlaCallModule"(%[[dq_act]], %[[dq_weight]] + %1 = "tf.XlaCallModule"(%0, %cst_0, %cst) { + Sout = [#tf_type.shape<1x2x2x2>], config = "", + module = "composite_conv2d_with_bias_and_relu6_fn_10", + _entry_function = @composite_conv2d_with_bias_and_relu6_fn_10, + // Represents a per-channel quantization for the operand index 1 with + // quantization dimension of 3 + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + platforms = [], version = 4 : i64 + } : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x2x2x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[0.000000e+00, 6.000000e+00]> : tensor<2xf32>} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %2 : tensor<1x2x2x2xf32> + } + + // CHECK-LABEL: composite_conv2d_with_bias_and_relu6_fn_10 + func.func private @composite_conv2d_with_bias_and_relu6_fn_10(%arg0: tensor<1x3x2x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x2x2x2xf32> attributes {tf.tf_quant.composite_function} { + %0 = "quantization.stats"(%arg1) {layerStats = dense<[-3.54062747, 0.54742622]> : tensor<2xf32>} : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2xf32> + %1 = "quantization.stats"(%arg0) {layerStats = dense<[1.27501142, 2.824783]> : tensor<2xf32>} : (tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> + %2 = stablehlo.convolution(%1, %0) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [1, 1]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) + -> tensor<1x2x2x2xf32> + %3 = "quantization.stats"(%arg2) {layerStats = dense<[7.05456924, 7.11401462]> : tensor<2xf32>} : (tensor<2xf32>) -> tensor<2xf32> + %4 = "quantization.stats"(%2) {layerStats = dense<[-1.36523, 3.57373]> : tensor<2xf32>} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + %5 = "chlo.broadcast_add"(%4, %3) : (tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x2x2x2xf32> + %6 = "quantization.stats"(%5) {layerStats = dense<[-1.31055, 2.62842]> : tensor<2xf32>} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + %cst_min = stablehlo.constant dense<0.0> : tensor + %cst_max = stablehlo.constant dense<6.0> : tensor + %7 = "stablehlo.clamp"(%cst_min, %6, %cst_max) {device = ""} : (tensor, tensor<1x2x2x2xf32>, tensor) -> tensor<1x2x2x2xf32> + %8 = "quantization.stats"(%7) {layerStats = dense<[0.000000e+00, 6.000000e+00]> : tensor<2xf32>} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %8 : tensor<1x2x2x2xf32> + } +} + +// ----- + +module { + // CHECK-LABEL: dot_general + func.func private @dot_general(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: %[[q_weight:.*]] = "quantization.qcast" + // CHECK-SAME: -> tensor<2x2x!quant.uniform:f32:1, {0.049663885371891529,0.060200210631363035}>> + // CHECK: %[[dq_weight:.*]] = "quantization.dcast"(%[[q_weight]]) + %cst = "tf.Const"() {device = "", value = dense<[[-6.30731344, 5.4962182], [1.80364347, -7.64542675]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + // CHECK: %[[q_act:.*]] = "quantization.qcast"(%arg0) + // CHECK-SAME: -> tensor<2x2x!quant.uniform> + // CHECK: %[[dq_act:.*]] = "quantization.dcast"(%[[q_act]]) + %0 = "quantization.stats"(%arg0) {layerStats = dense<[1.27501142, 4.824783]> : tensor<2xf32>} : (tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: "tf.XlaCallModule"(%[[dq_act]], %[[dq_weight]] + %1 = "tf.XlaCallModule"(%0, %cst) { + Sout = [#tf_type.shape<2x2>], config = "", + _entry_function = @composite_dot_general, + module = "composite_dot_general", + platforms = [], version = 4 : i64 + } : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[0.000000e+00, 6.000000e+00]> : tensor<2xf32>} : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %2 : tensor<2x2xf32> + } + + // CHECK-LABEL: composite_dot_general + func.func private @composite_dot_general(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [0] + > + } : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + } +} + +// ----- + +// Tests that the `PrepareQuantizePass` prepares for per-tensor quantization for +// the weight of convolution. This is based on the `_quantization_method` that +// does not have a `input_quantized_types` with a specified `dimension_specs`. + +// CHECK-LABEL: conv_per_tensor_quantized_method +func.func private @conv_per_tensor_quantized_method(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> { + %cst = "tf.Const"() {device = "", value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<[[[[-6.30731344, 5.4962182], [1.80364347, -7.64542675], [-2.11145878, -7.08605719]], [[-9.54062747, -6.14013147], [6.12640238, -4.18223286], [5.05738974, 8.99269962]], [[3.3535192, 0.84816426], [-6.64676809, -7.95477629], [5.81315517, 9.21566581]]], [[[1.38622558, 4.63866329], [4.54742622, -1.43770897], [-3.96835279, 2.99996852]], [[0.989735424, -4.83384752], [-7.27702999, 1.17216611], [1.33735656, 0.728900194]], [[5.1286211, 8.98645591], [1.55008793, -3.85491467], [3.7003777, 9.26594448]]]]> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[1.27501142, 4.824783]> : tensor<2xf32>} : (tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst_0, %cst) { + Sout = [#tf_type.shape<1x2x2x2>], config = "", + module = "composite_conv_fn_1", + _entry_function = @composite_conv_fn_1, + _quantization_method = "static_range_ptq {}", + platforms = [], version = 4 : i64 + } : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x2x2x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[0.000000e+00, 6.000000e+00]> : tensor<2xf32>} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %2 : tensor<1x2x2x2xf32> +} +// CHECK-SAME: %[[ARG_0:.+]]: tensor<1x3x2x3xf32> + +// Test that the weight is prepared for per-tensor quantization, based on the +// `_quantization_method` attribute without a `dimension_specs` field in +// `QuantizedType`. +// CHECK-DAG: %[[WEIGHT_CONST:.+]] = stablehlo.constant {{.*}} tensor<2x3x3x2xf32> +// CHECK: %[[Q_WEIGHT_PER_TENSOR:.*]] = "quantization.qcast"(%[[WEIGHT_CONST]]) {{.*}} (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> +// CHECK: %[[DQ_WEIGHT:.*]] = "quantization.dcast"(%[[Q_WEIGHT_PER_TENSOR]]) + +// CHECK: %[[Q_ACTIVATION:.*]] = "quantization.qcast"(%[[ARG_0]]) +// CHECK-SAME: -> tensor<1x3x2x3x!quant.uniform> +// CHECK: %[[DQ_ACTIVATION:.*]] = "quantization.dcast"(%[[Q_ACTIVATION]]) +// CHECK: "tf.XlaCallModule"(%[[DQ_ACTIVATION]], %[[DQ_WEIGHT]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize.mlir new file mode 100644 index 00000000000000..17e38625a42e2b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize.mlir @@ -0,0 +1,74 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-quantize -verify-each=false | FileCheck %s + +// Tests for PopulateFusedGemmStylePatterns are handled in +// quantize_composite_functions for module-level evaluation of functions. + +module attributes {tf_saved_model.semantics} { +// CHECK: quantize_simple_xla_call_module(%[[ARG_0:.+]]: tensor<1x4xf32>) + func.func private @quantize_simple_xla_call_module(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> + %1 = "quantization.qcast"(%0) {volatile} : (tensor<4x3xf32>) -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03, 5.000000e-03, 5.000000e-03}>> + %2 = "quantization.dcast"(%1) : (tensor<4x3x!quant.uniform:f32:1, {5.000000e-03, 5.000000e-03, 5.000000e-03}>>) -> tensor<4x3xf32> + %3 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + %4 = "quantization.dcast"(%3) : (tensor<1x4x!quant.uniform>) -> tensor<1x4xf32> + %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %6 = "quantization.qcast"(%5) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %7 = "quantization.dcast"(%6) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %7 : tensor<1x3xf32> + } +// Test that the inputs and output of the tf.XlaCallModule op has been replaced +// by quantized types, and the corresponding quantization.dcast ops that turned +// those quantized types back to float types are removed. +// CHECK: %[[CONST_0:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> +// CHECK-DAG: %[[QCAST_0:.+]] = "quantization.qcast"(%[[CONST_0]]) {volatile} : (tensor<4x3xf32>) -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> +// CHECK-DAG: %[[QCAST_1:.+]] = "quantization.qcast"(%[[ARG_0]]) {volatile} : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[QCAST_1]], %[[QCAST_0]]) +// Test that the `Method` has been copied over. +// CHECK-SAME: {_quantization_method = "static_range_ptq { }"} +// CHECK-SAME: : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[DCAST_0:.+]] = "quantization.dcast"(%[[CALL_0]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: return + + func.func private @composite_dot_general_fn(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +// Tests that the output of the tf.XlaCallModule op has been replaced by +// a quantized type, and the corresponding quantization.qcast ops that turned +// the float output to a quantized type is removed. + +// CHECK-LABEL: quantize_simple_xla_call_module_no_operand +func.func private @quantize_simple_xla_call_module_no_operand() -> tensor<1x3xf32> { + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantization.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantization.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> +} +// CHECK: %[[XLA_CALL_MODULE_0:.+]] = "tf.XlaCallModule"() <{{{.*}}}> {{{.*}}} : () -> tensor<1x3x!quant.uniform> +// CHECK: %[[DCAST_0:.+]] = "quantization.dcast"(%[[XLA_CALL_MODULE_0]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: "func.return"(%[[DCAST_0]]) : (tensor<1x3xf32>) -> () + +// ----- + +// Tests for emitting an error when there is no corresponding entry +// function to quantize (@composite_dot_general_fn). + +module attributes {tf_saved_model.semantics} { + func.func private @error_when_no_entry_function(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<1.000000e+00> : tensor<2x3xf32> + %1 = "quantization.qcast"(%0) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 5.000000e-03>> + %2 = "quantization.dcast"(%1) : (tensor<2x3x!quant.uniform:f32, 5.000000e-03>>) -> tensor<2x3xf32> + %3 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %4 = "quantization.dcast"(%3) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// expected-error @+2 {{Failed to find a valid entry function}} +// expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} + %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %6 = "quantization.qcast"(%5) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %7 = "quantization.dcast"(%6) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %7 : tensor<1x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_op_with_region.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_op_with_region.mlir new file mode 100644 index 00000000000000..5edfea7bc49025 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_op_with_region.mlir @@ -0,0 +1,241 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-quantize -verify-each=false | FileCheck %s + +// Tests if reduce_window op following quantized function is quantized. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1722 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: main_00 + // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x1x1024xf32> + func.func private @main_00(%arg0: tensor<2x3x1x1024xf32>) -> tensor<2x3x1x3xf32> attributes {tf._original_func_name = "main_0"} { + // CHECK: %[[CST0:.*]] = stablehlo.constant dense<0xFF800000> : tensor + // CHECK: %[[CST1:.*]] = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + // CHECK: %[[Q0:.*]] = "quantization.qcast"(%[[CST0]]) + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[CST1]]) + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG0]]) + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q2]], %[[Q1]]) + + // CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[CALL]], %[[Q0]]) + // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + // CHECK-SAME: window_dimensions = array + // CHECK: %[[ARG1:.*]]: tensor>, %[[ARG2:.*]]: tensor> + // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> + // CHECK: stablehlo.return %[[MAX]] : tensor> + // CHECK: (tensor<2x3x1x3x!quant.uniform>, tensor>) -> tensor<2x3x1x3x!quant.uniform> + + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[REDUCE]]) + // CHECK: return %[[DQ]] + + %0 = stablehlo.constant dense<0xFF800000> : tensor + %1 = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + %2 = "quantization.qcast"(%0) {volatile} : (tensor) -> tensor> + %3 = "quantization.dcast"(%2) : (tensor>) -> tensor + %4 = "quantization.qcast"(%1) {volatile} : (tensor<2x3x1024x3xf32>) -> tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>> + %5 = "quantization.dcast"(%4) : (tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>>) -> tensor<2x3x1024x3xf32> + %6 = "quantization.qcast"(%arg0) {volatile} : (tensor<2x3x1x1024xf32>) -> tensor<2x3x1x1024x!quant.uniform> + %7 = "quantization.dcast"(%6) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> + %8 = "tf.XlaCallModule"(%7, %5) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + %9 = "quantization.qcast"(%8) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> + %10 = "quantization.dcast"(%9) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> + %11 = "stablehlo.reduce_window"(%10, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %14 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %14 : tensor + }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array} : (tensor<2x3x1x3xf32>, tensor) -> tensor<2x3x1x3xf32> + %12 = "quantization.qcast"(%11) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> + %13 = "quantization.dcast"(%12) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> + return %13 : tensor<2x3x1x3xf32> + } + + // CHECK: quantized_dot_general_fn_1 + func.func private @composite_dot_general_fn_1(%arg0: tensor<2x3x1x1024xf32>, %arg1: tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general + // CHECK: %[[RQ:.*]] = stablehlo.uniform_quantize %[[DOT]] + // CHECK: return %[[RQ]] + + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + return %0 : tensor<2x3x1x3xf32> + } +} + +// ----- + +// Tests if reduce_window op preceding quantized function is quantized. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1722 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: main_00 + // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x1x1024xf32> + func.func private @main_00(%arg0: tensor<2x3x1x1024xf32>) -> tensor<2x3x1x3xf32> attributes {tf._original_func_name = "main_0"} { + // CHECK: %[[CST0:.*]] = stablehlo.constant dense<0xFF800000> : tensor + // CHECK: %[[CST1:.*]] = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + // CHECK: %[[Q0:.*]] = "quantization.qcast"(%[[CST0]]) + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) + + // CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[Q1]], %[[Q0]]) + // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + // CHECK-SAME: window_dimensions = array + // CHECK: %[[ARG1:.*]]: tensor>, %[[ARG2:.*]]: tensor> + // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> + // CHECK: stablehlo.return %[[MAX]] : tensor> + // CHECK: (tensor<2x3x1x1024x!quant.uniform>, tensor>) -> tensor<2x3x1x1024x!quant.uniform> + + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[CST1]]) + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[REDUCE]], %[[Q2]]) + + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[CALL]]) + // CHECK: return %[[DQ]] + + %0 = stablehlo.constant dense<0xFF800000> : tensor + %1 = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + %2 = "quantization.qcast"(%0) {volatile} : (tensor) -> tensor> + %3 = "quantization.dcast"(%2) : (tensor>) -> tensor + %4 = "quantization.qcast"(%arg0) {volatile} : (tensor<2x3x1x1024xf32>) -> tensor<2x3x1x1024x!quant.uniform> + %5 = "quantization.dcast"(%4) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> + %6 = "stablehlo.reduce_window"(%5, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %14 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %14 : tensor + }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array} : (tensor<2x3x1x1024xf32>, tensor) -> tensor<2x3x1x1024xf32> + %7 = "quantization.qcast"(%6) {volatile} : (tensor<2x3x1x1024xf32>) -> tensor<2x3x1x1024x!quant.uniform> + %8 = "quantization.dcast"(%7) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> + %9 = "quantization.qcast"(%1) {volatile} : (tensor<2x3x1024x3xf32>) -> tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>> + %10 = "quantization.dcast"(%9) : (tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>>) -> tensor<2x3x1024x3xf32> + %11 = "tf.XlaCallModule"(%8, %10) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + %12 = "quantization.qcast"(%11) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> + %13 = "quantization.dcast"(%12) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> + return %13 : tensor<2x3x1x3xf32> + } + + // CHECK: quantized_dot_general_fn_1 + func.func private @composite_dot_general_fn_1(%arg0: tensor<2x3x1x1024xf32>, %arg1: tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general + // CHECK: %[[RQ:.*]] = stablehlo.uniform_quantize %[[DOT]] + // CHECK: return %[[RQ]] + + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + return %0 : tensor<2x3x1x3xf32> + } +} + +// ----- + +// Tests if reduce_window op following quantized same-scale op is quantized. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1722 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: main_00 + // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x1x1024xf32> + func.func private @main_00(%arg0: tensor<2x3x1x1024xf32>) -> tensor<2x3x3xf32> attributes {tf._original_func_name = "main_0"} { + // CHECK: %[[CST0:.*]] = stablehlo.constant dense<0xFF800000> : tensor + // CHECK: %[[CST1:.*]] = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + // CHECK: %[[Q0:.*]] = "quantization.qcast"(%[[CST0]]) + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[CST1]]) + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG0]]) + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q2]], %[[Q1]]) + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[CALL]] + + // CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[RESHAPE]], %[[Q0]]) + // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64> + // CHECK-SAME: window_dimensions = array + // CHECK: %[[ARG1:.*]]: tensor>, %[[ARG2:.*]]: tensor> + // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> + // CHECK: stablehlo.return %[[MAX]] : tensor> + // CHECK: (tensor<2x3x3x!quant.uniform>, tensor>) -> tensor<2x3x3x!quant.uniform> + + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[REDUCE]]) + // CHECK: return %[[DQ]] + + %0 = stablehlo.constant dense<0xFF800000> : tensor + %1 = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + %2 = "quantization.qcast"(%0) {volatile} : (tensor) -> tensor> + %3 = "quantization.dcast"(%2) : (tensor>) -> tensor + %4 = "quantization.qcast"(%1) {volatile} : (tensor<2x3x1024x3xf32>) -> tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>> + %5 = "quantization.dcast"(%4) : (tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>>) -> tensor<2x3x1024x3xf32> + %6 = "quantization.qcast"(%arg0) {volatile} : (tensor<2x3x1x1024xf32>) -> tensor<2x3x1x1024x!quant.uniform> + %7 = "quantization.dcast"(%6) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> + %8 = "tf.XlaCallModule"(%7, %5) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + %9 = "quantization.qcast"(%8) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> + %10 = "quantization.dcast"(%9) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> + %11 = stablehlo.reshape %10 : (tensor<2x3x1x3xf32>) -> tensor<2x3x3xf32> + %12 = "quantization.qcast"(%11) {volatile} : (tensor<2x3x3xf32>) -> tensor<2x3x3x!quant.uniform> + %13 = "quantization.dcast"(%12) : (tensor<2x3x3x!quant.uniform>) -> tensor<2x3x3xf32> + %14 = "stablehlo.reduce_window"(%13, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %17 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %17 : tensor + }) {padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64>, window_dimensions = array} : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> + %15 = "quantization.qcast"(%14) {volatile} : (tensor<2x3x3xf32>) -> tensor<2x3x3x!quant.uniform> + %16 = "quantization.dcast"(%15) : (tensor<2x3x3x!quant.uniform>) -> tensor<2x3x3xf32> + return %16 : tensor<2x3x3xf32> + } + + // CHECK: quantized_dot_general_fn_1 + func.func private @composite_dot_general_fn_1(%arg0: tensor<2x3x1x1024xf32>, %arg1: tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general + // CHECK: %[[RQ:.*]] = stablehlo.uniform_quantize %[[DOT]] + // CHECK: return %[[RQ]] + + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + return %0 : tensor<2x3x1x3xf32> + } +} + +// ----- + +// Tests if reduce_window op preceding quantized same-scale op is quantized. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1722 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: main_00 + // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x1024xf32> + func.func private @main_00(%arg0: tensor<2x3x1024xf32>) -> tensor<2x3x1x3xf32> attributes {tf._original_func_name = "main_0"} { + // CHECK: %[[CST0:.*]] = stablehlo.constant dense<0xFF800000> : tensor + // CHECK: %[[CST1:.*]] = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + // CHECK: %[[Q0:.*]] = "quantization.qcast"(%[[CST0]]) + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) + + // CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[Q1]], %[[Q0]]) + // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64> + // CHECK-SAME: window_dimensions = array + // CHECK: %[[ARG1:.*]]: tensor>, %[[ARG2:.*]]: tensor> + // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> + // CHECK: stablehlo.return %[[MAX]] : tensor> + // CHECK: (tensor<2x3x1024x!quant.uniform>, tensor>) -> tensor<2x3x1024x!quant.uniform> + + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[REDUCE]] + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[CST1]]) + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[RESHAPE]], %[[Q2]]) + + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[CALL]]) + // CHECK: return %[[DQ]] + + %0 = stablehlo.constant dense<0xFF800000> : tensor + %1 = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + %2 = "quantization.qcast"(%0) {volatile} : (tensor) -> tensor> + %3 = "quantization.dcast"(%2) : (tensor>) -> tensor + %4 = "quantization.qcast"(%arg0) {volatile} : (tensor<2x3x1024xf32>) -> tensor<2x3x1024x!quant.uniform> + %5 = "quantization.dcast"(%4) : (tensor<2x3x1024x!quant.uniform>) -> tensor<2x3x1024xf32> + %6 = "stablehlo.reduce_window"(%5, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %17 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %17 : tensor + }) {padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64>, window_dimensions = array} : (tensor<2x3x1024xf32>, tensor) -> tensor<2x3x1024xf32> + %7 = "quantization.qcast"(%6) {volatile} : (tensor<2x3x1024xf32>) -> tensor<2x3x1024x!quant.uniform> + %8 = "quantization.dcast"(%7) : (tensor<2x3x1024x!quant.uniform>) -> tensor<2x3x1024xf32> + %9 = stablehlo.reshape %8 : (tensor<2x3x1024xf32>) -> tensor<2x3x1x1024xf32> + %10 = "quantization.qcast"(%9) {volatile} : (tensor<2x3x1x1024xf32>) -> tensor<2x3x1x1024x!quant.uniform> + %11 = "quantization.dcast"(%10) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> + %12 = "quantization.qcast"(%1) {volatile} : (tensor<2x3x1024x3xf32>) -> tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>> + %13 = "quantization.dcast"(%12) : (tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>>) -> tensor<2x3x1024x3xf32> + %14 = "tf.XlaCallModule"(%11, %13) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + %15 = "quantization.qcast"(%14) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> + %16 = "quantization.dcast"(%15) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> + return %16 : tensor<2x3x1x3xf32> + } + + // CHECK: quantized_dot_general_fn_1 + func.func private @composite_dot_general_fn_1(%arg0: tensor<2x3x1x1024xf32>, %arg1: tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general + // CHECK: %[[RQ:.*]] = stablehlo.uniform_quantize %[[DOT]] + // CHECK: return %[[RQ]] + + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + return %0 : tensor<2x3x1x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_same_scale.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_same_scale.mlir new file mode 100644 index 00000000000000..5ab6ea4101db4b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_same_scale.mlir @@ -0,0 +1,373 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-quantize -verify-each=false | FileCheck %s + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: same_scale_after_composite + // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<2x3xf32> + func.func private @same_scale_after_composite(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<3x1xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[CALL]] : (tensor<1x3x!quant.uniform>) -> tensor<3x1x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[RESHAPE]]) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + // CHECK: return %[[DQ]] + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + %3 = "quantization.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %7 = stablehlo.reshape %6 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %8 = "quantization.qcast"(%7) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %9 = "quantization.dcast"(%8) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + return %9 : tensor<3x1xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<1x2x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: same_scale_indirectly_connected + // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<2x3xf32> + func.func private @same_scale_indirectly_connected(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[CALL]] : (tensor<1x3x!quant.uniform>) -> tensor<3x1x!quant.uniform> + // CHECK: %[[TRANSPOSE:.*]] = stablehlo.transpose %[[RESHAPE]], dims = [1, 0] : (tensor<3x1x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[TRANSPOSE]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: return %[[DQ]] + + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + %3 = "quantization.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %7 = stablehlo.reshape %6 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %8 = "quantization.qcast"(%7) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %9 = "quantization.dcast"(%8) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + %10 = stablehlo.transpose %9, dims = [1, 0] : (tensor<3x1xf32>) -> tensor<1x3xf32> + %11 = "quantization.qcast"(%10) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %12 = "quantization.dcast"(%11) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %12 : tensor<1x3xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<1x2x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +// CHECK-LABEL: same_scale_not_connected_to_composite +func.func @same_scale_not_connected_to_composite() -> tensor<3x1xf32> { + // CHECK: %[[CST:.*]] = stablehlo.constant + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[CST]]) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[DQ1:.*]] = "quantization.dcast"(%[[Q1]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[DQ1]] + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[RESHAPE]]) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + // CHECK: %[[DQ2:.*]] = "quantization.dcast"(%[[Q2]]) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + // CHECK: return %[[DQ2]] + + %0 = stablehlo.constant dense<1.000000e+00> : tensor<1x3xf32> + %1 = "quantization.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantization.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = stablehlo.reshape %2 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %4 = "quantization.qcast"(%3) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %5 = "quantization.dcast"(%4) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + return %5 : tensor<3x1xf32> +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: concatenate_and_composite + // CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<1x2xf32> + // CHECK-SAME: %[[ARG2:.*]]: tensor<2x5xf32> + func.func private @concatenate_and_composite(%arg0: tensor<3x2xf32>, %arg1: tensor<1x2xf32>, %arg2: tensor<2x5xf32>) -> tensor<4x5xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + // CHECK: %[[CONCAT:.*]] = stablehlo.concatenate %[[Q1]], %[[Q2]], dim = 0 + // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<4x2x!quant.uniform> + // CHECK: %[[Q3:.*]] = "quantization.qcast"(%[[ARG2]]) {volatile} : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[CONCAT]], %[[Q3]]) + // CHECK-SAME: (tensor<4x2x!quant.uniform>, tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<4x5x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[CALL]]) : (tensor<4x5x!quant.uniform>) -> tensor<4x5xf32> + // CHECK: return %[[DQ]] + + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<3x2x!quant.uniform>) -> tensor<3x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %3 = "quantization.dcast"(%2) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %4 = "stablehlo.concatenate"(%1, %3) { + dimension = 0 : i64 + } : (tensor<3x2xf32>, tensor<1x2xf32>) -> tensor<4x2xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<4x2xf32>) -> tensor<4x2x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<4x2x!quant.uniform>) -> tensor<4x2xf32> + %7 = "quantization.qcast"(%arg2) {volatile} : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> + %8 = "quantization.dcast"(%7) : (tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x5xf32> + %9 = "tf.XlaCallModule"(%6, %8) {Sout = [#tf_type.shape<4x5>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<4x2xf32>, tensor<2x5xf32>) -> tensor<4x5xf32> + %10 = "quantization.qcast"(%9) {volatile} : (tensor<4x5xf32>) -> tensor<4x5x!quant.uniform> + %11 = "quantization.dcast"(%10) : (tensor<4x5x!quant.uniform>) -> tensor<4x5xf32> + return %11 : tensor<4x5xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG3:.*]]: tensor<4x2x!quant.uniform> + // CHECK-SAME: %[[ARG4:.*]]: tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<4x2xf32>, %arg1: tensor<2x5xf32>) -> tensor<4x5xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG3]], %[[ARG4]] + // CHECK-SAME: (tensor<4x2x!quant.uniform>, tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<4x5x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<4x5x!quant.uniform>) -> tensor<4x5x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<4x2xf32>, tensor<2x5xf32>) -> tensor<4x5xf32> + return %0 : tensor<4x5xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: composite_and_pad + // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<2x3xf32> + // CHECK-SAME: %[[ARG2:.*]]: tensor + func.func private @composite_and_pad(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor) -> tensor<3x9xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = "quantization.qcast"(%arg2) {volatile} : (tensor) -> tensor> + // CHECK: %[[PAD:.*]] = stablehlo.pad %[[CALL]], %[[Q3]] + // CHECK-SAME: (tensor<1x3x!quant.uniform>, tensor>) -> tensor<3x9x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[PAD]]) : (tensor<3x9x!quant.uniform>) -> tensor<3x9xf32> + // CHECK: return %[[DQ]] + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + %3 = "quantization.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %7 = "quantization.qcast"(%arg2) {volatile} : (tensor) -> tensor> + %8 = "quantization.dcast"(%7) : (tensor>) -> tensor + %9 = stablehlo.pad %6, %8, low = [0, 1], high = [2, 1], interior = [0, 2] : (tensor<1x3xf32>, tensor) -> tensor<3x9xf32> + %10 = "quantization.qcast"(%9) {volatile} : (tensor<3x9xf32>) -> tensor<3x9x!quant.uniform> + %11 = "quantization.dcast"(%10) : (tensor<3x9x!quant.uniform>) -> tensor<3x9xf32> + return %11 : tensor<3x9xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<1x2x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: composite_and_select + // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<2x3xf32> + // CHECK-SAME: %[[ARG2:.*]]: tensor<1x3xi1> + // CHECK-SAME: %[[ARG3:.*]]: tensor<1x3xf32> + func.func private @composite_and_select(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<1x3xi1>, %arg3: tensor<1x3xf32>) -> tensor<1x3xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = "quantization.qcast"(%[[ARG3]]) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[SELECT:.*]] = stablehlo.select %[[ARG2]], %[[CALL]], %[[Q3]] : tensor<1x3xi1>, tensor<1x3x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[SELECT]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: return %[[DQ]] + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + %3 = "quantization.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %7 = "quantization.qcast"(%arg3) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %8 = "quantization.dcast"(%7) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %9 = stablehlo.select %arg2, %6, %8 : (tensor<1x3xi1>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %10 = "quantization.qcast"(%9) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %11 = "quantization.dcast"(%10) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %11 : tensor<1x3xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<1x2x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: composite_and_broadcast_in_dim + // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<2x3xf32> + func.func private @composite_and_broadcast_in_dim(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<2x3x2xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %[[CALL]], dims = [2, 1] : (tensor<1x3x!quant.uniform>) -> tensor<2x3x2x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[BROADCAST]]) : (tensor<2x3x2x!quant.uniform>) -> tensor<2x3x2xf32> + // CHECK: return %[[DQ]] + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + %3 = "quantization.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %7 = stablehlo.broadcast_in_dim %6, dims = [2, 1] : (tensor<1x3xf32>) -> tensor<2x3x2xf32> + %8 = "quantization.qcast"(%7) {volatile} : (tensor<2x3x2xf32>) -> tensor<2x3x2x!quant.uniform> + %9 = "quantization.dcast"(%8) : (tensor<2x3x2x!quant.uniform>) -> tensor<2x3x2xf32> + return %9 : tensor<2x3x2xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<1x2x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: composite_and_gather + // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4x5xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<3x5x2xf32> + // CHECK-SAME: %[[ARG2:.*]]: tensor<2x3x2xi64> + func.func private @composite_and_gather(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x2xf32>, %arg2: tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<3x4x5xf32>) -> tensor<3x4x5x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<3x5x2xf32>) -> tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: (tensor<3x4x5x!quant.uniform>, tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>>) -> tensor<3x4x2x!quant.uniform> + // CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%[[CALL]], %[[ARG2]]) + // CHECK-SAME: (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi64>) -> tensor<2x3x2x2x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[GATHER]]) : (tensor<2x3x2x2x!quant.uniform>) -> tensor<2x3x2x2xf32> + // CHECK: return %[[DQ]] + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<3x4x5xf32>) -> tensor<3x4x5x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<3x4x5x!quant.uniform>) -> tensor<3x4x5xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<3x5x2xf32>) -> tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>> + %3 = "quantization.dcast"(%2) : (tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>>) -> tensor<3x5x2xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x5xf32>, tensor<3x5x2xf32>) -> tensor<3x4x2xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<3x4x2xf32>) -> tensor<3x4x2x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<3x4x2x!quant.uniform>) -> tensor<3x4x2xf32> + %7 = "stablehlo.gather"(%6, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xf32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> + %8 = "quantization.qcast"(%7) {volatile} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2x!quant.uniform> + %9 = "quantization.dcast"(%8) : (tensor<2x3x2x2x!quant.uniform>) -> tensor<2x3x2x2xf32> + return %9 : tensor<2x3x2x2xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<3x4x5x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x2xf32>) -> tensor<3x4x2xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<3x4x5x!quant.uniform>, tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>>) -> tensor<3x4x2x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<3x4x2x!quant.uniform>) -> tensor<3x4x2x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<3x4x5xf32>, tensor<3x5x2xf32>) -> tensor<3x4x2xf32> + return %0 : tensor<3x4x2xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: composite_and_slice + // CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<2x4xf32> + func.func private @composite_and_slice(%arg0: tensor<3x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x2xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<3x4x!quant.uniform> + // CHECK: %[[SLICE:.*]] = stablehlo.slice %[[CALL]] [1:3, 2:4] : (tensor<3x4x!quant.uniform>) -> tensor<2x2x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[SLICE]]) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> + // CHECK: return %[[DQ]] + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<3x2x!quant.uniform>) -> tensor<3x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> + %3 = "quantization.dcast"(%2) : (tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x4xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x2xf32>, tensor<2x4xf32>) -> tensor<3x4xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<3x4xf32>) -> tensor<3x4x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<3x4x!quant.uniform>) -> tensor<3x4xf32> + %7 = stablehlo.slice %6 [1:3, 2:4] : (tensor<3x4xf32>) -> tensor<2x2xf32> + %8 = "quantization.qcast"(%7) {volatile} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + %9 = "quantization.dcast"(%8) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> + return %9 : tensor<2x2xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<3x2x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<3x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<3x4x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<3x4x!quant.uniform>) -> tensor<3x4x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<3x2xf32>, tensor<2x4xf32>) -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_weight_only.mlir new file mode 100644 index 00000000000000..6a9bd42a76ae82 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_weight_only.mlir @@ -0,0 +1,66 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-quantize | FileCheck %s + +// Test that hybrid quantized dot_general is produced when q/dq pair only exists +// for weight. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> + %0 = "quantization.qcast"(%cst) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<2x3x!quant.uniform>) -> tensor<2x3xf32> + %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// CHECK-LABEL: quantize_dot_general_fn +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> +// CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[Q]]) +// CHECK-SAME: {_quantization_method = "weight_only_ptq { }"} : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_dot_general_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: return %[[DOT]] + +// ----- + +// Test that hybrid quantized convolution is produced when q/dq pair only exists +// for weight. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> + %0 = "quantization.qcast"(%cst) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +} + +// CHECK-LABEL: quantize_conv_fn +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> +// CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[Q]]) {_quantization_method = "weight_only_ptq { }"} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_conv_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x3x4x3xf32>, %[[ARG2:.+]]: tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG1]], %[[ARG2]]) +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CONV]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/shape_cstr_legalize_to_hlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/shape_cstr_legalize_to_hlo.mlir new file mode 100644 index 00000000000000..ac7d6a51fb87b1 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/shape_cstr_legalize_to_hlo.mlir @@ -0,0 +1,110 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-convert-shape-to-stablehlo-with-constraints --verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @shape_cstr_broadcastable +func.func @shape_cstr_broadcastable(%arg0: tensor<2xindex>, %arg1: tensor<2xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<2xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[DIMS2]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +// CHECK-LABEL: func @shape_cstr_broadcastable_different_dims_1 +func.func @shape_cstr_broadcastable_different_dims_1(%arg0: tensor<2xindex>, %arg1: tensor<1xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<1xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<1xindex> to tensor<1xi32> + // CHECK-NEXT: %[[PAD:.*]] = stablehlo.constant dense<1> : tensor<1xi32> + // CHECK-NEXT: %[[DIMS2_PAD:.*]] = stablehlo.concatenate %[[PAD]], %[[DIMS2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2_PAD]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[DIMS2_PAD]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +// CHECK-LABEL: func @shape_cstr_broadcastable_different_dims_2 +func.func @shape_cstr_broadcastable_different_dims_2(%arg0: tensor<1xindex>, %arg1: tensor<2xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<1xindex>, tensor<2xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<1xindex> to tensor<1xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[PAD:.*]] = stablehlo.constant dense<1> : tensor<1xi32> + // CHECK-NEXT: %[[DIMS1_PAD:.*]] = stablehlo.concatenate %[[PAD]], %[[DIMS1]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1_PAD]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1_PAD]], %[[DIMS2]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +func.func @shape_cstr_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>, %arg2: tensor<4xindex>) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex> + shape.assuming %0 { + } + func.return +} + +// ----- + +func.func @shape_cstr_broadcastable_input_shape(%arg0: !shape.shape, %arg1: !shape.shape) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape + shape.assuming %0 { + } + func.return +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_convert_func_to_bfloat16.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_convert_func_to_bfloat16.mlir new file mode 100644 index 00000000000000..f73515b3c5e815 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_convert_func_to_bfloat16.mlir @@ -0,0 +1,128 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-convert-func-to-bfloat16 -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @add_f32(%arg0: tensor<3x3xbf16>, %arg1: tensor<3x3xbf16>) -> tensor<3x3xbf16> +func.func @add_f32(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> { + // CHECK-NOT: f32 + // CHECK: stablehlo.add + %0 = stablehlo.add %arg0, %arg1: (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> +} + +// ----- + +// CHECK-LABEL: @add_f64(%arg0: tensor<3x3xbf16>, %arg1: tensor<3x3xbf16>) -> tensor<3x3xbf16> +func.func @add_f64(%arg0: tensor<3x3xf64>, %arg1: tensor<3x3xf64>) -> tensor<3x3xf64> { + // CHECK-NOT: f64 + // CHECK: stablehlo.add + %0 = stablehlo.add %arg0, %arg1: (tensor<3x3xf64>, tensor<3x3xf64>) -> tensor<3x3xf64> + return %0 : tensor<3x3xf64> +} + +// ----- + +// CHECK-LABEL: @constant_f32() -> tensor<2x2xbf16> +func.func @constant_f32() -> tensor<2x2xf32> { + // CHECK-NOT: f32 + // CHECK{LITERAL}: stablehlo.constant dense<[[1.398440e+00, 0.000000e+00], [3.093750e+00, -2.001950e-01]]> : tensor<2x2xbf16> + %0 = stablehlo.constant dense<[[1.4, 0.0], [3.1, -0.2]]> : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @constant_elided() -> tensor<2x2xf32> { + // expected-error @+1 {{failed to legalize operation 'stablehlo.constant' that was explicitly marked illegal}} + %0 = stablehlo.constant dense_resource<__elided__> : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: @reduce_window_f32(%arg0: tensor<2x3x1x3xbf16>) -> tensor<2x3x1x3xbf16> +func.func @reduce_window_f32(%arg0: tensor<2x3x1x3xf32>) -> tensor<2x3x1x3xf32> { + // CHECK-NOT: f32 + // CHECK: stablehlo.reduce_window + %0 = stablehlo.constant dense<0.0> : tensor + %1 = "stablehlo.reduce_window"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %2 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %2 : tensor + }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array} : (tensor<2x3x1x3xf32>, tensor) -> tensor<2x3x1x3xf32> + return %1 : tensor<2x3x1x3xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_i32_f32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_i32_f32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xf32> { + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xf32> + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %[[BITCAST]] : (tensor<1x256128xf32>) -> tensor<1x256128xbf16> + // CHECK: return %[[CONVERT]] : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xf32> + return %20 : tensor<1x256128xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_f32_i32(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xi32> +func.func @bitcast_convert_f32_i32(%arg0: tensor<1x256128xf32>) -> tensor<1x256128xi32> { + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xf32> + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %[[CONVERT]] : (tensor<1x256128xf32>) -> tensor<1x256128xi32> + // CHECK: return %[[BITCAST]] : tensor<1x256128xi32> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xf32>) -> tensor<1x256128xi32> + return %20 : tensor<1x256128xi32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_ui32_f32(%arg0: tensor<1x256128xui32>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_ui32_f32(%arg0: tensor<1x256128xui32>) -> tensor<1x256128xf32> { + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xui32>) -> tensor<1x256128xf32> + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %[[BITCAST]] : (tensor<1x256128xf32>) -> tensor<1x256128xbf16> + // CHECK: return %[[CONVERT]] : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xui32>) -> tensor<1x256128xf32> + return %20 : tensor<1x256128xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_f32_ui32(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xui32> +func.func @bitcast_convert_f32_ui32(%arg0: tensor<1x256128xf32>) -> tensor<1x256128xui32> { + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xf32> + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %[[CONVERT]] : (tensor<1x256128xf32>) -> tensor<1x256128xui32> + // CHECK: return %[[BITCAST]] : tensor<1x256128xui32> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xf32>) -> tensor<1x256128xui32> + return %20 : tensor<1x256128xui32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_f32_f32(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_f32_f32(%arg0: tensor<1x256128xf32>) -> tensor<1x256128xf32> { + // Convert bitcast_convert to no-op for f32->f32. + // CHECK: return %arg0 : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xf32>) -> tensor<1x256128xf32> + return %20 : tensor<1x256128xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_i32_ui32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xui32> +func.func @bitcast_convert_i32_ui32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xui32> { + // Do not convert bitcast_convert for legal types. + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xui32> + // CHECK: return %[[BITCAST]] : tensor<1x256128xui32> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xui32> + return %20 : tensor<1x256128xui32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_bf16_bf16(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_bf16_bf16(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xbf16> { + // Do not convert bitcast_convert for legal types. + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xbf16> + // CHECK: return %[[BITCAST]] : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xbf16> + return %20 : tensor<1x256128xbf16> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_convert_xla_call_module_op_to_bfloat16.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_convert_xla_call_module_op_to_bfloat16.mlir new file mode 100644 index 00000000000000..d3694e7e6402df --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_convert_xla_call_module_op_to_bfloat16.mlir @@ -0,0 +1,42 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-xla-call-module-serialization -tf-stablehlo-convert-xla-call-module-op-to-bfloat16 -tf-xla-call-module-deserialization | FileCheck %s + +// ConvertXlaCallModuleOpToBfloat16Pass works on XlaCallModuleOps with +// serialized modules. Which makes verification difficult. Therefore we add +// (de)serialization passes so that the input and output are deserializated +// StableHLO functions. + +// CHECK-LABEL: module +module { + // CHECK-LABEL: func @main + // CHECK-SAME: %[[ARG_0:.*]]: tensor<10xf32>, %[[ARG_1:.*]]: tensor<10xf32>, %[[ARG_2:.*]]: tensor<6xi32> + func.func @main( + %arg0: tensor<10xf32>, %arg1: tensor<10xf32>, %arg2: tensor<6xi32> + ) -> (tensor<10xf32>, tensor<6xi32>) { + // CHECK: %[[CAST_0:.*]] = "tf.Cast"(%[[ARG_0]]) <{Truncate = false}> : (tensor<10xf32>) -> tensor<10xbf16> + // CHECK: %[[CAST_1:.*]] = "tf.Cast"(%[[ARG_1]]) <{Truncate = false}> : (tensor<10xf32>) -> tensor<10xbf16> + // CHECK: %[[RESULT:.*]]:2 = "tf.XlaCallModule"(%[[CAST_0]], %[[CAST_1]], %[[ARG_2]]) + // CHECK-SAME: _stablehlo_version = "1.0.0" + // CHECK-SAME: (tensor<10xbf16>, tensor<10xbf16>, tensor<6xi32>) -> (tensor<10xbf16>, tensor<6xi32>) + // CHECK: %[[RESULT_CAST:.*]] = "tf.Cast"(%[[RESULT]]#0) <{Truncate = false}> : (tensor<10xbf16>) -> tensor<10xf32> + %0:2 = "tf.XlaCallModule"(%arg0, %arg1, %arg2) { + Sout = [#tf_type.shape<10>], dim_args_spec = [], + _entry_function = @main_0, + _stablehlo_version = "1.0.0", + _stablehlo_module_attrs = { mhlo.num_partitions = 1 }, module = "", + platforms = [], version = 5 : i64 + } : (tensor<10xf32>, tensor<10xf32>, tensor<6xi32>) -> (tensor<10xf32>, tensor<6xi32>) + // CHECK: return %[[RESULT_CAST]], %[[RESULT]]#1 : tensor<10xf32>, tensor<6xi32> + func.return %0#0, %0#1 : tensor<10xf32>, tensor<6xi32> + } + + // CHECK-LABEL: func private @main_0 + // CHECK-SAME: %[[ARG_0:.*]]: tensor<10xbf16>, %[[ARG_1:.*]]: tensor<10xbf16>, %[[ARG_2:.*]]: tensor<6xi32> + func.func private @main_0( + %arg0: tensor<10xf32>, %arg1: tensor<10xf32>, %arg2: tensor<6xi32> + ) -> (tensor<10xf32>, tensor<6xi32>) attributes {_from_xla_call_module} { + // CHECK: %[[ADD:.*]] = stablehlo.add %[[ARG_0]], %[[ARG_1]] : tensor<10xbf16> + %0 = stablehlo.add %arg0, %arg1 : tensor<10xf32> + // CHECK: return %[[ADD]], %[[ARG_2]] : tensor<10xbf16>, tensor<6xi32> + return %0, %arg2 : tensor<10xf32>, tensor<6xi32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_defer_activation_transpose.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_defer_activation_transpose.mlir new file mode 100644 index 00000000000000..b4216725020cb4 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_defer_activation_transpose.mlir @@ -0,0 +1,307 @@ +// RUN: stablehlo-quant-opt %s -tf-stablehlo-defer-activation-transpose \ +// RUN: -split-input-file -verify-diagnostics | FileCheck %s + +// Tests that an `add(transpose(arg0), arg1)` pattern is converted to +// `transpose(add(arg0, transpose(arg1)))`. The transpose in the activation is +// deferred to the output of `stablehlo.add` and an extra transpose op is +// inserted to the RHS to match the shape of the operand. + +// CHECK-LABEL: add_with_activation_transpose +func.func @add_with_activation_transpose(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x4x3x3xf32> + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> + %2 = stablehlo.add %1, %0 : tensor<1x4x3x3xf32> + return %2 : tensor<1x4x3x3xf32> +} +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[CONST_0]], dims = [0, 2, 3, 1] : (tensor<1x4x3x3xf32>) -> tensor<1x3x3x4xf32> + +// Check that the shape of the add is changed to reflect the deferred transpose. +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[ARG_0]], %[[TRANSPOSE_0]] : tensor<1x3x3x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose +// CHECK: return %[[TRANSPOSE_1]] + +// ----- + +// Tests that an `add(transpose(arg0), broadcast_in_dim(arg1))` pattern is +// converted to `transpose(add(arg0, transpose(broadcast_in_dim(arg1))))`. +// The transpose in the activation is deferred to the output of `stablehlo.add` +// and an extra transpose op is inserted to the RHS to match the shape of the +// operand. + +// CHECK-LABEL: add_with_activation_transpose_broadcasted_rhs +func.func @add_with_activation_transpose_broadcasted_rhs(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> + %1 = stablehlo.broadcast_in_dim %0, dims = [1] : (tensor<4xf32>) -> tensor<1x4x3x3xf32> + %2 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> + %3 = stablehlo.add %2, %1 : tensor<1x4x3x3xf32> + return %3 : tensor<1x4x3x3xf32> +} +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant +// CHECK-DAG: %[[BROADCAST:.+]] = stablehlo.broadcast_in_dim %[[CONST_0]], dims = [1] : (tensor<4xf32>) -> tensor<1x4x3x3xf32> +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[BROADCAST]], dims = [0, 2, 3, 1] : (tensor<1x4x3x3xf32>) -> tensor<1x3x3x4xf32> + +// Check that the shape of the add is changed to reflect the deferred transpose. +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[ARG_0]], %[[TRANSPOSE_0]] : tensor<1x3x3x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose +// CHECK: return %[[TRANSPOSE_1]] + +// ----- + +// [No change] Tests that the activation transpose whose permutation is not +// `[0, 3, 1, 2]` is not deferred. + +// CHECK-LABEL: add_with_activation_transpose_permutation_mismatch +func.func @add_with_activation_transpose_permutation_mismatch( + %arg0: tensor<1x2x3x4xf32>) -> tensor<1x3x2x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x2x4xf32> + %1 = stablehlo.transpose %arg0, dims = [0, 2, 1, 3] : (tensor<1x2x3x4xf32>) -> tensor<1x3x2x4xf32> + %2 = stablehlo.add %1, %0 : tensor<1x3x2x4xf32> + return %2 : tensor<1x3x2x4xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[ADD_0]] + +// ----- + +// [No change] Tests that the activation transpose whose rank is not 4 is not +// deferred. + +// CHECK-LABEL: add_with_activation_transpose_rank_two +func.func @add_with_activation_transpose_rank_two(%arg0: tensor<1x2xf32>) -> tensor<2x1xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<2x1xf32> + %1 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32> + %2 = stablehlo.add %1, %0 : tensor<2x1xf32> + return %2 : tensor<2x1xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[ADD_0]] + +// ----- + +// [No change] Tests that the right-hand side that is not a constant is not +// deferred. + +// CHECK-LABEL: add_with_activation_transpose_nonconst_rhs +func.func @add_with_activation_transpose_nonconst_rhs(%arg0: tensor<1x3x3x4xf32>, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3x3xf32> { + %0 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> + %1 = stablehlo.add %0, %arg1 : tensor<1x4x3x3xf32> + return %1 : tensor<1x4x3x3xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[ADD_0]] + +// ----- + +// Tests that the transpose of the input of `stablehlo.reduce_window` is +// deferred to the result. The attributes are permutated according to the new +// input shape. + +// CHECK-LABEL: reduce_window_max_activation_transpose +func.func @reduce_window_max_activation_transpose(%arg0: tensor<1x16x16x4xf32>) -> tensor<1x4x8x8xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x16x16x4xf32>) -> tensor<1x4x16x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) {window_dimensions = array, window_strides = array} : (tensor<1x4x16x16xf32>, tensor) -> tensor<1x4x8x8xf32> + return %2 : tensor<1x4x8x8xf32> +} +// CHECK-SAME: %[[ARG:.+]]: tensor<1x16x16x4xf32> +// CHECK-DAG: %[[INIT_VALUE_CONST:.+]] = stablehlo.constant dense<0xFF800000> + +// Check that the body is not modified. +// CHECK: %[[REDUCE_WINDOW:.+]] = "stablehlo.reduce_window"(%[[ARG]], %[[INIT_VALUE_CONST]]) +// CHECK: <{window_dimensions = array, window_strides = array}> +// CHECK: ^bb0(%[[REDUCE_ARG_0:.+]]: tensor, %[[REDUCE_ARG_1:.+]]: tensor): +// CHECK: %[[MAX:.+]] = stablehlo.maximum %[[REDUCE_ARG_0]], %[[REDUCE_ARG_1]] +// CHECK: stablehlo.return %[[MAX]] + +// Check that the attributes window_dimensions & window_strides are also +// permutated to match the new input shape. +// CHECK: (tensor<1x16x16x4xf32>, tensor) -> tensor<1x8x8x4xf32> + +// Check that a `stablehlo.transpose` is added to the result to match the shape +// of the users. +// CHECK: %[[TRANSPOSE:.+]] = stablehlo.transpose %[[REDUCE_WINDOW]], dims = [0, 3, 1, 2] : (tensor<1x8x8x4xf32>) -> tensor<1x4x8x8xf32> +// CHECK: return %[[TRANSPOSE]] + +// ----- + +// Tests that the transpose of the input of `stablehlo.reduce_window` is +// deferred to the result. The attributes are permutated according to the new +// input shape. This test is similar to the test above with the difference that +// the `stablehlo.reduce_window` has explicit optional attributes: +// `base_dilations` and `window_dilations`. + +// CHECK-LABEL: reduce_window_max_activation_transpose_explicit_optional_attrs +func.func @reduce_window_max_activation_transpose_explicit_optional_attrs( + %arg0: tensor<1x16x16x4xf32>) -> tensor<1x4x15x15xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x16x16x4xf32>) -> tensor<1x4x16x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) { + window_dimensions = array, + window_strides = array, + base_dilations = array, + window_dilations = array + } : (tensor<1x4x16x16xf32>, tensor) -> tensor<1x4x15x15xf32> + return %2 : tensor<1x4x15x15xf32> +} +// CHECK-SAME: %[[ARG:.+]]: tensor<1x16x16x4xf32> +// CHECK-DAG: %[[INIT_VALUE_CONST:.+]] = stablehlo.constant dense<0xFF800000> + +// Check that the body is not modified. +// CHECK: %[[REDUCE_WINDOW:.+]] = "stablehlo.reduce_window"(%[[ARG]], %[[INIT_VALUE_CONST]]) +// CHECK: <{base_dilations = array, window_dilations = array, window_dimensions = array, window_strides = array}> +// CHECK: ^bb0(%[[REDUCE_ARG_0:.+]]: tensor, %[[REDUCE_ARG_1:.+]]: tensor): +// CHECK: %[[MAX:.+]] = stablehlo.maximum %[[REDUCE_ARG_0]], %[[REDUCE_ARG_1]] +// CHECK: stablehlo.return %[[MAX]] + +// Check that the attributes window_dimensions & window_strides along with +// optional attributes base_dilations and window_dilations are also permutated +// to match the new input shape. +// CHECK: (tensor<1x16x16x4xf32>, tensor) -> tensor<1x15x15x4xf32> + +// Check that a `stablehlo.transpose` is added to the result to match the shape +// of the users. +// CHECK: %[[TRANSPOSE:.+]] = stablehlo.transpose %[[REDUCE_WINDOW]], dims = [0, 3, 1, 2] : (tensor<1x15x15x4xf32>) -> tensor<1x4x15x15xf32> +// CHECK: return %[[TRANSPOSE]] + +// ----- + +// [No change] Tests that the transpose of the input of +// `stablehlo.reduce_window` is NOT deferred to the result, when the input +// tensor does not have rank 4. + +// CHECK-LABEL: reduce_window_max_activation_transpose +// CHECK-SAME: (%[[ARG:.+]]: tensor<16x8xf32>) -> tensor<4x8xf32> +func.func @reduce_window_max_activation_transpose_rank2(%arg0: tensor<16x8xf32>) -> tensor<4x8xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<16x8xf32>) -> tensor<8x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) {window_dimensions = array, window_strides = array} : (tensor<8x16xf32>, tensor) -> tensor<4x8xf32> + return %2 : tensor<4x8xf32> +} +// CHECK-DAG: stablehlo.constant +// CHECK: stablehlo.transpose %[[ARG]] +// CHECK: stablehlo.reduce_window + +// ----- + +// [No change] Tests that the transpose of the input of +// `stablehlo.reduce_window` is NOT deferred to the result, when it has an +// explicit `padding` attribute. + +// CHECK-LABEL: reduce_window_max_activation_transpose_with_padding +func.func @reduce_window_max_activation_transpose_with_padding(%arg0: tensor<1x16x16x4xf32>) -> tensor<1x4x9x9xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x16x16x4xf32>) -> tensor<1x4x16x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) { + window_dimensions = array, + window_strides = array, + padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64> + } : (tensor<1x4x16x16xf32>, tensor) -> tensor<1x4x9x9xf32> + return %2 : tensor<1x4x9x9xf32> +} +// CHECK-SAME: %[[ARG:.+]]: tensor<1x16x16x4xf32> +// CHECK-DAG: stablehlo.constant +// CHECK: stablehlo.transpose %[[ARG]] +// CHECK: stablehlo.reduce_window + +// ----- + +// [No change] Tests that the transpose of the input of +// `stablehlo.reduce_window` is NOT deferred to the result, when the transpose +// isn't `[0, 3, 1, 2]` (i.e. NCHW->NHWC). + +// CHECK-LABEL: reduce_window_max_activation_transpose_with_padding +func.func @reduce_window_max_activation_transpose_with_padding(%arg0: tensor<16x16x4x1xf32>) -> tensor<1x4x8x8xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<16x16x4x1xf32>) -> tensor<1x4x16x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) { + window_dimensions = array, + window_strides = array + } : (tensor<1x4x16x16xf32>, tensor) -> tensor<1x4x8x8xf32> + return %2 : tensor<1x4x8x8xf32> +} +// CHECK-SAME: %[[ARG:.+]]: tensor<16x16x4x1xf32> +// CHECK-DAG: stablehlo.constant +// CHECK: stablehlo.transpose %[[ARG]] +// CHECK: stablehlo.reduce_window + +// ----- + +// Tests that an `max(transpose(arg0), arg1)` pattern is converted to +// `transpose(max(arg0, transpose(arg1)))`. The transpose in the activation is +// deferred to the output of `stablehlo.max` and an extra transpose op is +// inserted to the RHS to match the shape of the operand. + +// CHECK-LABEL: max_with_activation_transpose +func.func @max_with_activation_transpose(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x4x3x3xf32> + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> + %2 = stablehlo.maximum %1, %0 : tensor<1x4x3x3xf32> + return %2 : tensor<1x4x3x3xf32> +} +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[CONST_0]], dims = [0, 2, 3, 1] : (tensor<1x4x3x3xf32>) -> tensor<1x3x3x4xf32> + +// Check that the shape of the add is changed to reflect the deferred transpose. +// CHECK: %[[MAX_0:.+]] = stablehlo.maximum %[[ARG_0]], %[[TRANSPOSE_0]] : tensor<1x3x3x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose +// CHECK: return %[[TRANSPOSE_1]] + +// ----- + +// [No change] Tests that the activation transpose of `stablehlo.maximum` whose +// permutation is not `[0, 3, 1, 2]` is not deferred. + +// CHECK-LABEL: max_with_activation_transpose_permutation_mismatch +func.func @max_with_activation_transpose_permutation_mismatch( + %arg0: tensor<1x2x3x4xf32>) -> tensor<1x3x2x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x2x4xf32> + %1 = stablehlo.transpose %arg0, dims = [0, 2, 1, 3] : (tensor<1x2x3x4xf32>) -> tensor<1x3x2x4xf32> + %2 = stablehlo.maximum %1, %0 : tensor<1x3x2x4xf32> + return %2 : tensor<1x3x2x4xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[MAX_0:.+]] = stablehlo.maximum %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[MAX_0]] + +// ----- + +// [No change] Tests that the activation transpose of `stablehlo.maximum` whose +// rank is not 4 is not deferred. + +// CHECK-LABEL: max_with_activation_transpose_rank_two +func.func @max_with_activation_transpose_rank_two(%arg0: tensor<1x2xf32>) -> tensor<2x1xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<2x1xf32> + %1 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32> + %2 = stablehlo.maximum %1, %0 : tensor<2x1xf32> + return %2 : tensor<2x1xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[MAX_0:.+]] = stablehlo.maximum %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[MAX_0]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_fold_constant_transpose.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_fold_constant_transpose.mlir new file mode 100644 index 00000000000000..da96bb0e7a681f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_fold_constant_transpose.mlir @@ -0,0 +1,59 @@ +// RUN: stablehlo-quant-opt %s -tf-stablehlo-fold-constant-transpose \ +// RUN: -split-input-file | FileCheck %s + +// CHECK-LABEL: transpose_simple_1d +func.func @transpose_simple_1d() -> tensor<2xf32> { + %0 = stablehlo.constant dense<[0.000000e+0, 1.000000e+0]> : tensor<2xf32> + %1 = stablehlo.transpose %0, dims = [0] : (tensor<2xf32>) -> tensor<2xf32> + return %1 : tensor<2xf32> +} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32> +// CHECK-NOT: transpose +// CHECK: return %[[CONST_0]] : tensor<2xf32> + +// ----- + +// CHECK-LABEL: transpose_simple_2d +func.func @transpose_simple_2d() -> tensor<3x2xf32> { + %0 = stablehlo.constant dense<[[0.000000e+0, 1.000000e+0, 2.000000e+0], [3.000000e+0, 4.000000e+0, 5.000000e+0]]> : tensor<2x3xf32> + %1 = stablehlo.transpose %0, dims = [1, 0] : (tensor<2x3xf32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant dense<{{\[\[}}0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> +// CHECK-NOT: transpose +// CHECK: return %[[CONST_0]] : tensor<3x2xf32> + +// ----- + +// CHECK-LABEL: transpose_simple_4d +func.func @transpose_simple_4d() -> tensor<5x2x3x4xf32> { + %0 = stablehlo.constant dense<1.000000e+0> : tensor<2x3x4x5xf32> + %1 = stablehlo.transpose %0, dims = [3, 0, 1, 2] : (tensor<2x3x4x5xf32>) -> tensor<5x2x3x4xf32> + return %1 : tensor<5x2x3x4xf32> +} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<5x2x3x4xf32> +// CHECK-NOT: transpose +// CHECK: return %[[CONST_0]] : tensor<5x2x3x4xf32> + +// ----- + +// Tests that int constants are not folded. + +// CHECK-LABEL: transpose_int +func.func @transpose_int() -> tensor<3x2xi32> { + %0 = stablehlo.constant dense<0> : tensor<2x3xi32> + %1 = stablehlo.transpose %0, dims = [1, 0] : (tensor<2x3xi32>) -> tensor<3x2xi32> + return %1 : tensor<3x2xi32> +} +// CHECK: transpose + +// ----- + +// Tests that transposing an argument cannot be folded. + +// CHECK-LABEL: transpose_arg +func.func @transpose_arg(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x3xf32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} +// CHECK: transpose diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_calibration_statistics_saver.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_calibration_statistics_saver.mlir new file mode 100644 index 00000000000000..8e034735ee9aaa --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_calibration_statistics_saver.mlir @@ -0,0 +1,219 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -mlir-disable-threading -tf-stablehlo-insert-calibration-statistics-saver | FileCheck %s + +func.func @serving_default(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x2x2x2xf32>) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}} { + %cst = "tf.Const"() <{value = dense<[[[[-0.891899645, 0.392044574], [0.77720493, 1.31188095], [0.255048186, 2.700150e+00]], [[-1.08111858, -0.406604826], [-0.298575521, -2.25356531], [-1.00201964, 2.54532099]], [[-1.34911358, 0.279911458], [-0.868258893, -1.36708188], [0.866317451, -2.05804896]]], [[[-0.591397941, 0.331505477], [0.715151429, 2.64073896], [1.27163255, 0.206143498]], [[0.474211812, 1.45044816], [0.119936548, 2.54149938], [-0.939900994, 0.438387245]], [[-1.12486279, -1.09022558], [0.82202208, 1.04652023], [1.30316162, 2.62054276]]]]> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 5 : i32, id = "0", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) + %0 = "tf.Conv2D"(%output, %cst) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + %output_1, %min_2, %max_3, %histogram_4 = "tf.CustomAggregator"(%0) <{calibration_method = 5 : i32, id = "1", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x2x2x2xf32>) -> (tensor<1x2x2x2xf32>, tensor, tensor, tensor<512xi64>) + %1 = "tf.Identity"(%output_1) {device = ""} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %1 : tensor<1x2x2x2xf32> +} +// CHECK-LABEL: @serving_default +// CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], %[[MIN_O:.*]], %[[MAX_O:.*]], %[[HISTOGRAM_0:.*]] = "tf.CustomAggregator" +// CKECK-SAME: <{calibration_method = 5 : i32, id = "0", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) +// CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], %[[MIN_1:.*]], %[[MAX_1:.*]], %[[HISTOGRAM_1:.*]] = "tf.CustomAggregator" +// CKECK-SAME: <{calibration_method = 5 : i32, id = "1", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) +// CHECK: "tf.CalibrationStatisticsSaver"(%[[MIN_O]], %[[MAX_O]], %[[HISTOGRAM_0]], %[[MIN_1]], %[[MAX_1]], %[[HISTOGRAM_1]]) +// CHECK-SAME: <{calibration_methods = [5 : i32, 5 : i32], ids = ["0", "1"], output_file_path = "serving_default_0.pb"}> : (tensor, tensor, tensor<512xi64>, tensor, tensor, tensor<512xi64>) -> () +// CHECK: return + +// ----- + +// No CustomAggregator ops exist. +func.func private @composite_conv2d_with_bias_and_relu6_fn_1(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x2x2x2xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> : (tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x2x2x2xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %2 : tensor<1x2x2x2xf32> +} +// CHECK-LABEL: @composite_conv2d_with_bias_and_relu6_fn_1 +// CHECK-NOT: "tf.CalibrationStatisticsSaver" + +// ----- + +// Check the IfOp is set to stateful. +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1833 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: func.func @serving_default + // CHECK: "tf.If" + // CHECK-SAME: is_stateless = false + func.func @serving_default(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi32>}> {device = ""} : () -> tensor<2xi32> + %cst_0 = "tf.Const"() <{value = dense<1.000000e+01> : tensor}> {device = ""} : () -> tensor + %0 = "tf.Sum"(%arg0, %cst) <{keep_dims = false}> {device = ""} : (tensor<1x4xf32>, tensor<2xi32>) -> tensor + %1 = "tf.Greater"(%0, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %2:2 = "tf.If"(%1, %arg0) <{else_branch = @cond_false_80, is_stateless = true, then_branch = @cond_true_70}> {Tcond = i1, Tin = [f32], Tout = [i1, f32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], device = ""} : (tensor, tensor<1x4xf32>) -> (tensor, tensor<1x3xf32>) + %3 = "tf.Identity"(%2#1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @cond_false_80 + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "cond_false_80_0.pb" + func.func private @cond_false_80(%arg0: tensor<1x4xf32> {tf._user_specified_name = "x"}) -> (tensor, tensor<1x3xf32>) attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x4>], tf._original_func_name = "cond_false_8"} { + %cst = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0.117216609, 0.933735609, 0.0728900209]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.Identity"(%cst) {device = ""} : (tensor) -> tensor + %1 = "tf.PartitionedCall"(%output, %cst_1, %cst_0) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %output_2, %min_3, %max_4, %histogram_5 = "tf.CustomAggregator"(%1) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + %2 = "tf.Identity"(%output_2) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %0, %2 : tensor, tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @cond_true_70 + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "cond_true_70_0.pb" + func.func private @cond_true_70(%arg0: tensor<1x4xf32> {tf._user_specified_name = "x"}) -> (tensor, tensor<1x3xf32>) attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x4>], tf._original_func_name = "cond_true_7"} { + %cst = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0.335351914, 0.084816426, -0.664676845]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.Identity"(%cst) {device = ""} : (tensor) -> tensor + %1 = "tf.PartitionedCall"(%output, %cst_1, %cst_0) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %output_2, %min_3, %max_4, %histogram_5 = "tf.CustomAggregator"(%1) <{calibration_method = 1 : i32, id = "3", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + %2 = "tf.Identity"(%output_2) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %0, %2 : tensor, tensor<1x3xf32> + } + + func.func private @composite_matmul_with_bias_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_matmul_with_bias_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} + +// ----- + +// Check the IfRegion is set to stateful. +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1833 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: func.func @serving_default + // CHECK: "tf.IfRegion" + // CHECK-SAME: is_stateless = false + + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "serving_default_0.pb" + + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "serving_default_1.pb" + + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "serving_default_2.pb" + func.func @serving_default(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() <{value = dense<1.000000e+01> : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi32>}> {device = ""} : () -> tensor<2xi32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %cst_2 = "tf.Const"() <{value = dense<[0.335351914, 0.084816426, -0.664676845]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_3 = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_4 = "tf.Const"() <{value = dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %cst_5 = "tf.Const"() <{value = dense<[0.117216609, 0.933735609, 0.0728900209]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.Sum"(%output, %cst_0) <{keep_dims = false}> {device = ""} : (tensor<1x4xf32>, tensor<2xi32>) -> tensor + %1 = "tf.Greater"(%0, %cst) {device = ""} : (tensor, tensor) -> tensor + %2:2 = "tf.IfRegion"(%1) <{_else_func_name = "cond_false_80", _then_func_name = "cond_true_70", is_stateless = true}> ({ + %4 = "tf.Identity"(%cst_3) {device = ""} : (tensor) -> tensor + %5 = "tf.PartitionedCall"(%output, %cst_1, %cst_2) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %output_6, %min_7, %max_8, %histogram_9 = "tf.CustomAggregator"(%5) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + %6 = "tf.Identity"(%output_6) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + "tf.Yield"(%4, %6) {device = ""} : (tensor, tensor<1x3xf32>) -> () + }, { + %4 = "tf.Identity"(%cst_3) {device = ""} : (tensor) -> tensor + %5 = "tf.PartitionedCall"(%output, %cst_4, %cst_5) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %output_6, %min_7, %max_8, %histogram_9 = "tf.CustomAggregator"(%5) <{calibration_method = 1 : i32, id = "2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + %6 = "tf.Identity"(%output_6) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + "tf.Yield"(%4, %6) {device = ""} : (tensor, tensor<1x3xf32>) -> () + }) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = ""} : (tensor) -> (tensor, tensor<1x3xf32>) + %3 = "tf.Identity"(%2#1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + func.func private @composite_matmul_with_bias_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + func.func private @composite_matmul_with_bias_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1836 : i32}, tf_saved_model.semantics} { + func.func @main(%arg0: tensor<10x1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<10x1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = stablehlo.constant dense<0.000000e+00>: tensor<10x1024x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<10x1x1024xf32>) -> (tensor<10x1x1024xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.XlaCallModule"(%output, %cst) <{Sout = [#tf_type.shape<10x1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %output_0, %min_1, %max_2, %histogram_3 = "tf.CustomAggregator"(%0) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<10x1x3xf32>) -> (tensor<10x1x3xf32>, tensor, tensor, tensor<0xi64>) + return %output_0 : tensor<10x1x3xf32> + } + // CHECK-LABEL: @main + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], %[[MIN_O:.*]], %[[MAX_O:.*]], %[[HISTOGRAM_0:.*]] = "tf.CustomAggregator" + // CKECK-SAME: <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], %[[MIN_1:.*]], %[[MAX_1:.*]], %[[HISTOGRAM_1:.*]] = "tf.CustomAggregator" + // CKECK-SAME: <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: "tf.CalibrationStatisticsSaver"(%[[MIN_O]], %[[MAX_O]], %[[HISTOGRAM_0]], %[[MIN_1]], %[[MAX_1]], %[[HISTOGRAM_1]]) + // CHECK-SAME: <{calibration_methods = [1 : i32, 1 : i32], ids = ["0", "1"], output_file_path = "main_0.pb"}> : (tensor, tensor, tensor<0xi64>, tensor, tensor, tensor<0xi64>) -> () + // CHECK: return + + func.func private @composite_dot_general_with_relu_fn_1(%arg0: tensor<10x1x1024xf32>, %arg1: tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x1x3xf32> + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %1 = stablehlo.maximum %0, %cst : tensor<10x1x3xf32> + return %1 : tensor<10x1x3xf32> + } + // CHECK-LABEL: func.func private @composite_dot_general_with_relu_fn_1 + // CHECK-NOT: "tf.CalibrationStatisticsSaver" +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1836 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: func.func @main + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "main_0.pb" + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "main_1.pb" + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "main_2.pb" + func.func @main(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = stablehlo.constant dense<1.000000e+01> : tensor + %cst_0 = stablehlo.constant dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32> + %c = stablehlo.constant dense : tensor + %cst_1 = stablehlo.constant dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32> + %cst_2 = stablehlo.constant dense<-0.000000e+00> : tensor + %cst_3 = stablehlo.constant dense<[[0.335351914, 0.084816426, -0.664676845]]> : tensor<1x3xf32> + %cst_4 = stablehlo.constant dense<[[0.117216609, 0.933735609, 0.0728900209]]> : tensor<1x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) + %0 = stablehlo.reduce(%output init: %cst_2) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x4xf32>, tensor) -> tensor + %1 = stablehlo.compare GT, %0, %cst : (tensor, tensor) -> tensor + %2:2 = "stablehlo.if"(%1) ({ + %3 = "tf.XlaCallModule"(%output, %cst_0, %cst_3) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_same_shape_fn_2, _original_entry_function = "composite_dot_general_with_bias_same_shape_fn_2", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %output_5, %min_6, %max_7, %histogram_8 = "tf.CustomAggregator"(%3) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + stablehlo.return %c, %output_5 : tensor, tensor<1x3xf32> + }, { + %3 = "tf.XlaCallModule"(%output, %cst_1, %cst_4) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_same_shape_fn_1, _original_entry_function = "composite_dot_general_with_bias_same_shape_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %output_5, %min_6, %max_7, %histogram_8 = "tf.CustomAggregator"(%3) <{calibration_method = 1 : i32, id = "2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + stablehlo.return %c, %output_5 : tensor, tensor<1x3xf32> + }) : (tensor) -> (tensor, tensor<1x3xf32>) + return %2#1 : tensor<1x3xf32> + } + func.func private @composite_dot_general_with_bias_same_shape_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_with_bias_same_shape_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_calibration_statistics_saver_with_skipping.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_calibration_statistics_saver_with_skipping.mlir new file mode 100644 index 00000000000000..a7a4e6d7b47fe2 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_calibration_statistics_saver_with_skipping.mlir @@ -0,0 +1,47 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-insert-calibration-statistics-saver='aggregator-ops-to-ignore=skipping_id' | FileCheck %s + +func.func @serving_default(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x2x2x2xf32>) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}} { + %cst = "tf.Const"() <{value = dense<[[[[-0.891899645, 0.392044574], [0.77720493, 1.31188095], [0.255048186, 2.700150e+00]], [[-1.08111858, -0.406604826], [-0.298575521, -2.25356531], [-1.00201964, 2.54532099]], [[-1.34911358, 0.279911458], [-0.868258893, -1.36708188], [0.866317451, -2.05804896]]], [[[-0.591397941, 0.331505477], [0.715151429, 2.64073896], [1.27163255, 0.206143498]], [[0.474211812, 1.45044816], [0.119936548, 2.54149938], [-0.939900994, 0.438387245]], [[-1.12486279, -1.09022558], [0.82202208, 1.04652023], [1.30316162, 2.62054276]]]]> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 5 : i32, id = "skipping_id", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) + %0 = "tf.Conv2D"(%output, %cst) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + %output_1, %min_2, %max_3, %histogram_4 = "tf.CustomAggregator"(%0) <{calibration_method = 5 : i32, id = "keeping_id", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x2x2x2xf32>) -> (tensor<1x2x2x2xf32>, tensor, tensor, tensor<512xi64>) + %1 = "tf.Identity"(%output_1) {device = ""} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %1 : tensor<1x2x2x2xf32> +} +// CHECK-LABEL: @serving_default +// CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], %[[MIN_O:.*]], %[[MAX_O:.*]], %[[HISTOGRAM_0:.*]] = "tf.CustomAggregator" +// CKECK-SAME: <{calibration_method = 5 : i32, id = "skipping_id", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) +// CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], %[[MIN_1:.*]], %[[MAX_1:.*]], %[[HISTOGRAM_1:.*]] = "tf.CustomAggregator" +// CKECK-SAME: <{calibration_method = 5 : i32, id = "keeping_id", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) +// CHECK: "tf.CalibrationStatisticsSaver"(%[[MIN_1]], %[[MAX_1]], %[[HISTOGRAM_1]]) +// CHECK-SAME: <{calibration_methods = [5 : i32], ids = ["keeping_id"], output_file_path = "serving_default_0.pb"}> : (tensor, tensor, tensor<512xi64>) -> () +// CHECK: return + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1836 : i32}, tf_saved_model.semantics} { + func.func @main(%arg0: tensor<10x1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<10x1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = stablehlo.constant dense<0.000000e+00>: tensor<10x1024x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "skipping_id", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<10x1x1024xf32>) -> (tensor<10x1x1024xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.XlaCallModule"(%output, %cst) <{Sout = [#tf_type.shape<10x1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %output_0, %min_1, %max_2, %histogram_3 = "tf.CustomAggregator"(%0) <{calibration_method = 1 : i32, id = "keeping_id", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<10x1x3xf32>) -> (tensor<10x1x3xf32>, tensor, tensor, tensor<0xi64>) + return %output_0 : tensor<10x1x3xf32> + } + // CHECK-LABEL: @main + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], %[[MIN_O:.*]], %[[MAX_O:.*]], %[[HISTOGRAM_0:.*]] = "tf.CustomAggregator" + // CKECK-SAME: <{calibration_method = 1 : i32, id = "skipping_id", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], %[[MIN_1:.*]], %[[MAX_1:.*]], %[[HISTOGRAM_1:.*]] = "tf.CustomAggregator" + // CKECK-SAME: <{calibration_method = 1 : i32, id = "keeping_id", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: "tf.CalibrationStatisticsSaver"(%[[MIN_1]], %[[MAX_1]], %[[HISTOGRAM_1]]) + // CHECK-SAME: <{calibration_methods = [1 : i32], ids = ["keeping_id"], output_file_path = "main_0.pb"}> : (tensor, tensor, tensor<0xi64>) -> () + // CHECK: return + + func.func private @composite_dot_general_with_relu_fn_1(%arg0: tensor<10x1x1024xf32>, %arg1: tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x1x3xf32> + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %1 = stablehlo.maximum %0, %cst : tensor<10x1x3xf32> + return %1 : tensor<10x1x3xf32> + } + // CHECK-LABEL: func.func private @composite_dot_general_with_relu_fn_1 + // CHECK-NOT: "tf.CalibrationStatisticsSaver" +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_weight_param.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_weight_param.mlir new file mode 100644 index 00000000000000..8812a2963b72e3 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_weight_param.mlir @@ -0,0 +1,374 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-insert-weight-param | FileCheck %s + +// Test that q/dq pair with per-tensor quantization parameter is inserted +// between constant and XlaCallModule op with empty `weight_only_ptq` method +// and function name containing conv. + +func.func @qdq_for_conv_weight_empty(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x2x2x2>], _entry_function = @composite_conv_fn, + _original_entry_function = "composite_conv_fn", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + return %0 : tensor<1x2x2x2xf32> +} + +// CHECK-LABEL: func.func @qdq_for_conv_weight_empty +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> +// CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<2x3x3x2xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) +// CHECK-SAME: _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq { }" +// CHECK-SAME: (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> +// CHECK: return %[[CALL]] : tensor<1x2x2x2xf32> + +// ----- + +// Test that q/dq pair with per-tensor quantization parameter is inserted +// between constant and XlaCallModule op with empty `weight_only_ptq` method and +// function name containing dot_general. + +func.func @qdq_for_dot_general_weight_empty(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @qdq_for_dot_general_weight_empty +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> +// CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3xf32>}> : () -> tensor<2x3xf32> +// CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<2x3xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) +// CHECK-SAME: _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq { }" +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] : tensor<1x3xf32> + +// ----- + +// Test that q/dq pair with per-tensor quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `per_tensor` and function name containing conv. + +func.func @qdq_for_conv_weight_per_tensor(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x2x2x2>], _entry_function = @composite_conv_fn, + _original_entry_function = "composite_conv_fn", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {per_tensor {}}}}", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + return %0 : tensor<1x2x2x2xf32> +} + +// CHECK-LABEL: func.func @qdq_for_conv_weight_per_tensor +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> +// CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<2x3x3x2xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) +// CHECK-SAME: _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {per_tensor {}}}}" +// CHECK-SAME: (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> +// CHECK: return %[[CALL]] : tensor<1x2x2x2xf32> + +// ----- + +// Test that q/dq pair with per-tensor quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `per_tensor` and function name containing dot_general. + +func.func @qdq_for_dot_general_weight_per_tensor(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {per_tensor {}}}}", _stablehlo_module_attrs = {}, + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @qdq_for_dot_general_weight_per_tensor +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> +// CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3xf32>}> : () -> tensor<2x3xf32> +// CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<2x3xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) +// CHECK-SAME: _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {per_tensor {}}}}" +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] : tensor<1x3xf32> + +// ----- + +// Test that q/dq pair with per-channel quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `quatized_type` without specified quantization dimension and function name +// containing conv. + +module attributes {tf_saved_model.semantics} { + func.func private @qdq_for_conv_weight_per_channel_default(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], version = 5 : i64, + _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}", + _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + device = "" + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } + + // CHECK: func.func private @qdq_for_conv_weight_per_channel_default(%[[ARG0:.+]]: tensor<1x3x4x3xf32>) + // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> + // CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> + // CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<2x3x3x2xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[DQ]]) + // CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + // CHECK: return %[[CALL]] + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } + // CHECK: func private @composite_conv_fn + // CHECK: %[[CONV:.+]] = stablehlo.convolution + // CHECK: return %[[CONV]] +} + +// ----- + +// Test that q/dq pair with per-channel quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `quatized_type` without specified quantization dimension and function name +// containing dot_general. + +module attributes {tf_saved_model.semantics} { + func.func private @qdq_for_dot_general_weight_per_channel_default(%arg0: tensor<4x3x6x5xf32>) -> tensor<4x3x6x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<4x3x5x2xf32>} : () -> tensor<4x3x5x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<4x3x6x2>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}", + _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + return %0 : tensor<4x3x6x2xf32> + } + // CHECK: func.func private @qdq_for_dot_general_weight_per_channel_default(%[[ARG0:.+]]: tensor<4x3x6x5xf32>) + // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<4x3x5x2xf32>}> : () -> tensor<4x3x5x2xf32> + // CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<4x3x5x2xf32>) -> tensor<4x3x5x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> + // CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<4x3x5x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<4x3x5x2xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[DQ]]) + // CHECK-SAME: (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + // CHECK: return %[[CALL]] + + func.func private @composite_dot_general_fn(%arg0: tensor<4x3x6x5xf32>, %arg1: tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] : (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + return %0 : tensor<4x3x6x2xf32> + } + // CHECK: func private @composite_dot_general_fn + // CHECK: %[[DOT:.+]] = stablehlo.dot_general + // CHECK: return %[[DOT]] +} + +// ----- + +// Test that q/dq pair with per-channel quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `quatized_type` with specified quantization dimension and function name +// containing conv. + +module attributes {tf_saved_model.semantics} { + func.func private @qdq_for_conv_weight_per_channel(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], version = 5 : i64, + _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + device = "" + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } + + // CHECK: func.func private @qdq_for_conv_weight_per_channel(%[[ARG0:.+]]: tensor<1x3x4x3xf32>) + // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> + // CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> + // CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<2x3x3x2xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[DQ]]) + // CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + // CHECK: return %[[CALL]] + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } + // CHECK: func private @composite_conv_fn + // CHECK: %[[CONV:.+]] = stablehlo.convolution + // CHECK: return %[[CONV]] +} + +// ----- + +// Test that q/dq pair with per-channel quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `quatized_type` with specified quantization dimension and function name +// containing dot_general. + +module attributes {tf_saved_model.semantics} { + func.func private @qdq_for_dot_general_weight_per_channel(%arg0: tensor<4x3x6x5xf32>) -> tensor<4x3x6x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<4x3x5x2xf32>} : () -> tensor<4x3x5x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<4x3x6x2>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + return %0 : tensor<4x3x6x2xf32> + } + // CHECK: func.func private @qdq_for_dot_general_weight_per_channel(%[[ARG0:.+]]: tensor<4x3x6x5xf32>) + // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<4x3x5x2xf32>}> : () -> tensor<4x3x5x2xf32> + // CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<4x3x5x2xf32>) -> tensor<4x3x5x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> + // CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<4x3x5x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<4x3x5x2xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[DQ]]) + // CHECK-SAME: (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + // CHECK: return %[[CALL]] + + func.func private @composite_dot_general_fn(%arg0: tensor<4x3x6x5xf32>, %arg1: tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] : (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + return %0 : tensor<4x3x6x2xf32> + } + // CHECK: func private @composite_dot_general_fn + // CHECK: %[[DOT:.+]] = stablehlo.dot_general + // CHECK: return %[[DOT]] +} + +// ----- + +// Test that q/dq pair is not inserted between constant and XlaCallModule op +// whose entry function name does not include conv nor dot_general. + +func.func @no_qdq_except_conv_and_dot_general(%arg0: tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<3x4x2xf32>} : () -> tensor<3x4x2xf32> + %0 = "tf.XlaCallModule"(%cst, %arg0) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_gather_fn, + _original_entry_function = "composite_gather_fn", _quantization_method = "weight_only_ptq { }", + _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], + disabled_checks = [], has_token_input_output = false, module = "", + platforms = [], version = 5 : i64 + } : (tensor<3x4x2xf32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> + return %0 : tensor<2x3x2x2xf32> +} + +// CHECK-LABEL: func.func @no_qdq_except_conv_and_dot_general +// CHECK-NOT: quantization.qcast +// CHECK-NOT: quantization.dcast + +// ----- + +// Test that q/dq pair is not inserted for constant whose operand number is +// not 1. + +func.func @no_qdq_for_non_weight_constant(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<4.000000e-02> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.XlaCallModule"(%arg0, %arg1, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_fn, + _original_entry_function = "composite_dot_general_with_bias_fn", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @no_qdq_for_non_weight_constant +// CHECK-NOT: quantization.qcast +// CHECK-NOT: quantization.dcast + +// ----- + +// Test that q/dq pair is not inserted between constant and XlaCallModule op +// without `weight_only_ptq` method. + +func.func @no_qdq_for_not_quantizable_call(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], + disabled_checks = [], has_token_input_output = false, module = "", + platforms = [], version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @no_qdq_for_not_quantizable_call +// CHECK-NOT: quantization.qcast +// CHECK-NOT: quantization.dcast + +// ----- + +// Test that q/dq pair is not inserted between constant and XlaCallModule op +// with different method. + +func.func @no_qdq_for_not_quantizable_call(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], + disabled_checks = [], has_token_input_output = false, module = "", + platforms = [], _quantization_method = "static_range_ptq { }", version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @no_qdq_for_not_quantizable_call +// CHECK-NOT: quantization.qcast +// CHECK-NOT: quantization.dcast + +// ----- + +// Test that q/dq pair is not inserted when constant has multiple users. + +func.func @no_qdq_for_multiple_users(%arg0: tensor<2x2xf32>) -> tensor<2x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %2 = stablehlo.add %cst, %0 : tensor<2x3xf32> + return %2 : tensor<2x3xf32> +} + +// CHECK-LABEL: func.func @no_qdq_for_multiple_users +// CHECK-NOT: quantization.qcast +// CHECK-NOT: quantization.dcast diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_lift_quantizable_spots_as_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_lift_quantizable_spots_as_functions.mlir new file mode 100644 index 00000000000000..e0c0406bb89294 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_lift_quantizable_spots_as_functions.mlir @@ -0,0 +1,861 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-lift-quantizable-spots-as-functions | FileCheck %s + +// CHECK-LABEL: @conv_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %1: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_fn_1 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: return %[[CONV]] : tensor<1x3x3x4xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + return %1 : tensor<1x1x64xf32> +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: return %[[DOT_GENERAL:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_same_shape_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x2xf32> +func.func @dot_general_with_bias_same_shape_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<2x3xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x3xf32> + %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %3 = stablehlo.add %2, %1 : tensor<1x3xf32> + func.return %3: tensor<1x3xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_same_shape_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 +// CHECK: return %[[ADD]] : tensor<1x3xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_bias_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_with_bias_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %3 = stablehlo.broadcast_in_dim %1, dims = [3] : (tensor<4xf32>) -> tensor<1x3x3x4xf32> + %4 = stablehlo.add %2, %3 : tensor<1x3x3x4xf32> + func.return %4: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_bias_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_IN_DIM]] +// CHECK: return %[[ADD]] : tensor<1x3x3x4xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_bias_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<64xf32> + %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %3 = stablehlo.broadcast_in_dim %1, dims = [2] : (tensor<64xf32>) -> tensor<1x1x64xf32> + %4 = stablehlo.add %2, %3 : tensor<1x1x64xf32> + func.return %4: tensor<1x1x64xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[BROADCAST_IN_DIM]] +// CHECK: return %[[ADD]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_bias_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %3 = shape.shape_of %2 : tensor -> tensor<4xindex> + %4 = stablehlo.dynamic_broadcast_in_dim %1, %3, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %5 = stablehlo.add %2, %4 : tensor + func.return %5: tensor +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_bias_dynamic_fn_1 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[CONV]] +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF]] +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM]] +// CHECK: return %[[ADD]] : tensor +// CHECK: } + +// ----- + +// Because the operand of shape_of is other than the target conv, +// should not match conv bias pattern. + +// CHECK-LABEL: @conv_with_bias_dynamic_shape_not_same_op_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_dynamic_shape_not_same_op_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %4 = shape.shape_of %3 : tensor -> tensor<4xindex> + %5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %6 = stablehlo.add %2, %5 : tensor + func.return %6: tensor +} +// CHECK-NOT: @composite_conv_with_bias_dynamic_fn_1 + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @dot_general_with_bias_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<12544x10xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<10xf32> + %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<12544x10xf32>) -> tensor + %3 = shape.shape_of %2 : tensor -> tensor<2xindex> + %4 = stablehlo.dynamic_broadcast_in_dim %1, %3, dims = [1] : (tensor<10xf32>, tensor<2xindex>) -> tensor + %5 = stablehlo.add %2, %4 : tensor + func.return %5: tensor +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_dynamic_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[DOT_GENERAL]] +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] +// CHECK: return %[[ADD]] : tensor +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_relu_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_with_relu_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %3 = stablehlo.maximum %2, %1 : tensor<1x3x3x4xf32> + func.return %3: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_relu_fn_1 +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[CONV]], %[[CONST]] +// CHECK: return %[[MAX]] : tensor<1x3x3x4xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_relu_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32>, +func.func @dot_general_with_relu_fn(%arg0: tensor<1x1x167xf32>, %arg1: tensor<167x64xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %3 = stablehlo.maximum %2, %1 : tensor<1x1x64xf32> + return %3 : tensor<1x1x64xf32> +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_relu_fn_1 +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[DOT_GENERAL]], %[[CONST]] +// CHECK: return %[[MAX:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_relu_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_relu_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %3 = shape.shape_of %2 : tensor -> tensor<4xindex> + %4 = stablehlo.dynamic_broadcast_in_dim %1, %3, dims = [] : (tensor, tensor<4xindex>) -> tensor + %5 = stablehlo.maximum %2, %4 : tensor + func.return %5: tensor +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_relu_dynamic_fn_1 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[CONV]] +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF]] +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM]] +// CHECK: return %[[MAX]] : tensor +// CHECK: } + +// ----- + +// Because the operand of shape_of is other than the target conv, +// should not match conv relu dynamic pattern. + +// CHECK-LABEL: @conv_with_relu_dynamic_shape_not_same_op_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_relu_dynamic_shape_not_same_op_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %4 = shape.shape_of %3 : tensor -> tensor<4xindex> + %5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [] : (tensor, tensor<4xindex>) -> tensor + %6 = stablehlo.maximum %2, %5 : tensor + func.return %6: tensor +} +// CHECK-NOT: private @composite_conv_with_relu_dynamic_fn_1 + +// ----- + +// CHECK-LABEL: @dot_general_with_relu_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @dot_general_with_relu_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<12544x10xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor + %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<12544x10xf32>) -> tensor + %3 = shape.shape_of %2 : tensor -> tensor<2xindex> + %4 = stablehlo.dynamic_broadcast_in_dim %1, %3, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5 = stablehlo.maximum %2, %4 : tensor + func.return %5: tensor +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_relu_dynamic_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[DOT_GENERAL]] +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF]] +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM]] +// CHECK: return %[[MAX]] : tensor +// CHECK: } + +// ----- + +// The pattern should not match when the const value for relu is not 0. + +// CHECK-LABEL: @conv_with_relu_wrong_const_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_with_relu_wrong_const_fn(%arg0: tensor<1x3x3x4xf32>, %arg1: tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %3 = stablehlo.maximum %2, %1 : tensor<1x3x3x4xf32> + func.return %3: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]]) +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[XLA_CALL_MODULE]], %[[CONST_1]] +// CHECK: return %[[MAX]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_fn_1 +// CHECK-NOT: private @composite_conv_with_relu_fn_1 + +// ----- + +// CHECK-LABEL: @conv_with_relu6_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_with_relu6_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32> + %2 = stablehlo.constant dense<6.000000e+00> : tensor<1x3x3x4xf32> + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %4 = stablehlo.clamp %1, %3, %2 : tensor<1x3x3x4xf32> + func.return %4: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_relu6_fn_1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[CONV]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor<1x3x3x4xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_relu6_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_relu6_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %2 = stablehlo.constant dense<6.000000e+00> : tensor<1x1x64xf32> + %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %4 = stablehlo.clamp %1, %3, %2 : tensor<1x1x64xf32> + return %4 : tensor<1x1x64xf32> +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_relu6_fn_1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[DOT_GENERAL]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_relu6_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_relu6_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor + %2 = stablehlo.constant dense<6.000000e+00> : tensor + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %4 = stablehlo.clamp %1, %3, %2 : (tensor, tensor, tensor) -> tensor + func.return %4: tensor +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_relu6_fn_1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[CONV]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_relu6_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @dot_general_with_relu6_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<12544x10xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor + %2 = stablehlo.constant dense<6.000000e+00> : tensor + %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<12544x10xf32>) -> tensor + %4 = stablehlo.clamp %1, %3, %2 : (tensor, tensor, tensor) -> tensor + func.return %4: tensor +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_relu6_fn_1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[DOT_GENERAL]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_same_shape_and_relu_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_bias_same_shape_and_relu_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x64xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %4 = stablehlo.add %3, %1 : tensor<1x1x64xf32> + %5 = stablehlo.maximum %4, %2 : tensor<1x1x64xf32> + func.return %5: tensor<1x1x64xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_same_shape_and_relu_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[CONST]] +// CHECK: return %[[MAX]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_bias_and_relu_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_with_bias_and_relu_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32> + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %4 = stablehlo.broadcast_in_dim %1, dims = [3] : (tensor<4xf32>) -> tensor<1x3x3x4xf32> + %5 = stablehlo.add %3, %4 : tensor<1x3x3x4xf32> + %6 = stablehlo.maximum %5, %2 : tensor<1x3x3x4xf32> + func.return %6: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_bias_and_relu_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_IN_DIM]] +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[CONST]] +// CHECK: return %[[MAX]] : tensor<1x3x3x4xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_and_relu_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_bias_and_relu_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<64xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %4 = stablehlo.broadcast_in_dim %1, dims = [2] : (tensor<64xf32>) -> tensor<1x1x64xf32> + %5 = stablehlo.add %3, %4 : tensor<1x1x64xf32> + %6 = stablehlo.maximum %5, %2 : tensor<1x1x64xf32> + func.return %6: tensor<1x1x64xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_and_relu_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[BROADCAST_IN_DIM]] +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[CONST]] +// CHECK: return %[[MAX]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_bias_and_relu_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_and_relu_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %4 = shape.shape_of %3 : tensor -> tensor<4xindex> + %5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %6 = stablehlo.add %3, %5 : tensor + %7 = shape.shape_of %6 : tensor -> tensor<4xindex> + %8 = stablehlo.dynamic_broadcast_in_dim %2, %7, dims = [] : (tensor, tensor<4xindex>) -> tensor + %9 = stablehlo.maximum %6, %8 : tensor + func.return %9: tensor +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_bias_and_relu_dynamic_fn_1 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[CONV]] +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] +// CHECK: %[[SHAPE_OF_1:.*]] = shape.shape_of %[[ADD]] +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF_1]] +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[DYNAMIC_BROADCAST_IN_DIM_1]] +// CHECK: return %[[MAX]] : tensor +// CHECK: } + +// ----- + +// Because the operand of shape_of is other than the target conv, +// should not match conv bias relu dynamic pattern. + +// CHECK-LABEL: @conv_with_bias_and_relu_dynamic_shape_not_same_op_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_and_relu_dynamic_shape_not_same_op_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %4 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %5 = shape.shape_of %4 : tensor -> tensor<4xindex> + %6 = stablehlo.dynamic_broadcast_in_dim %1, %5, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %7 = stablehlo.add %3, %6 : tensor + %8 = shape.shape_of %7 : tensor -> tensor<4xindex> + %9 = stablehlo.dynamic_broadcast_in_dim %2, %8, dims = [] : (tensor, tensor<4xindex>) -> tensor + %10 = stablehlo.maximum %7, %9 : tensor + func.return %10: tensor +} +// CHECK-NOT: private @composite_conv_with_bias_and_relu_dynamic_fn_1 + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_and_relu_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @dot_general_with_bias_and_relu_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<12544x10xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<10xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<12544x10xf32>) -> tensor + %4 = shape.shape_of %3 : tensor -> tensor<2xindex> + %5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [1] : (tensor<10xf32>, tensor<2xindex>) -> tensor + %6 = stablehlo.add %3, %5 : tensor + %7 = shape.shape_of %6 : tensor -> tensor<2xindex> + %8 = stablehlo.dynamic_broadcast_in_dim %2, %7, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9 = stablehlo.maximum %6, %8 : tensor + func.return %9: tensor +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_and_relu_dynamic_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[DOT_GENERAL]] +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] +// CHECK: %[[SHAPE_OF_1:.*]] = shape.shape_of %[[ADD]] +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF_1]] +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[DYNAMIC_BROADCAST_IN_DIM_1]] +// CHECK: return %[[MAX]] : tensor +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_same_shape_and_relu6_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_bias_same_shape_and_relu6_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x64xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %3 = stablehlo.constant dense<6.000000e+00> : tensor<1x1x64xf32> + %4 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %5 = stablehlo.add %4, %1 : tensor<1x1x64xf32> + %6 = stablehlo.clamp %2, %5, %3 : tensor<1x1x64xf32> + func.return %6: tensor<1x1x64xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_same_shape_and_relu6_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_bias_and_relu6_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_with_bias_and_relu6_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32> + %3 = stablehlo.constant dense<6.000000e+00> : tensor<1x3x3x4xf32> + %4 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %5 = stablehlo.broadcast_in_dim %1, dims = [3] : (tensor<4xf32>) -> tensor<1x3x3x4xf32> + %6 = stablehlo.add %4, %5 : tensor<1x3x3x4xf32> + %7 = stablehlo.clamp %2, %6, %3 : tensor<1x3x3x4xf32> + func.return %7: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_bias_and_relu6_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_IN_DIM]] +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor<1x3x3x4xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_and_relu6_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_bias_and_relu6_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<64xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %3 = stablehlo.constant dense<6.000000e+00> : tensor<1x1x64xf32> + %4 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %5 = stablehlo.broadcast_in_dim %1, dims = [2] : (tensor<64xf32>) -> tensor<1x1x64xf32> + %6 = stablehlo.add %4, %5 : tensor<1x1x64xf32> + %7 = stablehlo.clamp %2, %6, %3 : tensor<1x1x64xf32> + func.return %7: tensor<1x1x64xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_and_relu6_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[BROADCAST_IN_DIM]] +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_bias_and_relu6_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.constant dense<6.000000e+00> : tensor + %4 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %5 = shape.shape_of %4 : tensor -> tensor<4xindex> + %6 = stablehlo.dynamic_broadcast_in_dim %1, %5, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %7 = stablehlo.add %4, %6 : tensor + %8 = stablehlo.clamp %2, %7, %3 : (tensor, tensor, tensor) -> tensor + func.return %8: tensor +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_bias_and_relu6_dynamic_fn_1 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[CONV]] +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor +// CHECK: } + +// ----- + +// Because the operand of shape_of is other than the target conv, +// should not match conv bias relu6 dynamic pattern. + +// CHECK-LABEL: @conv_with_bias_and_relu6_dynamic_shape_not_same_op_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_and_relu6_dynamic_shape_not_same_op_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.constant dense<6.000000e+00> : tensor + %4 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %5 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %6 = shape.shape_of %5 : tensor -> tensor<4xindex> + %7 = stablehlo.dynamic_broadcast_in_dim %1, %6, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %8 = stablehlo.add %4, %7 : tensor + %9 = stablehlo.clamp %2, %8, %3 : (tensor, tensor, tensor) -> tensor + func.return %9: tensor +} +// CHECK-NOT: private @composite_conv_with_bias_and_relu6_dynamic_fn_1 + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_and_relu6_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @dot_general_with_bias_and_relu6_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<12544x10xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<10xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.constant dense<6.000000e+00> : tensor + %4 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<12544x10xf32>) -> tensor + %5 = shape.shape_of %4 : tensor -> tensor<2xindex> + %6 = stablehlo.dynamic_broadcast_in_dim %1, %5, dims = [1] : (tensor<10xf32>, tensor<2xindex>) -> tensor + %7 = stablehlo.add %4, %6 : tensor + %8 = stablehlo.clamp %2, %7, %3 : (tensor, tensor, tensor) -> tensor + func.return %8: tensor +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_and_relu6_dynamic_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[DOT_GENERAL]] +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor +// CHECK: } + +// ----- + +// CHECK-LABEL: @gather_fn( +func.func @gather_fn() -> tensor<2x3x2x2xi32> { + %0 = stablehlo.constant dense<1> : tensor<3x4x2xi32> + %1 = stablehlo.constant dense<1> : tensor<2x3x2xi64> + %2 = "stablehlo.gather"(%0, %1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false +} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32> + func.return %2: tensor<2x3x2x2xi32> +} +// CHECK: %[[OPERAND:.*]] = stablehlo.constant +// CHECK: %[[INDICES:.*]] = stablehlo.constant +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[OPERAND]], %[[INDICES]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<2x3x2x2xi32> +// CHECK: } + +// CHECK-LABEL: private @composite_gather_fn_1 +// CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%arg0, %arg1) +// CHECK: return %[[GATHER]] : tensor<2x3x2x2xi32> +// CHECK: } + +// ----- + +// Test that the name of composite functions are deterministic. There are 3 +// unsorted functions in this module and each function has 2 quantizable ops. +module { + func.func @conv_3_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%1, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %2: tensor<1x3x3x4xf32> + } + + func.func @conv_1_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%1, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %2: tensor<1x3x3x4xf32> + } + + func.func @conv_2_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%1, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %2: tensor<1x3x3x4xf32> + } +} + +// CHECK-LABEL: @conv_3_fn +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_6, _original_entry_function = "composite_conv_fn_6" +// CHECK-SAME: _stablehlo_version = "{{.*}}" +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_5, _original_entry_function = "composite_conv_fn_5" +// CHECK-SAME: _stablehlo_version = "{{.*}}" + +// CHECK-LABEL: @conv_1_fn +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_2, _original_entry_function = "composite_conv_fn_2" +// CHECK-SAME: _stablehlo_version = "{{.*}}" +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_1, _original_entry_function = "composite_conv_fn_1" +// CHECK-SAME: _stablehlo_version = "{{.*}}" + +// CHECK-LABEL: @conv_2_fn +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_4, _original_entry_function = "composite_conv_fn_4" +// CHECK-SAME: _stablehlo_version = "{{.*}}" +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_3, _original_entry_function = "composite_conv_fn_3" +// CHECK-SAME: _stablehlo_version = "{{.*}}" \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_merge-fusion-with-dequantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_merge-fusion-with-dequantize.mlir new file mode 100644 index 00000000000000..65154cb890cfc2 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_merge-fusion-with-dequantize.mlir @@ -0,0 +1,198 @@ +// RUN: stablehlo-quant-opt %s -tf-stablehlo-merge-fusion-with-dequantize -split-input-file -verify-diagnostics | FileCheck %s + +// Merge fusion with dequantize for relu case. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_relu_fusion + func.func private @merge_relu_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_relu_fn + // CHECK-SAME: -> tensor<1x3xf32> + %2 = call @quantized_dot_general_relu_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_relu_fn + func.func private @quantized_dot_general_relu_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + // CHECK: %[[MIN:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %arg0, %arg1 + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]] + // CHECK: %[[MAX:.*]] = chlo.broadcast_maximum %[[DQ]], %[[MIN]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Merge fusion with dequantize for relu6 case. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_relu6_fusion + func.func private @merge_relu6_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_relu6_fn + // CHECK-SAME: -> tensor<1x3xf32> + %2 = call @quantized_dot_general_relu6_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_relu6_fn + func.func private @quantized_dot_general_relu6_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + // CHECK-DAG: %[[MIN:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[MAX:.*]] = stablehlo.constant dense<6.000000e+00> : tensor + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %arg0, %arg1 + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]] + // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[MIN]], %[[DQ]], %[[MAX]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Merge fusion with dequantize for no activation case. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_no_act_fusion + func.func private @merge_no_act_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_fn + // CHECK-SAME: -> tensor<1x3xf32> + %2 = call @quantized_dot_general_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_fn + func.func private @quantized_dot_general_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %arg0, %arg1 + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]] + // CHECK: return %[[DQ]] : tensor<1x3xf32> + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Do not merge when quant.uniform result is used directly. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @no_merge_fusion_direct_usage + func.func private @no_merge_fusion_direct_usage(%arg0: tensor<1x4xf32>) -> (tensor<1x3xf32>, tensor<1x3x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_relu_fn + // CHECK-SAME: -> tensor<1x3x!quant.uniform> + %2 = call @quantized_dot_general_relu_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3, %2 : tensor<1x3xf32>, tensor<1x3x!quant.uniform> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_relu_fn + func.func private @quantized_dot_general_relu_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Do not merge when fusion and dequantize is already merged. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @no_merge_fusion_already_merged + func.func private @no_merge_fusion_already_merged(%arg0: tensor<1x4xf32>) -> (tensor<1x3xf32>) { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_fn + // CHECK-SAME: -> tensor<1x3xf32> + %2 = call @quantized_dot_general_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_fn + func.func private @quantized_dot_general_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_dequantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} + +// ----- + +// Do not merge when function is not quantized function. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_relu_fusion + func.func private @merge_relu_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @some_func + // CHECK-SAME: -> tensor<1x3x!quant.uniform> + %2 = call @some_func(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @some_func + func.func private @some_func( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Do not merge when the quantized fusion is invalid. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_relu_fusion + func.func private @merge_relu_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_relu_fn + // CHECK-SAME: -> tensor<1x3x!quant.uniform> + %2 = call @quantized_dot_general_relu_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_relu_fn + func.func private @quantized_dot_general_relu_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + %0 = stablehlo.constant() {value = dense<2> : tensor<1x3xi8>} : () -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_nchw_convolution_to_nhwc.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_nchw_convolution_to_nhwc.mlir new file mode 100644 index 00000000000000..3dfb5555ef4359 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_nchw_convolution_to_nhwc.mlir @@ -0,0 +1,96 @@ +// RUN: stablehlo-quant-opt %s -tf-stablehlo-nchw-convolution-to-nhwc \ +// RUN: -split-input-file -verify-diagnostics | FileCheck %s + +// Tests that `stablehlo.transpose` ops are inserted for each of input, filter, +// and output. +// Output dimension numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + +// CHECK-LABEL: nchw_conv +// CHECK-SAME: %[[ARG:.+]]: tensor<1x8x4x4xf32> +func.func @nchw_conv(%arg0: tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32> { + %0 = stablehlo.constant() {value = dense<7.000000e+00> : tensor<8x8x3x3xf32>} : () -> tensor<8x8x3x3xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x8x4x4xf32>, tensor<8x8x3x3xf32>) -> tensor<1x8x4x4xf32> + return %2 : tensor<1x8x4x4xf32> +} + +// CHECK-DAG: %[[CONST:.+]] = stablehlo.constant {{.*}} : tensor<8x8x3x3xf32> +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[ARG]], dims = [0, 2, 3, 1] : (tensor<1x8x4x4xf32>) -> tensor<1x4x4x8xf32> +// CHECK-DAG: %[[TRANSPOSE_1:.+]] = stablehlo.transpose %[[CONST]], dims = [2, 3, 1, 0] : (tensor<8x8x3x3xf32>) -> tensor<3x3x8x8xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[TRANSPOSE_0]], %[[TRANSPOSE_1]]) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = {{\[\[}}1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x4x4x8xf32>, tensor<3x3x8x8xf32>) -> tensor<1x4x4x8xf32> +// CHECK: %[[TRANSPOSE_2:.+]] = stablehlo.transpose %[[CONV]], dims = [0, 3, 1, 2] : (tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> + +// ----- + +// Tests that the conversion doesn't happen when the input dimension numbers +// are not [b, f, 0, 1]. + +// CHECK-LABEL: conv_input_dim_numbers_mismatch +func.func @conv_input_dim_numbers_mismatch(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> { + %0 = stablehlo.constant() {value = dense<7.000000e+00> : tensor<8x8x3x3xf32>} : () -> tensor<8x8x3x3xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x4x4x8xf32>, tensor<8x8x3x3xf32>) -> tensor<1x8x4x4xf32> + return %2 : tensor<1x8x4x4xf32> +} + +// CHECK-NOT: stablehlo.transpose +// CHECK: %[[CONV:.+]] = stablehlo.convolution +// CHECK-SAME{LITERAL}: [b, 0, 1, f]x[o, i, 0, 1]->[b, f, 0, 1] +// CHECK-NOT: stablehlo.transpose + +// ----- + +// Tests that the conversion doesn't happen when the feature dimension numbers +// are not [i, 0, 1, o]. + +// CHECK-LABEL: conv_feature_dim_numbers_mismatch +func.func @conv_feature_dim_numbers_mismatch(%arg0: tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32> { + %0 = stablehlo.constant() {value = dense<7.000000e+00> : tensor<8x3x3x8xf32>} : () -> tensor<8x3x3x8xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 0, 1]x[i, 0, 1, o]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x8x4x4xf32>, tensor<8x3x3x8xf32>) -> tensor<1x8x4x4xf32> + return %2 : tensor<1x8x4x4xf32> +} + +// CHECK-NOT: stablehlo.transpose +// CHECK: %[[CONV:.+]] = stablehlo.convolution +// CHECK-SAME{LITERAL}: [b, f, 0, 1]x[i, 0, 1, o]->[b, f, 0, 1] +// CHECK-NOT: stablehlo.transpose + +// ----- + +// Tests that the conversion doesn't happen when the output dimension numbers +// are not [b, 0, 1, f]. + +// CHECK-LABEL: conv_output_dim_numbers_mismatch +func.func @conv_output_dim_numbers_mismatch(%arg0: tensor<1x8x4x4xf32>) -> tensor<1x4x4x8xf32> { + %0 = stablehlo.constant() {value = dense<7.000000e+00> : tensor<8x8x3x3xf32>} : () -> tensor<8x8x3x3xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x8x4x4xf32>, tensor<8x8x3x3xf32>) -> tensor<1x4x4x8xf32> + return %2 : tensor<1x4x4x8xf32> +} + +// CHECK-NOT: stablehlo.transpose +// CHECK: %[[CONV:.+]] = stablehlo.convolution +// CHECK-SAME{LITERAL}: [b, f, 0, 1]x[o, i, 0, 1]->[b, 0, 1, f] +// CHECK-NOT: stablehlo.transpose + +// ----- + +// Tests that a quantized convolution does not match. No conversion occurs. + +// CHECK-LABEL: quantized_convolution +func.func @quantized_convolution(%arg0: tensor<1x4x3x3x!quant.uniform>, %arg1: tensor<2x4x3x3x!quant.uniform>) -> tensor<1x2x3x3x!quant.uniform> { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x4x3x3x!quant.uniform>, tensor<2x4x3x3x!quant.uniform>) -> tensor<1x2x3x3x!quant.uniform> + return %0 : tensor<1x2x3x3x!quant.uniform> +} + +// CHECK-NOT: stablehlo.transpose + +// ----- + +// Tests that a quantized convolution with rank > 4 does not match. +// No conversion occurs. + +// CHECK-LABEL: convolution_3d +func.func @convolution_3d(%arg0: tensor<1x4x28x28x1xf32>, %arg1: tensor<2x3x3x1x16xf32>) -> tensor<1x3x26x26x16xf32> { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], window = {} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x4x28x28x1xf32>, tensor<2x3x3x1x16xf32>) -> tensor<1x3x26x26x16xf32> + return %0 : tensor<1x3x26x26x16xf32> +} + +// CHECK-NOT: stablehlo.transpose diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_optimize_graph.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_optimize_graph.mlir new file mode 100644 index 00000000000000..92484985334b38 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_optimize_graph.mlir @@ -0,0 +1,33 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-optimize-graph | FileCheck %s + +// CHECK-LABEL: @merge_requantization_followed_by_dequantization +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x4x3xf32> +func.func @merge_requantization_followed_by_dequantization(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> { + // CHECK: %[[CST:.*]] = stablehlo.constant dense<4.000000e-01> : tensor<2x3x3x2xf32> + // CHECK: %[[QUANT_CST:.*]] = stablehlo.uniform_quantize %[[CST]] + // CHECK: %[[QUANT_ARG_0:.*]] = stablehlo.uniform_quantize %[[ARG_0]] + // CHECK: %[[CONV:.*]] = stablehlo.convolution(%[[QUANT_ARG_0]], %[[QUANT_CST]]) + // CHECK-NOT: stablehlo.uniform_quantize + // CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[CONV]] + // CHECK: return %[[DEQUANT]] + %cst = stablehlo.constant dense<0.4> : tensor<2x3x3x2xf32> + %quant_cst = stablehlo.uniform_quantize %cst : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32, 0.015>> + %quant_arg = stablehlo.uniform_quantize %arg0 : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> + %conv = stablehlo.convolution(%quant_arg, %quant_cst) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, 0.015>>) -> tensor<1x3x4x2x!quant.uniform> + %requant = stablehlo.uniform_quantize %conv : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> + %dequant = stablehlo.uniform_dequantize %requant : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2xf32> + func.return %dequant : tensor<1x3x4x2xf32> +} + +// ----- + +// CHECK-LABEL: @dont_merge_quantization_followed_by_quantization +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x4x3xf32> +func.func @dont_merge_quantization_followed_by_quantization(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> { + // CHECK: %[[QUANT_ARG_0:.*]] = stablehlo.uniform_quantize %[[ARG_0]] + // CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[QUANT_ARG_0]] + // CHECK: return %[[DEQUANT]] + %quant_arg = stablehlo.uniform_quantize %arg0 : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> + %dequant = stablehlo.uniform_dequantize %quant_arg : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> + func.return %dequant : tensor<1x3x4x3xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_post_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_post_quantize.mlir new file mode 100644 index 00000000000000..01f2ee34f0c80b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_post_quantize.mlir @@ -0,0 +1,72 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-post-quantize | FileCheck %s + +// CHECK-LABEL: @remove_volatile_qdq +func.func @remove_volatile_qdq() -> tensor<3x2xf32> { + // CHECK: %[[CST:.*]] = stablehlo.constant + // CHECK-NOT: "quantization.qcast" + // CHECK-NOT: "quantization.dcast" + // CHECK: return %[[CST]] + %cst = stablehlo.constant dense<[[-0.960978984, -0.390246302], [-0.790828585, -0.601039409], [-1.0280807, -1.02731466]]> : tensor<3x2xf32> + %q = "quantization.qcast"(%cst) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + %dq = "quantization.dcast"(%q) : (tensor<3x2x!quant.uniform>) -> tensor<3x2xf32> + func.return %dq : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @remove_volatile_qdq_with_requantization +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32> +func.func @remove_volatile_qdq_with_requantization(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> { + // CHECK: %[[Q1:.*]] = stablehlo.uniform_quantize %[[ARG0]] + // CHECK: %[[Q2:.*]] = stablehlo.uniform_quantize %[[Q1]] + // CHECK: %[[ABS:.*]] = stablehlo.abs %[[Q2]] + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[ABS]] + // CHECK: %[[ADD:.*]] = stablehlo.add %[[ARG0]], %[[DQ]] + // CHECK: return %[[ADD]] + %q1 = "quantization.qcast"(%arg0) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + %q2 = "quantization.qcast"(%q1) {volatile} : (tensor<3x2x!quant.uniform>) -> tensor<3x2x!quant.uniform> + %dq1 = "quantization.dcast"(%q2) : (tensor<3x2x!quant.uniform>) -> tensor<3x2xf32> + %abs = stablehlo.abs %q2 : (tensor<3x2x!quant.uniform>) -> tensor<3x2x!quant.uniform> + %dq2 = "quantization.dcast"(%abs) : (tensor<3x2x!quant.uniform>) -> tensor<3x2xf32> + %add = stablehlo.add %dq1, %dq2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> + func.return %add : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @quantize_constant +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x3xf32> +func.func @quantize_constant(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> { + // CHECK-DAG: %[[QCST:.*]] = stablehlo.constant() <{value = dense<-78> : tensor<3x2xi8>}> : () -> tensor<3x2x!quant.uniform:f32, 5.000000e-03>> + // CHECK-DAG: %[[Q1:.*]] = stablehlo.uniform_quantize %[[ARG0]] + // CHECK-NOT: "quantization.qcast" + // CHECK: %[[DOT:.*]] = stablehlo.dot %[[Q1]], %[[QCST]] + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]] + // CHECK: return %[[DQ]] + %cst = stablehlo.constant dense<-0.390246302> : tensor<3x2xf32> + %q1 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %q2 = "quantization.qcast"(%cst) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform:f32, 5.000000e-03>> + %dot = stablehlo.dot %q1, %q2 : (tensor<1x3x!quant.uniform>, tensor<3x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<1x2x!quant.uniform> + %dq = "quantization.dcast"(%dot) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + func.return %dq : tensor<1x2xf32> +} + +// ----- + +// CHECK-LABEL: @convert_quantization_qdq_to_stablehlo_uniform_qdq +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x3xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<3x2xf32> +func.func @convert_quantization_qdq_to_stablehlo_uniform_qdq(%arg0: tensor<1x3xf32>, %arg1: tensor<3x2xf32>) -> tensor<1x2xf32> { + // CHECK: %[[Q1:.*]] = stablehlo.uniform_quantize %[[ARG0]] + // CHECK-NOT: "quantization.qcast" + // CHECK: %[[Q2:.*]] = stablehlo.uniform_quantize %[[ARG1]] + // CHECK-NOT: "quantization.qcast" + // CHECK: %[[DOT:.*]] = stablehlo.dot %[[Q1]], %[[Q2]] + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]] + // CHECK: return %[[DQ]] + %q1 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %q2 = "quantization.qcast"(%arg1) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform:f32, 5.000000e-03>> + %dot = stablehlo.dot %q1, %q2 : (tensor<1x3x!quant.uniform>, tensor<3x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<1x2x!quant.uniform> + %dq = "quantization.dcast"(%dot) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + func.return %dq : tensor<1x2xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_quantize_composite_functions.mlir new file mode 100644 index 00000000000000..46e51a7dd0f75b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_quantize_composite_functions.mlir @@ -0,0 +1,896 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -tf-stablehlo-quantize-composite-functions | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -tf-stablehlo-quantize-composite-functions=enable-per-channel-quantized-weight=false | FileCheck --check-prefix=CHECK-PER-TENSOR %s + +// Tests that basic dot_general is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } +// Checks that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. + +// CHECK: func.func private @quantize_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// Checks that the entry function is quantized for dot_general. Quantized +// dot_general outputs an i32 quantized tensor, followed by requantization to +// i8 quantized tensor. + +// CHECK: func.func private @quantized_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> +} + +// ----- + +// Tests that `stablehlo.dot_general` with `batching_dim` is quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_batch_per_tensor_quantized_fn(%arg0: tensor<2x2x2xf32>) -> tensor<2x2x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x2x3xf32>} : () -> tensor<2x2x3xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<2x2x2xf32>, tensor<2x2x3xf32>) -> tensor<2x2x3xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<2x2x3xf32>) -> tensor<2x2x3xf32> + return %2 : tensor<2x2x3xf32> + } +// CHECK: func.func private @quantize_dot_general_batch_per_tensor_quantized_fn(%[[ARG_0:.+]]: tensor<2x2x2xf32>) -> tensor<2x2x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x2x3xi8>}> : () -> tensor<2x2x3x!quant.uniform:f32, {{.*}}>> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<2x2x2xf32>) -> tensor<2x2x2x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<2x2x2x!quant.uniform>, tensor<2x2x3x!quant.uniform:f32, {{.*}}>) -> tensor<2x2x3x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<2x2x3x!quant.uniform) -> tensor<2x2x3xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<2x2x3xf32> + + func.func private @composite_dot_general_fn(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x3xf32>) -> tensor<2x2x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<2x2x2xf32>, tensor<2x2x3xf32>) -> tensor<2x2x3xf32> + return %0 : tensor<2x2x3xf32> + } +} + +// ----- + +// Tests that fused pattern for dot_general + bias is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_with_bias_same_shape_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_same_shape_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_with_bias_same_shape_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } +// CHECK: func.func private @quantize_dot_general_with_bias_same_shape_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> +// CHECK: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x3xi32>}> : () -> tensor<1x3x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_same_shape_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_dot_general_with_bias_same_shape_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x3xi32>}> : () -> tensor<1x3x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_same_shape_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + + func.func private @composite_dot_general_with_bias_same_shape_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +// CHECK: func.func private @quantized_dot_general_with_bias_same_shape_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32:1, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[ARG_3]] : tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_dot_general_with_bias_same_shape_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[ARG_3]] : tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> + +} + +// ----- + +// Tests that fused pattern for dot_general + bias with dynamic batch dimension +// is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_with_bias_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape], _entry_function = @composite_dot_general_with_bias_dynamic_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_with_bias_dynamic_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3xf32>, tensor<3xf32>) -> tensor + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_dot_general_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<3xi32>}> : () -> tensor<3x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>, tensor<3x!quant.uniform) -> tensor +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_dot_general_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<3xi32>}> : () -> tensor<3x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<3x!quant.uniform) -> tensor +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_dot_general_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3xf32>, %arg2: tensor<3xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<2> : tensor<1xi32> + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x3xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [1] : (tensor<3xf32>, tensor<2xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + return %5 : tensor + } +} +// CHECK: func.func private @quantized_dot_general_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32:1, {{.*}}>>, %[[ARG_3:.+]]: tensor<3x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK: %[[CONST_2:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[DOT_GENERAL_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [1] : (tensor<3x!quant.uniform>, tensor<2xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_dot_general_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<3x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[CONST_2:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[DOT_GENERAL_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [1] : (tensor<3x!quant.uniform>, tensor<2xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor> + +// ----- + +// Tests that basic convolution is properly quantized. It is per-channel +// quantized unless `enable-per-channel-quantized-weight=false`, according to +// `_quantization_method` with an `input_quantized_types` and explicit +// `dimension_specs`. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst) { + Sout = [#tf_type.shape<1x3x4x2>], + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64, + _entry_function = @composite_conv_fn, + _stableghlo_version = "1.0.0", + _original_entry_function = "composite_conv_fn", + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _stablehlo_module_attrs = {}, + _tfl_quant_trait = "fully_quantizable", + device = "" + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// Check that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. + +// CHECK: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) +// CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +// Checks that the entry function is quantized for convolution. Quantized +// convolution outputs an i32 quantized tensor, followed by requantization to +// i8 quantized tensor. + +// CHECK: func.func private @quantized_conv_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Tests that basic convolution is properly quantized. In this example, the +// convolution is always per-tensor quantized (even if +// enable-per-channel-quantized-weights=true), according to +// `_quantization_method`. + +// CHECK-LABEL: quantize_conv_fn_per_tensor +func.func @quantize_conv_fn_per_tensor(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst) { + Sout = [#tf_type.shape<1x3x4x2>], + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64, + _entry_function = @composite_conv_fn, + _stablehlo_version = "1.0.0", + _original_entry_function = "composite_conv_fn", + _quantization_method = "static_range_ptq { }", + _stablehlo_module_attrs = {}, + _tfl_quant_trait = "fully_quantizable", + device = "" + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> +} +// Check that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. + +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> +} +// Checks that the entry function is quantized for convolution. Quantized +// convolution outputs an i32 quantized tensor, followed by requantization to +// i8 quantized tensor. + +// CHECK: func.func private @quantized_conv_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// ----- + +// Tests that fused pattern for convolution + bias is properly quantized. + +// Checks that fused functions with 1D bias is properly quantized. +// The 1D bias should be broadcasted in dims [3], where it initially has +// `quantizedDimension=0`, but has `quantizedDimension=3` after broadcasting. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_1d_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_1d_fn, + _stablehlo_version = "1.0.0", + _original_entry_function = "composite_conv_with_bias_1d_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantize_conv_with_bias_1d_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<47978> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_1d_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_conv_with_bias_1d_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.broadcast_in_dim %arg2, dims = [3] : (tensor<2xf32>) -> tensor<1x3x4x2xf32> + %1 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = stablehlo.add %1, %0 : tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantized_conv_with_bias_1d_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %arg2, dims = [3] : (tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_1d_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %[[ARG_3]] +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Checks that fused functions with 4D bias is properly quantized. +// The 4D bias should be braoadcasted in dims [0, 1, 2, 3], where it +// already has `quantizedDimension=3`. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_fn, + _stablehlo_version = "1.0.0", + _original_entry_function = "composite_conv_with_bias_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantize_conv_with_bias_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_conv_with_bias_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> + %1 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = stablehlo.add %1, %0 : tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantized_conv_with_bias_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %arg2 +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Tests that fused pattern for convolution + bias with dynamic batch dimension +// is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_dynamic_fn, + _stablehlo_version = "1.0.0", + _original_entry_function = "composite_conv_with_bias_dynamic_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_conv_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER_TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_conv_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> + %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> + %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>, tensor<4xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + return %5 : tensor + } +} +// CHECK: func.func private @quantized_conv_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// ----- + +// Tests that fused pattern for convolution + bias + relu with +// dynamic batch dimension is properly quantized. + +// Note that this checks for identical condition as +// quantize_conv_with_bias_dynamic_fn, omitting stablehlo.maximum. +// This is because activation clipping which includes 0.0f can be simply +// omitted from the graph as the lifted function's out_scale and out_zp are +// already calculated based on the clipped distribution. +// Note that the resulting scale and zero point should be calculated based on +// clipped range [0, r_max]. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_and_relu_dynamic_fn, + _stablehlo_version = "1.0.0", + _original_entry_function = "composite_conv_with_bias_and_relu_dynamic_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %2 = "quantization.stats"(%1) {layerStats = dense<[0.00000000e-6, 8.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_conv_with_bias_and_relu_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> + %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> + %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> + %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor + %cst_4 = stablehlo.constant dense<6.000000e+00> : tensor + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>, tensor<4xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + %6 = stablehlo.clamp %cst_3, %5, %cst_4 : (tensor, tensor, tensor) -> tensor + return %6 : tensor + } +} +// CHECK: func.func private @quantized_conv_with_bias_and_relu_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_and_relu_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// ----- + +// Tests that fused pattern for convolution + bias + relu6 with +// dynamic batch dimension is properly quantized. + +// Note that this checks for identical condition as +// quantize_conv_with_bias_dynamic_fn, omitting stablehlo.clamp. +// This is because activation clipping which includes 0.0f can be simply +// omitted from the graph as the lifted function's out_scale and out_zp are +// already calculated based on the clipped distribution. +// Note that the resulting scale and zero point should be calculated based on +// clipped range [0, r_max]. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_and_relu6_dynamic_fn, + _stablehlo_version = "1.0.0", + _original_entry_function = "composite_conv_with_bias_and_relu6_dynamic_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 6.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER-TENSOR: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> + %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> + %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> + %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor + %cst_4 = stablehlo.constant dense<6.000000e+00> : tensor + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>, tensor<4xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + %6 = stablehlo.clamp %cst_3, %5, %cst_4 : (tensor, tensor, tensor) -> tensor + return %6 : tensor + } +} +// CHECK: func.func private @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// ----- + +// Tests that XlaCallModule op is not quantized and converted to func.call without the quantization.stats ops. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantized_without_stats_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stableghlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is +// not quantized. + +// CHECK: func.func private @not_quantized_without_stats_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> +// CHECK: %[[CALL:.+]] = call @composite_dot_general_fn(%[[ARG_0]], %[[CONST_0]]) : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// CHECK: func.func private @composite_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2xf32>, %[[ARG_2:.+]]: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} +// Check that the composite_dot_general_fn is untouched. +// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]] +// CHECK: return %[[DOT_GENERAL_0]] +} + +// ----- + +// Tests that basic `stablehlo.gather` is properly quantized. + +module attributes {tf_saved_model.semantics} { +// CHECK: func.func private @quantize_gather_fn(%[[ARG:.+]]: tensor<3x4x2xf32>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} + func.func private @quantize_gather_fn(%arg: tensor<3x4x2xf32>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<1> : tensor<2x3x2xi32>} : () -> tensor<2x3x2xi32> + %0 = "quantization.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<3x4x2xf32>) -> tensor<3x4x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<2x3x2x2>], _entry_function = @composite_gather_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_gather_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2xf32> + return %2 : tensor<2x3x2x2xf32> + } +// Checks that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. +// CHECK: %[[CONST:.+]] = stablehlo.constant dense<{{.*}}> : tensor<2x3x2xi32> +// CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<3x4x2xf32>) -> tensor<3x4x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_gather_fn(%[[UNIFORM_QUANTIZE]], %[[CONST]]) {_quantization_method = "static_range_ptq { }"} : (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[CALL]] : (tensor<2x3x2x2x!quant.uniform) -> tensor<2x3x2x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE]] : tensor<2x3x2x2xf32> + +// CHECK: func.func private @quantized_gather_fn(%[[ARG_0:.+]]: tensor<3x4x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_gather_fn(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> attributes {_from_xla_call_module} { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + return %0 : tensor<2x3x2x2xf32> + } +// CHECK: %[[GATHER:.+]] = "stablehlo.gather"(%[[ARG_0]], %[[ARG_1]]) {{.*}} : (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[GATHER]] : tensor<2x3x2x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor<2x3x2x2x!quant.uniform> +} + +// ----- + +// Tests that a basic `stablehlo.add` and a fused `stablehlo.dot_general` +// are properly quantized. + +module attributes {tf_saved_model.semantics} { +// CHECK: func.func private @quantize_add_fn(%[[ARG:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} + func.func private @quantize_add_fn(%arg: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst_0 = "tf.Const"() {value = dense<1.00000000e-1> : tensor<1x2xf32>} : () -> tensor<1x2xf32> + %cst_1 = "tf.Const"() {value = dense<1.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantization.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst_0) {Sout = [#tf_type.shape<1x2>], _entry_function = @composite_add_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_add_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %3 = "quantization.stats"(%2) {layerStats = dense<[5.00000000e-6, 6.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %4 = "tf.XlaCallModule"(%3, %cst_1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantization.stats"(%4) {layerStats = dense<[5.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %5 : tensor<1x3xf32> + } +// CHECK: %[[CONST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<1x2xi8>}> : () -> tensor<1x2x!quant.uniform> +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}>> +// CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_add_fn(%[[UNIFORM_QUANTIZE]], %[[CONST]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[CALL]] : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[UNIFORM_DEQUANTIZE]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + +// CHECK: func.func private @quantized_add_fn(%[[ARG_0:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_add_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.add %arg0, %arg1 : tensor<1x2xf32> + return %0 : tensor<1x2xf32> + } +// CHECK: %[[ADD:.+]] = stablehlo.add %arg0, %arg1 : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: return %[[ADD]] : tensor<1x2x!quant.uniform> + +// CHECK: func.func private @quantized_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// CHECK: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1,{{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE]] : tensor<1x3x!quant.uniform> +} + +// ----- + +// Tests that `stablehlo.add` is not quantized and emits error when the function +// does not include two ops. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantize_fn_when_not_singular(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<1.00000000e-1> : tensor<1x2xf32>} : () -> tensor<1x2xf32> + %0 = "quantization.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x2>], _entry_function = @composite_add_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_add_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> + // expected-error@+1 {{'stablehlo.uniform_dequantize' op operand #0 must be ranked tensor of per-tensor integer quantized or per-axis integer quantized values, but got 'tensor<1x2xf32>'}} + %2 = "quantization.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + return %2 : tensor<1x2xf32> + } + + func.func private @composite_add_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.add %arg0, %arg1 : tensor<1x2xf32> + %1 = stablehlo.add %0, %arg1 : tensor<1x2xf32> + return %1 : tensor<1x2xf32> + } +} + +// ----- + +// Tests that `stablehlo.gather` without `static_range_ptq` is not quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantize_singular_op_without_static_range_ptq(%arg: tensor<3x4x2xf32>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<1> : tensor<2x3x2xi32>} : () -> tensor<2x3x2xi32> + %0 = "quantization.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<3x4x2xf32>) -> tensor<3x4x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<2x3x2x2>], _entry_function = @composite_gather_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_gather_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + // expected-error@+1 {{'stablehlo.uniform_dequantize' op operand #0 must be ranked tensor of per-tensor integer quantized or per-axis integer quantized values, but got 'tensor<2x3x2x2xf32>'}} + %2 = "quantization.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2xf32> + return %2 : tensor<2x3x2x2xf32> + } + + func.func private @composite_gather_fn(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> attributes {_from_xla_call_module} { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + return %0 : tensor<2x3x2x2xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_quantize_composite_functions_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_quantize_composite_functions_weight_only.mlir new file mode 100644 index 00000000000000..1467313c585a92 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_quantize_composite_functions_weight_only.mlir @@ -0,0 +1,122 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -tf-stablehlo-quantize-composite-functions | FileCheck --check-prefix=CHECK %s + +// Test that per-tensor weight-only quantized dot_general op is produced when +// empty `weight_only_ptq` is provided. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_per_tensor(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// CHECK-LABEL: quantize_dot_general_per_tensor +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq { }"} : (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_dot_general_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3xf32> +// CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3xf32> +// CHECK: return %[[DOT]] + +// ----- + +// Test that per-tensor weight-only quantized convolution op is produced when +// empty `weight_only_ptq` is provided. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_per_tensor(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %1 : tensor<1x3x4x2xf32> + } + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +} + +// CHECK-LABEL: quantize_conv_per_tensor +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq { }"} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_conv_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x3x4x3xf32>, %[[ARG2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG1]], %[[ARG2]]) +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CONV]] + +// ----- + +// Test that per-channel weight-only quantized dot_general op is produced when +// `weight_only_ptq` with `dimension_specs` is provided. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_per_channel(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// CHECK-LABEL: quantize_dot_general_per_channel +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {0.0023622048182750312,0.0023622048182750312,0.0023622048182750312}>> +// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}"} +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32:1, {0.0023622048182750312,0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_dot_general_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform:f32:1, {0.0023622048182750312,0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3xf32> +// CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32:1, {0.0023622048182750312,0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3xf32> +// CHECK: return %[[DOT]] + +// ----- + +// Test that per-channel weight-only quantized convolution op is produced when +// `weight_only_ptq` with `dimension_specs` is provided. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_per_channel(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %1 : tensor<1x3x4x2xf32> + } + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +} + +// CHECK-LABEL: quantize_conv_per_channel +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}"} +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_conv_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x3x4x3xf32>, %[[ARG2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG1]], %[[ARG2]]) +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CONV]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_remove_sharding_custom_call.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_remove_sharding_custom_call.mlir new file mode 100644 index 00000000000000..c408290bd4a915 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_remove_sharding_custom_call.mlir @@ -0,0 +1,20 @@ +// RUN: stablehlo-quant-opt %s -tf-stablehlo-remove-sharding-custom-call \ +// RUN: -split-input-file | FileCheck %s + +// CHECK-LABEL: sharding_custom_call_removed +func.func @sharding_custom_call_removed(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %1 = stablehlo.custom_call @Sharding(%arg0) {mhlo.sharding = ""} : (tensor<3xf32>) -> tensor<3xf32> + return %1 : tensor<3xf32> +} +// CHECK-NOT: custom_call + +// ----- + +// Tests that a custom_call that is not @Sharding is not removed. + +// CHECK-LABEL: custom_call_not_removed +func.func @custom_call_not_removed(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %1 = stablehlo.custom_call @NotSharding(%arg0) : (tensor<3xf32>) -> tensor<3xf32> + return %1 : tensor<3xf32> +} +// CHECK: custom_call @NotSharding diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir new file mode 100644 index 00000000000000..e140973c9fb298 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir @@ -0,0 +1,476 @@ +// RUN: stablehlo-quant-opt %s -split-input-file \ +// RUN: -tf-stablehlo-replace-stablehlo-ops-in-main-function-with-xla-call-module-ops \ +// RUN: | FileCheck %s + +// Modules with "main" or "serving_default" should properly run this pass and +// convert subgraphs into XLACallModuleOp. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + + // CHECK: func private @_stablehlo_main_1 + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1x3xf32> + // CHECK: return + // CHECK: } + + // CHECK: func private @_stablehlo_main_0 + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<3x64xf32> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1x64xf32> + // CHECK: return + // CHECK: } + + func.func @main(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x64xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %1 = stablehlo.constant dense<1.000000e+03> : tensor<1x3xf32> + %2:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %3 = "tf.XlaCallModule"(%2#0, %0, %1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %4:4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + %5 = stablehlo.constant dense<1.000000e+03> : tensor<3x64xf32> + %6 = stablehlo.constant dense<1.000000e+03> : tensor<1x64xf32> + %7:4 = "tf.CustomAggregator"(%4#0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + %8 = "tf.XlaCallModule"(%7#0, %5, %6) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x3xf32>, tensor<3x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> + %9:4 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x64xf32>) -> (tensor<1x64xf32>, tensor, tensor, tensor<*xi64>) + return %9#0 : tensor<1x64xf32> + } + + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable"} + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_0]]) + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0 + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[CUSTOM_AGGREGATOR_1]]) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable"} + // CHECK: %[[CUSTOM_AGGREGATOR_3:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_1:.*]]) + // CHECK: return %[[CUSTOM_AGGREGATOR_3]] : tensor<1x64xf32> + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + + // CHECK: @composite_dot_general_with_relu_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_with_relu_fn_1(%arg0: tensor<1x3xf32>, %arg1: tensor<3x64xf32>, %arg2: tensor<1x64xf32>) -> tensor<1x64xf32> { + %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x64xf32> + %1 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x3xf32>, tensor<3x64xf32>) -> tensor<1x64xf32> + %2 = stablehlo.add %1, %arg2 : tensor<1x64xf32> + %3 = stablehlo.maximum %2, %0 : tensor<1x64xf32> + return %3 : tensor<1x64xf32> + } +} + + +// ----- + +// Tests that the subgraph in serving_default excluding the tf.Identity is +// converted to a single XlaCallModuleOp. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1654 : i32}, tf_saved_model.semantics} { + + // CHECK: func private @_stablehlo_main_0(%arg0: tensor, %arg1: tensor<1x1024xf32>) + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<0.134728625> : tensor<1x3xf32> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<-1.280000e+02> : tensor<1x1024xf32> + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<0.003921567> : tensor<1x1024xf32> + // CHECK: %[[DIVIDE:.*]] = stablehlo.divide %arg1, %[[CONSTANT_2]] + // CHECK: %[[ADD:.*]] = stablehlo.add %[[DIVIDE]], %[[CONSTANT_1]] + // CHECK return %[[ADD]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x1024xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<0.134728625> : tensor<1x3xf32> + %1 = stablehlo.constant dense<-1.280000e+02> : tensor<1x1024xf32> + %2 = stablehlo.constant dense<0.003921567> : tensor<1x1024xf32> + %3 = stablehlo.divide %arg0, %2 : tensor<1x1024xf32> + %4 = stablehlo.add %3, %1 : tensor<1x1024xf32> + %5 = "tf.Identity"(%4) {device = ""} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> + return %5 : tensor<1x1024xf32> + } + + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"(%arg0) <{Sout = [#tf_type.shape<1x1024>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _stablehlo_version = "{{.*}}"} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> + // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP]]) + // CHECK: return %[[IDENTITY]] + // CHECK } + +} + +// ----- + +// Tests that the first stablehlo.constant is converted to XlaCallModuleOp. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func private @_stablehlo_main_0 + // CHECK: %[[CONSTANT:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT:.*]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %2 = "tf.XlaCallModule"(%1#0, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + return %3#0 : tensor<1x3xf32> + } + + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _stablehlo_version = "{{.*}}"} + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}" + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: return %[[CUSTOM_AGGREGATOR_1]] + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +// Tests to confirm that the StableHLO graph is not replaced if "main" or +// "serving_default" function is not in the module. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK-NOT: func private @_stablehlo_main_ + + // CHECK-LABEL: @random_name + func.func @random_name(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %2 = "tf.XlaCallModule"(%1#0, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + return %3#0 : tensor<1x3xf32> + } + + // CHECK: %[[CONSTANT:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[XLA_CALL_MODULE_EXTRACTED_FROM_SUBGRAPH:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0" + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: return %[[CUSTOM_AGGREGATOR_1]] + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +// Tests where StableHLO graph in main has a small constant to be duplicated. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func private @_stablehlo_main_1(%arg0: tensor) -> tensor<1024x3xf32> attributes {_from_xla_call_module} + // CHECK: %[[CONSTANT1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT1:.*]] + // CHECK: } + + // CHECK: func private @_stablehlo_main_0(%arg0: tensor + // CHECK-SAME: %[[INPUT1:.*]]: tensor<1024x3xf32>, %[[INPUT2:.*]]: tensor<1024x3xf32> + // CHECK: %[[CONSTANT2:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[ADD:.*]] = stablehlo.add %[[INPUT1]], %[[CONSTANT2]] : tensor<1024x3xf32> + // CHECK: %[[MUL:.*]] = stablehlo.multiply %[[INPUT1]], %[[INPUT2]] : tensor<1024x3xf32> + // CHECK: return %[[ADD]], %[[MUL]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1024x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1024x3xf32> {tf_saved_model.index_path = ["output1"]}, tensor<1024x3xf32> {tf_saved_model.index_path = ["output2"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %2 = "tf.XlaCallModule"(%1#0, %0) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %4 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %5 = stablehlo.add %3#0, %4 : tensor<1024x3xf32> + %6 = stablehlo.multiply %3#0, %0 : tensor<1024x3xf32> + return %5, %6 : tensor<1024x3xf32>, tensor<1024x3xf32> + } + + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0" + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_2:.*]]:2 = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_2]]#0, %[[SUBGRAPH_2]]#1 + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +// Tests where StableHLO graph in main has branches. +// This test makes sure tracing won't stop at op (%1) with multiple uses. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func private @_stablehlo_main_1(%arg0: tensor) -> tensor<3x11xf32> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<3x11xf32> + // CHECK: return %[[CONSTANT_1:.*]] + // CHECK: } + + // CHECK: func private @_stablehlo_main_0 + // CHECK-SAME: (%arg0: tensor, %[[INPUT_1:.*]]: tensor<3x11xf32>) + // CHECK-SAME: -> tensor<3x11xf32> + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<1.000000e+01> : tensor<3x11xf32> + // CHECK: %[[ADD:.*]] = stablehlo.add %[[INPUT_1]], %[[CONSTANT_2]] : tensor<3x11xf32> + // CHECK: %[[MUL:.*]] = stablehlo.multiply %[[ADD]], %[[CONSTANT_2]] : tensor<3x11xf32> + // CHECK: return %[[MUL]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<3x3xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<3x11xf32> {tf_saved_model.index_path = ["output1"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<3x11xf32> + // %1 is large enough that it won't be duplicated. + %1 = stablehlo.constant dense<1.000000e+01> : tensor<3x11xf32> + %2:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor, tensor, tensor<*xi64>) + %3 = "tf.XlaCallModule"(%2#0, %0) {Sout = [#tf_type.shape<3x11>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x3xf32>, tensor<3x11xf32>) -> tensor<3x11xf32> + %4:4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<3x11xf32>) -> (tensor<3x11xf32>, tensor, tensor, tensor<*xi64>) + %5 = stablehlo.add %4#0, %1 : tensor<3x11xf32> + %6 = stablehlo.multiply %5, %1 : tensor<3x11xf32> + return %6 : tensor<3x11xf32> + } + + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<3x11>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<3x11>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0" + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_2:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]]) <{Sout = [#tf_type.shape<3x11>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_2]] + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<3x3xf32>, %arg1: tensor<3x11xf32>) -> tensor<3x11xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3x3xf32>, tensor<3x11xf32>) -> tensor<3x11xf32> + return %0 : tensor<3x11xf32> + } +} + +// ----- + +// Tests where StableHLO graph in main has dead end. +// This test makes sure tracing will include the dead end from the op in the +// same sub graph: +// stablehlo.add and %0 along with its dead end branch are in the same sub +// graph. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func.func private @_stablehlo_main_1(%arg0: tensor) -> tensor<1024x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT_0]] + // CHECK: } + + // CHECK: func.func private @_stablehlo_main_0(%arg0: tensor, %[[ARG_1:.*]]: tensor<1024x3xf32>) -> tensor<1024x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<5.000000e+01> : tensor<1024x3xf32> + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<4.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[REMAINDER:.*]] = stablehlo.remainder %[[CONSTANT_3]], %[[CONSTANT_1]] : tensor<1024x3xf32> + // CHECK: %[[COMPARE:.*]] = stablehlo.compare EQ, %[[REMAINDER]], %[[CONSTANT_2]], NOTYPE : (tensor<1024x3xf32>, tensor<1024x3xf32>) -> tensor<1024x3xi1> + // CHECK: stablehlo.custom_call @shape_assertion(%[[COMPARE]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<1024x3xi1>) -> () + // CHECK: %[[ADD:.*]] = stablehlo.add %[[ARG_1]], %[[CONSTANT_3]] + // CHECK: return %[[ADD]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1024x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1024x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<4.000000e+03> : tensor<1024x3xf32> + %1 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %2 = stablehlo.constant dense<5.000000e+01> : tensor<1024x3xf32> + %3 = stablehlo.remainder %0, %1 : tensor<1024x3xf32> + %4 = stablehlo.compare EQ, %3, %2, NOTYPE : (tensor<1024x3xf32>, tensor<1024x3xf32>) -> tensor<1024x3xi1> + stablehlo.custom_call @shape_assertion(%4) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<1024x3xi1>) -> () + %5 = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> + %6:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %7 = "tf.XlaCallModule"(%6#0, %5) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %8:4 = "tf.CustomAggregator"(%7) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %9 = stablehlo.add %8#0, %0 : tensor<1024x3xf32> + return %9 : tensor<1024x3xf32> + } + // CHECK: %[[SUBGRAPH_0:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[SUBGRAPH_0]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0" + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_1]] : tensor<1024x3xf32> + // CHECK: } +} + +// ----- + +// Tests where StableHLO graph in main has branch. +// This test makes sure the branch will not be added to subgraph when it reaches +// a tf op: +// stablehlo.add and %0 are not in the same subgraph. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func.func private @_stablehlo_main_2(%arg0: tensor) -> (tensor<1024x3xf32>, tensor<1024x3xf32>) attributes {_from_xla_call_module} { + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<4.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[REMAINDER:.*]] = stablehlo.remainder %[[CONSTANT_0]], %[[CONSTANT_1]] : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT_0]], %[[REMAINDER]] + // CHECK: } + + // CHECK: func.func private @_stablehlo_main_1(%arg0: tensor) -> tensor<1024x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT_2]] : tensor<1024x3xf32> + // CHECK: } + + // CHECK: func.func private @_stablehlo_main_0(%arg0: tensor, %[[ARG_1:.*]]: tensor<1024x3xf32>, %[[ARG_2:.*]]: tensor<1024x3xf32>) -> tensor<1024x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[ADD:.*]] = stablehlo.add %[[ARG_1]], %[[ARG_2]] + // CHECK: return %[[ADD]] + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1024x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1024x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<4.000000e+03> : tensor<1024x3xf32> + %1 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %2 = stablehlo.remainder %0, %1 : tensor<1024x3xf32> + %3 = "tf.Identity"(%2) {device = ""} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> + %4 = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> + %5:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %6 = "tf.XlaCallModule"(%5#0, %4) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %7:4 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %8 = stablehlo.add %7#0, %0 : tensor<1024x3xf32> + return %8 : tensor<1024x3xf32> + } + // CHECK: %[[SUBGRAPH_0:.*]]:2 = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_2 + // CHECK: %[[IDENTIFY:.*]] = "tf.Identity"(%[[SUBGRAPH_0]]#1) {device = ""} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0" + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_2:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_0]]#0) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_2]] : tensor<1024x3xf32> + // CHECK: } +} + +// ----- + +// Tests where StableHLO graph in main has dead end. +// This test checks tracing will stop if the dead end is too deep (>5): +// stablehlo.add and %0 are not in the same subgraph. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func.func private @_stablehlo_main_1(%arg0: tensor) -> (tensor<1024x3xf32>, tensor<1024x3xf32>) attributes {_from_xla_call_module} { + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<4.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<5.000000e+01> : tensor<1024x3xf32> + // CHECK: %[[REMAINDER_0:.*]] = stablehlo.remainder %[[CONSTANT_0]], %[[CONSTANT_1]] : tensor<1024x3xf32> + // CHECK: %[[REMAINDER_1:.*]] = stablehlo.remainder %[[REMAINDER_0]], %[[CONSTANT_1]] : tensor<1024x3xf32> + // CHECK: %[[REMAINDER_2:.*]] = stablehlo.remainder %[[REMAINDER_1]], %[[CONSTANT_1]] : tensor<1024x3xf32> + // CHECK: %[[REMAINDER_3:.*]] = stablehlo.remainder %[[REMAINDER_2]], %[[CONSTANT_1]] : tensor<1024x3xf32> + // CHECK: %[[COMPARE:.*]] = stablehlo.compare EQ, %[[REMAINDER_3]], %[[CONSTANT_2]], NOTYPE : (tensor<1024x3xf32>, tensor<1024x3xf32>) -> tensor<1024x3xi1> + // CHECK: stablehlo.custom_call @shape_assertion(%[[COMPARE]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<1024x3xi1>) -> () + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT_0]], %[[CONSTANT_3]] + // CHECK: } + + // CHECK: func.func private @_stablehlo_main_0(%arg0: tensor, %[[ARG_1:.*]]: tensor<1024x3xf32>, %[[ARG_2:.*]]: tensor<1024x3xf32>) -> tensor<1024x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[ADD:.*]] = stablehlo.add %[[ARG_1]], %[[ARG_2]] + // CHECK: return %[[ADD]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1024x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1024x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<4.000000e+03> : tensor<1024x3xf32> + %1 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %2 = stablehlo.constant dense<5.000000e+01> : tensor<1024x3xf32> + %3 = stablehlo.remainder %0, %1 : tensor<1024x3xf32> + %4 = stablehlo.remainder %3, %1 : tensor<1024x3xf32> + %5 = stablehlo.remainder %4, %1 : tensor<1024x3xf32> + %6 = stablehlo.remainder %5, %1 : tensor<1024x3xf32> + %7 = stablehlo.compare EQ, %6, %2, NOTYPE : (tensor<1024x3xf32>, tensor<1024x3xf32>) -> tensor<1024x3xi1> + stablehlo.custom_call @shape_assertion(%7) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<1024x3xi1>) -> () + %8 = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> + %9:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %10 = "tf.XlaCallModule"(%9#0, %8) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %11:4 = "tf.CustomAggregator"(%10) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %12 = stablehlo.add %11#0, %0 : tensor<1024x3xf32> + return %12 : tensor<1024x3xf32> + } + // CHECK: %[[SUBGRAPH_0:.*]]:2 = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[SUBGRAPH_0]]#1) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}" + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_0]]#0) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_1]] : tensor<1024x3xf32> + // CHECK: } +} + +// ----- + +// main function contains PartitionedCall and StatefulPartitionedCall ops which +// is used to preserve aliased functions. This test make sure stablehlo ops in +// each PartitionedCall functions are lifted. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func private @_stablehlo_main_2 + // CHECK: stablehlo.multiply %arg1, %arg2 : tensor<3x3xf32> + // CHECK: return + // CHECK: } + + // CHECK: func private @_stablehlo_main_1 + // CHECK: stablehlo.add %arg1, %arg2 : tensor<3x3xf32> + // CHECK: return + // CHECK: } + + // CHECK: func private @_stablehlo_main_0 + // CHECK: stablehlo.constant dense<1.000000e+03> : tensor<3x3xf32> + // CHECK: stablehlo.constant dense<2.000000e+03> : tensor<3x3xf32> + // CHECK: return + // CHECK: } + + func.func @main() -> (tensor<3x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<3x3xf32> + %1 = stablehlo.constant dense<2.000000e+03> : tensor<3x3xf32> + %2 = "tf.StatefulPartitionedCall"(%0, %1) <{ + config = "", config_proto = "", executor_type = "", f = @some_func + }> { + _collective_manager_ids = [], device = "" + } : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + %3 = "tf.PartitionedCall"(%2, %1) <{ + config = "", config_proto = "", executor_type = "", f = @some_other_func + }> { + _collective_manager_ids = [], device = "" + } : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + return %3 : tensor<3x3xf32> + } + // CHECK: func.func @main + // CHECK: %[[INPUT:.*]]:3 = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @_stablehlo_main_0 + // CHECK: %[[ADD:.*]] = "tf.StatefulPartitionedCall"(%[[INPUT]]#1, %[[INPUT]]#2) + // CHECK-SAME: f = @some_func + // CHECK: "tf.PartitionedCall"(%[[ADD]], %[[INPUT]]#0) + // CHECK-SAME: f = @some_other_func + // CHECK: return + + func.func private @some_func(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> attributes {tf._noinline = true} { + %0 = stablehlo.add %arg0, %arg1 : tensor<3x3xf32> + return %0 : tensor<3x3xf32> + } + // CHECK: func.func private @some_func + // CHECK: tf.XlaCallModule + // CHECK-SAME: _entry_function = @_stablehlo_main_1 + // CHECK: return + + func.func private @some_other_func(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> attributes {tf._noinline = true} { + %0 = stablehlo.multiply %arg0, %arg1 : tensor<3x3xf32> + return %0 : tensor<3x3xf32> + } + // CHECK: func.func private @some_other_func + // CHECK: tf.XlaCallModule + // CHECK-SAME: _entry_function = @_stablehlo_main_2 + // CHECK: return +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_restore_function_name.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_restore_function_name.mlir new file mode 100644 index 00000000000000..b6f746c8e46961 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_restore_function_name.mlir @@ -0,0 +1,52 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-restore-function-name | FileCheck %s + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1646 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: @serving_default + // CHECK-SAME: %[[ARG0:[^:[:space:]]+]] + // CHECK-SAME: %[[ARG1:[^:[:space:]]+]] + func.func private @serving_default(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<1x3xf32> { + %0 = "tf.XlaCallModule"(%arg0, %arg1) {Sout = [#tf_type.shape<1x3>], _entry_function = @main, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1 + // CHECK-SAME: _original_entry_function = "composite_dot_general_fn_1" + // CHECK: return %[[CALL]] + } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:[^:[:space:]]+]] + // CHECK-SAME: %[[ARG3:[^:[:space:]]+]] + func.func private @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + // CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK: return %[[DOT]] + } +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1646 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: @serving_default + // CHECK-SAME: %[[ARG0:[^:[:space:]]+]] + // CHECK-SAME: %[[ARG1:[^:[:space:]]+]] + func.func private @serving_default(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<1x3xf32> { + %0 = "tf.XlaCallModule"(%arg0, %arg1) {Sout = [#tf_type.shape<1x3>], _entry_function = @main, _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: _entry_function = @main + // CHECK-NOT: _original_entry_function = "composite_dot_general_fn_1" + // CHECK: return %[[CALL]] + } + + // CHECK: @main + // CHECK-NOT: @composite_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:[^:[:space:]]+]] + // CHECK-SAME: %[[ARG3:[^:[:space:]]+]] + func.func private @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + // CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK: return %[[DOT]] + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_shape_cstr_legalize_to_hlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_shape_cstr_legalize_to_hlo.mlir new file mode 100644 index 00000000000000..e0a2ba600993ca --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_shape_cstr_legalize_to_hlo.mlir @@ -0,0 +1,110 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-convert-shape-to-stablehlo-with-constraints --verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @shape_cstr_broadcastable +func.func @shape_cstr_broadcastable(%arg0: tensor<2xindex>, %arg1: tensor<2xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<2xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[DIMS2]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +// CHECK-LABEL: func @shape_cstr_broadcastable_different_dims_1 +func.func @shape_cstr_broadcastable_different_dims_1(%arg0: tensor<2xindex>, %arg1: tensor<1xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<1xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<1xindex> to tensor<1xi32> + // CHECK-NEXT: %[[PAD:.*]] = stablehlo.constant dense<1> : tensor<1xi32> + // CHECK-NEXT: %[[DIMS2_PAD:.*]] = stablehlo.concatenate %[[PAD]], %[[DIMS2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2_PAD]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[DIMS2_PAD]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +// CHECK-LABEL: func @shape_cstr_broadcastable_different_dims_2 +func.func @shape_cstr_broadcastable_different_dims_2(%arg0: tensor<1xindex>, %arg1: tensor<2xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<1xindex>, tensor<2xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<1xindex> to tensor<1xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[PAD:.*]] = stablehlo.constant dense<1> : tensor<1xi32> + // CHECK-NEXT: %[[DIMS1_PAD:.*]] = stablehlo.concatenate %[[PAD]], %[[DIMS1]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1_PAD]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1_PAD]], %[[DIMS2]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +func.func @shape_cstr_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>, %arg2: tensor<4xindex>) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex> + shape.assuming %0 { + } + func.return +} + +// ----- + +func.func @shape_cstr_broadcastable_input_shape(%arg0: !shape.shape, %arg1: !shape.shape) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape + shape.assuming %0 { + } + func.return +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_unfuse_mhlo_batch_norm.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_unfuse_mhlo_batch_norm.mlir new file mode 100644 index 00000000000000..e6dd30102e1d2f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_unfuse_mhlo_batch_norm.mlir @@ -0,0 +1,30 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-unfuse-mhlo-batch-norm | FileCheck %s + +// CHECK-LABEL: @unfuse_batch_norm +// CHECK-SAME: %[[X:[^:[:space:]]+]] +// CHECK-SAME: %[[SCALE:[^:[:space:]]+]] +// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]] +// CHECK-SAME: %[[MEAN:[^:[:space:]]+]] +// CHECK-SAME: %[[VARIANCE:[^:[:space:]]+]] +func.func @unfuse_batch_norm( + %x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, + %mean: tensor<256xf32>, %variance: tensor<256xf32>) + -> (tensor<4x256xf32>) { + // CHECK-DAG: %[[EPS_BCAST:.+]] = mhlo.constant dense<1.001000e-05> : tensor<256xf32> + // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> + // CHECK-DAG: %[[STDDEV:.+]] = mhlo.sqrt %[[VARIANCE_EPS]] : tensor<256xf32> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32> + // CHECK: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32> + // CHECK: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32> + // CHECK: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32> + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} : + (tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, + tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: return %[[RESULT]] + func.return %0 : tensor<4x256xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_unwrap_xla_call_module_op.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_unwrap_xla_call_module_op.mlir new file mode 100644 index 00000000000000..e31ec5a24cf8c1 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_unwrap_xla_call_module_op.mlir @@ -0,0 +1,53 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-unwrap-xla-call-module-op | FileCheck %s + +// Tests if XlaCallModule op without quantizable trait that calls function with +// '_from_xla_call_module' trait is unwrapped. +// Tests if XlaCallModule op with quantizable trait is not unwrapped. +// Tests if XlaCallModule op without quantizable trait that calls function +// without '_from_xla_call_module' trait is not unwrapped. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1682 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: @main_00 + // CHECK: %[[ARG0:.*]]: tensor<10x1x1024xf32> + func.func private @main_00(%arg0: tensor<10x1x1024xf32>) -> tensor<6x5xf32> attributes {tf._original_func_name = "main_0"} { + %0 = "tf.Const"() <{value = dense<1.000000e+00> : tensor<10x1024x3xf32>}> : () -> tensor<10x1024x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<10x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %2 = "tf.XlaCallModule"(%1) <{Sout = [#tf_type.shape<3x10>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @main_0, _stablehlo_version = "1.0.0", _stablehlo_module_attrs = {}, device = ""} : (tensor<10x1x3xf32>) -> tensor<3x10xf32> + %3 = "tf.XlaCallModule"(%2) <{Sout = [#tf_type.shape<6x5>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @main_1, _stablehlo_version = "1.0.0", _stablehlo_module_attrs = {}, device = ""} : (tensor<3x10xf32>) -> tensor<6x5xf32> + return %3 : tensor<6x5xf32> + } + // CHECK: %[[CST:.*]] = "tf.Const"() + // CHECK-NEXT: %[[CALL1:.*]] = "tf.XlaCallModule"(%[[ARG0]], %[[CST]]) + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1 + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-NOT: "tf.XlaCallModule" + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %[[CALL1]] : (tensor<10x1x3xf32>) -> tensor<3x10xf32> + // CHECK-NEXT: %[[CALL2:.*]] = "tf.XlaCallModule"(%[[RESHAPE]]) + // CHECK-SAME: _entry_function = @main_1 + // CHECK-NOT: _tfl_quant_trait = "fully_quantizable" + // CHECK-NEXT: return %[[CALL2]] + + // CHECK: @composite_dot_general_fn_1 + func.func private @composite_dot_general_fn_1(%arg0: tensor<10x1x1024xf32>, %arg1: tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + return %0 : tensor<10x1x3xf32> + } + // CHECK: %[[DOT:.*]] = stablehlo.dot_general + // CHECK-NEXT: return %[[DOT]] + + // CHECK: @main_0 + func.func private @main_0(%arg0: tensor<10x1x3xf32>) -> tensor<3x10xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.reshape %arg0 : (tensor<10x1x3xf32>) -> tensor<3x10xf32> + return %0 : tensor<3x10xf32> + } + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape + // CHECK-NEXT: return %[[RESHAPE]] + + // CHECK: @main_1 + func.func private @main_1(%arg0: tensor<3x10xf32>) -> tensor<6x5xf32> { + %0 = stablehlo.reshape %arg0 : (tensor<3x10xf32>) -> tensor<6x5xf32> + return %0 : tensor<6x5xf32> + } + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape + // CHECK-NEXT: return %[[RESHAPE]] +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_xla_call_module_to_call.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_xla_call_module_to_call.mlir new file mode 100644 index 00000000000000..15374881b67791 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_xla_call_module_to_call.mlir @@ -0,0 +1,23 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-xla-call-module-to-call | FileCheck %s + +// ----- + +// Tests composite tf.XlaCallModule is converted to func.call. + +module { + // CHECK-LABEL: func.func @main + func.func @main(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> { + // CHECK: call @composite_dot_general_fn_1 + // CHECK-SAME: (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + // CHECK-NOT: tf.XlaCallModule + %0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32> + %2 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + // CHECK-LABEL: func.func private @composite_dot_general_fn_1 + // CHECK-SAME: -> tensor<1x3xf32> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc index c14cff87984890..105ab22d159b5d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc @@ -28,10 +28,12 @@ limitations under the License. #include "stablehlo/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" @@ -46,6 +48,7 @@ int main(int argc, char** argv) { mlir::registerAllPasses(); mlir::registerTensorFlowPasses(); mlir::quant::stablehlo::registerPasses(); + mlir::tf_quant::stablehlo::registerPasses(); mlir::quant::stablehlo::registerBridgePasses(); mlir::stablehlo::registerPasses(); mlir::mhlo::registerAllMhloPasses(); @@ -64,7 +67,7 @@ int main(int argc, char** argv) { mlir::quantfork::QuantizationForkDialect, mlir::stablehlo::StablehloDialect, mlir::tf_executor::TensorFlowExecutorDialect, - mlir::vhlo::VhloDialect>(); + mlir::vhlo::VhloDialect, mlir::quant::ir::TFQuantDialect>(); mlir::mhlo::registerAllMhloDialects(registry); mlir::func::registerAllExtensions(registry); return failed( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index bfc5cde2dcbc82..ce3813b41e594e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -86,6 +86,21 @@ cc_library( ], ) +td_library( + name = "tf_quant_td_files", + srcs = [ + "passes/tf_post_quantize.td", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/common:quant_td_files", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantizationOpsTdFiles", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:FuncTdFiles", + ], +) + td_library( name = "quant_td_files", srcs = [ @@ -114,12 +129,7 @@ td_library( gentbl_cc_library( name = "convert_tf_xla_op_to_tf_op_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/convert_tf_xla_op_to_tf_op.inc", - ), - ], + tbl_outs = {"passes/convert_tf_xla_op_to_tf_op.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/convert_tf_xla_op_to_tf_op.td", deps = [":quant_td_files"], @@ -128,12 +138,7 @@ gentbl_cc_library( gentbl_cc_library( name = "cast_bf16_ops_to_f32_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/cast_bf16_ops_to_f32.inc", - ), - ], + tbl_outs = {"passes/cast_bf16_ops_to_f32.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/cast_bf16_ops_to_f32.td", deps = [":quant_td_files"], @@ -142,12 +147,7 @@ gentbl_cc_library( gentbl_cc_library( name = "prepare_lifting_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/prepare_lifting.inc", - ), - ], + tbl_outs = {"passes/prepare_lifting.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/prepare_lifting.td", deps = [":quant_td_files"], @@ -156,12 +156,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/lift_quantizable_spots_as_functions.inc", - ), - ], + tbl_outs = {"passes/lift_quantizable_spots_as_functions.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions.td", deps = [":quant_td_files"], @@ -170,12 +165,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_drq_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/lift_quantizable_spots_as_functions_drq.inc", - ), - ], + tbl_outs = {"passes/lift_quantizable_spots_as_functions_drq.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions_drq.td", deps = [":quant_td_files"], @@ -184,12 +174,7 @@ gentbl_cc_library( gentbl_cc_library( name = "prepare_quantize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/prepare_quantize.inc", - ), - ], + tbl_outs = {"passes/prepare_quantize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/prepare_quantize.td", deps = [":quant_td_files"], @@ -198,12 +183,7 @@ gentbl_cc_library( gentbl_cc_library( name = "quantize_composite_functions_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/quantize_composite_functions.inc", - ), - ], + tbl_outs = {"passes/quantize_composite_functions.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/quantize_composite_functions.td", deps = [":quant_td_files"], @@ -212,16 +192,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_quant_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "passes/tf_quant_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "passes/tf_quant_ops.cc.inc", - ), - ], + tbl_outs = { + "passes/tf_quant_ops.h.inc": ["-gen-op-decls"], + "passes/tf_quant_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/tf_quant_ops.td", deps = [ @@ -232,12 +206,7 @@ gentbl_cc_library( gentbl_cc_library( name = "optimize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/optimize.inc", - ), - ], + tbl_outs = {"passes/optimize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/optimize.td", deps = [":quant_td_files"], @@ -246,12 +215,7 @@ gentbl_cc_library( gentbl_cc_library( name = "convert_tpu_model_to_cpu_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/convert_tpu_model_to_cpu.inc", - ), - ], + tbl_outs = {"passes/convert_tpu_model_to_cpu.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/convert_tpu_model_to_cpu.td", deps = [":quant_td_files"], @@ -260,26 +224,25 @@ gentbl_cc_library( gentbl_cc_library( name = "post_quantize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/post_quantize.inc", - ), - ], + tbl_outs = {"passes/post_quantize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/post_quantize.td", deps = [":quant_td_files"], ) +gentbl_cc_library( + name = "tf_post_quantize_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/tf_post_quantize.inc": ["-gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_post_quantize.td", + deps = [":tf_quant_td_files"], +) + gentbl_cc_library( name = "preprocess_op_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/preprocess_op.inc", - ), - ], + tbl_outs = {"passes/preprocess_op.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/preprocess_op.td", deps = [":quant_td_files"], @@ -319,12 +282,7 @@ cc_library( gentbl_cc_library( name = "replace_cast_hacks_with_tf_xla_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/replace_cast_hacks_with_tf_xla_ops.inc", - ), - ], + tbl_outs = {"passes/replace_cast_hacks_with_tf_xla_ops.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/replace_cast_hacks_with_tf_xla_ops.td", deps = [":quant_td_files"], @@ -475,6 +433,52 @@ cc_library( alwayslink = True, ) +cc_library( + name = "tf_passes", + srcs = [ + "passes/prepare_lifting.inc", + "passes/tf_add_quantization_unit_loc.cc", + "passes/tf_convert_fake_quant_to_qdq.cc", + "passes/tf_post_quantize.cc", + "passes/tf_post_quantize.inc", + "passes/tf_prepare_lifting.cc", + ], + hdrs = [ + "passes/tf_passes.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":prepare_lifting_inc_gen", + ":quantization_options_proto_cc", + ":remove_identity_op_pattern", + ":tf_post_quantize_inc_gen", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/quantization_lib", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib:tf_quantization_config", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:constant_fold", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:quantization_unit_loc", + "//tensorflow/compiler/mlir/quantization/tensorflow/ops:tf_op_quant_spec", + "//tensorflow/compiler/mlir/quantization/tensorflow/utils:temp_fake_quant_utils", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], + # Alwayslink is required for registering the MLIR passes. + # TODO(b/255530126): Split the pass registration from the definitions to avoid binary size bloat. + alwayslink = True, +) + cc_library( name = "quantize_preprocess", srcs = [ @@ -487,7 +491,6 @@ cc_library( deps = [ ":passes", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite/stablehlo:tf_stablehlo", "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:pass_pipeline", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", @@ -495,6 +498,7 @@ cc_library( "//tensorflow/compiler/mlir/stablehlo:fuse_convolution_pass", "//tensorflow/compiler/mlir/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", "//tensorflow/compiler/mlir/stablehlo:rename_entrypoint_to_main", + "//tensorflow/compiler/mlir/stablehlo:tf_stablehlo", "//tensorflow/compiler/mlir/stablehlo:unfuse_batch_norm_pass", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", @@ -593,6 +597,7 @@ tf_cc_binary( srcs = ["passes/tf_quant_opt.cc"], deps = [ ":passes", + ":tf_passes", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc index e72a71f4a35da0..09dfcae58466b1 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc @@ -74,7 +74,7 @@ proto file.)doc"); class CalibrationStatisticsSaverOp : public OpKernel { public: explicit CalibrationStatisticsSaverOp( - absl::Nonnull context) + OpKernelConstruction* absl_nonnull context) : OpKernel(context) { std::string output_file_path; OP_REQUIRES_OK(context, @@ -128,7 +128,7 @@ class CalibrationStatisticsSaverOp : public OpKernel { } } - void Compute(absl::Nonnull context) override { + void Compute(OpKernelContext* absl_nonnull context) override { for (int idx = 0; idx < ids_.size(); ++idx) { AssignIfNotExists( ids_[idx], static_cast(calibration_methods_[idx])); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc index 57b8b23de72cc4..0b73b9c550b62a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc @@ -164,17 +164,25 @@ class AddDumpTensorOpPass }; template -class AddDumpTensorOp - : public OpRewritePattern::SplitMatchAndRewrite { +class AddDumpTensorOp : public OpRewritePattern { public: // Does not take ownership of context, which must refer to a valid value that // outlives this object. explicit AddDumpTensorOp(MLIRContext *context, DebuggerType debugger_type, std::string log_dir_path) - : OpRewritePattern::SplitMatchAndRewrite(context), + : OpRewritePattern(context), debugger_type_(debugger_type), log_dir_path_(std::move(log_dir_path)) {} + LogicalResult matchAndRewrite(LiftedOpT op, + PatternRewriter &rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + private: SmallVector CreateDumpAttributes( PatternRewriter &rewriter, const StringRef folder_name, @@ -204,7 +212,7 @@ class AddDumpTensorOp return symbol_table.insert(new_ref_func); } - LogicalResult match(LiftedOpT op) const override { + LogicalResult match(LiftedOpT op) const { if (!op->hasAttr(kQuantTraitAttrName) || op->getNumResults() != 1) { return failure(); } @@ -219,7 +227,7 @@ class AddDumpTensorOp return success(); } - void rewrite(LiftedOpT op, PatternRewriter &rewriter) const override { + void rewrite(LiftedOpT op, PatternRewriter &rewriter) const { // Only support ops with 1 results Value result = op->getResult(0); rewriter.setInsertionPointAfterValue(result); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc index 7fea73725af761..50d4030083d99b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc @@ -47,13 +47,22 @@ class CastBf16OpsToF32Pass void runOnOperation() override; }; -class CastBf16OpsToF32 : public RewritePattern::SplitMatchAndRewrite { +class CastBf16OpsToF32 : public RewritePattern { public: explicit CastBf16OpsToF32(MLIRContext* context) - : SplitMatchAndRewrite(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } private: - LogicalResult match(Operation* op) const override { + LogicalResult match(Operation* op) const { if (isa(op) || op->getName().hasTrait()) { return failure(); @@ -71,7 +80,7 @@ class CastBf16OpsToF32 : public RewritePattern::SplitMatchAndRewrite { return failure(); } - void rewrite(Operation* op, PatternRewriter& rewriter) const override { + void rewrite(Operation* op, PatternRewriter& rewriter) const { // Casts inputs of the operation. for (int i = 0; i < op->getNumOperands(); i++) { Value input = op->getOperand(i); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc index 886f9cd28a127b..ec7ffefd2d43f7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc @@ -86,7 +86,7 @@ std::optional GetCompsiteFunctionName(Operation *op) { return entry_function_attr.getValue(); } else { TF::PartitionedCallOp call_op = dyn_cast_or_null(op); - const auto f_attr = call_op.getFAttr().dyn_cast(); + const auto f_attr = mlir::dyn_cast(call_op.getFAttr()); if (!f_attr) return std::nullopt; return f_attr.getValue(); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td index d56ee05dc071dc..9e0f26d8793684 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td @@ -25,8 +25,8 @@ include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td" //===----------------------------------------------------------------------===// class IsFusedOpEndsWith : AttrConstraint< - CPred<"!$_self.cast().empty() && " - "$_self.cast()[$_self.cast().size() - 1]." + CPred<"!llvm::cast($_self).empty() && " + "llvm::cast($_self)[llvm::cast($_self).size() - 1]." "cast<::mlir::StringAttr>().str() == \"" # OpName # "\"">, "Matching fused '" # OpName # "' op at the end">; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td index d75a01be7d2182..338fdc91fc521c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td @@ -83,21 +83,21 @@ class HasEqualElementSize shape_1, list shape_2> : Constraint< "Checks if the given dimensions contain the same number of elements.">; def ReshapableTo1DTensor : Constraint< - CPred<"quant::ReshapableTo1DTensor($0.getType().cast())">, + CPred<"quant::ReshapableTo1DTensor(llvm::cast($0.getType()))">, "Checks if the value dims are all ones except the right most dim">; def ReshapeTo1DTensor : NativeCodeCall< "quant::ReshapeTo1DTensor($_builder, $_loc, $0)">; def HasEqualShape : Constraint().hasRank() && " - "$1.getType().cast().hasRank() && " - "$0.getType().cast().getShape() == $1.getType().cast().getShape()">, + "llvm::cast($0.getType()).hasRank() && " + "llvm::cast($1.getType()).hasRank() && " + "llvm::cast($0.getType()).getShape() == llvm::cast($1.getType()).getShape()">, "Checks if the shapes of tensors are same.">; // Make the 1D value $0 broadcastable with the shape of $1. def MakeOneDimValueBroadcastable : NativeCodeCall< - "MakeOneDimValueBroadcastable($_builder, $_loc, $0, $1.getType().cast())">; + "MakeOneDimValueBroadcastable($_builder, $_loc, $0, llvm::cast($1.getType()))">; // Match convolution op with "NHWC" data format or matmul op. def SupportedAffineOpMatcher : NativeCodeCall< diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc index c18d76327ca844..1590a447a131c3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc @@ -85,7 +85,7 @@ struct TFQuantizationBase Operation* quantized_op, const CustomMap& custom_op_map) { auto call_op = cast(quantized_op); StringRef function_name = - call_op.getFAttr().cast().getValue(); + llvm::cast(call_op.getFAttr()).getValue(); // The below can be generalized as there are more read-only ops added such // as slice. const bool is_gather = function_name.contains("gather"); @@ -98,7 +98,7 @@ struct TFQuantizationBase const CustomMap& custom_op_map) { auto call_op = cast(quantized_op); StringRef function_name = - call_op.getFAttr().cast().getValue(); + llvm::cast(call_op.getFAttr()).getValue(); // The below can be generalized as there are more read-only ops added such // as slice. bool is_gather = false; @@ -221,16 +221,16 @@ class QuantizeSameScaleOpsPattern inputs.reserve(quantizing_op->getNumOperands()); for (const auto& operand : quantizing_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (isa(operand_type)) { inputs.push_back(operand); continue; } - Type elem_type = operand_type.cast().getElementType(); + Type elem_type = llvm::cast(operand_type).getElementType(); if (auto dq_op = dyn_cast_or_null( operand.getDefiningOp())) { - auto dq_arg_type = dq_op.getArg().getType().cast(); - auto qtype = dq_arg_type.getElementType().cast(); + auto dq_arg_type = llvm::cast(dq_op.getArg().getType()); + auto qtype = llvm::cast(dq_arg_type.getElementType()); auto scast_op = rewriter.create( dq_op->getLoc(), dq_arg_type.clone(qtype.getStorageType()), dq_op.getArg()); @@ -253,12 +253,12 @@ class QuantizeSameScaleOpsPattern llvm::enumerate(quantizing_op->getResults())) { Value result = enumerated_result.value(); Type result_type = result.getType(); - if (result_type.isa()) { + if (isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; } - auto result_tensor_type = result_type.cast(); + auto result_tensor_type = llvm::cast(result_type); // If the user is the Quantize op, it must be the only user. if (result.hasOneUse() && llvm::isa(*result.user_begin())) { @@ -266,10 +266,8 @@ class QuantizeSameScaleOpsPattern llvm::cast(*result.user_begin()); outputs_replaced.insert( {user.getResult(), enumerated_result.index()}); - auto qtype = user.getType() - .cast() - .getElementType() - .cast(); + auto qtype = llvm::cast( + llvm::cast(user.getType()).getElementType()); output_types.push_back( result_tensor_type.clone(qtype.getStorageType())); } else if (!result_tensor_type.getElementType().isF32()) { @@ -338,7 +336,7 @@ class QuantizeSameScaleOpsPattern // Check if the preceding op is a quantized same-scale op. if (llvm::isa(preceding_op)) { auto sc_op = llvm::cast(preceding_op); - auto sc_arg_type = sc_op.getArg().getType().dyn_cast(); + auto sc_arg_type = llvm::dyn_cast(sc_op.getArg().getType()); if (sc_arg_type.getElementType().isInteger(8)) { return true; } @@ -364,7 +362,8 @@ class QuantizeSameScaleOpsPattern // Check if the preceding op is a quantized same-scale op. if (llvm::isa(following_op)) { auto sc_op = llvm::cast(following_op); - auto sc_arg_type = sc_op.getResult().getType().dyn_cast(); + auto sc_arg_type = + llvm::dyn_cast(sc_op.getResult().getType()); if (sc_arg_type.getElementType().isInteger(8)) { return true; } @@ -381,28 +380,28 @@ class QuantizeSameScaleOpsPattern return false; } - const auto f_attr = call_op.getFAttr().dyn_cast(); + const auto f_attr = llvm::dyn_cast(call_op.getFAttr()); if (!f_attr || !f_attr.getValue().starts_with("composite_")) { return false; } bool has_quantized_types = false; for (Value input : call_op.getArgs()) { - if (auto type = input.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (auto type = llvm::dyn_cast(input.getType())) { + if (isa(type.getElementType())) { return false; } - if (type.getElementType().isa()) { + if (isa(type.getElementType())) { has_quantized_types = true; } } } for (Value output : call_op.getOutput()) { - if (auto type = output.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (auto type = llvm::dyn_cast(output.getType())) { + if (isa(type.getElementType())) { return false; } - if (type.getElementType().isa()) { + if (isa(type.getElementType())) { has_quantized_types = true; } } @@ -432,10 +431,11 @@ struct QuantizeAvgPoolOpPattern if (!preceding_sc_op) return failure(); // Check if the same-scale requirement is met. - auto dq_arg_type = preceding_sc_op.getArg().getType().cast(); - auto qtype = dq_arg_type.getElementType().cast(); - auto q_result_type = sc_op.getType().cast(); - auto out_qtype = q_result_type.getElementType().cast(); + auto dq_arg_type = + llvm::cast(preceding_sc_op.getArg().getType()); + auto qtype = llvm::cast(dq_arg_type.getElementType()); + auto q_result_type = llvm::cast(sc_op.getType()); + auto out_qtype = llvm::cast(q_result_type.getElementType()); if (qtype != out_qtype) { avg_pool_op.emitError( "The preceding StorageCastOp and the following " diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc index 33fbe5406040f7..ae3a25b32199e7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc @@ -62,28 +62,25 @@ class RemoveVariableInitializationByConstPass // pattern. `tf.VarHandleOp` and `tf.Const` are removed unless they are used by // other ops. struct RemoveVariableAssignmentByConst - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { // Inherit the constructors. - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(TF::AssignVariableOp assign_op) const override { + LogicalResult matchAndRewrite(TF::AssignVariableOp assign_op, + PatternRewriter& rewriter) const override { Value resource_operand = assign_op.getOperand(0); Value assigned_value_operand = assign_op.getOperand(1); - if (isa(resource_operand.getDefiningOp()) && - isa(assigned_value_operand.getDefiningOp())) { - return success(); - } else { + if (!isa(resource_operand.getDefiningOp()) || + !isa(assigned_value_operand.getDefiningOp())) { return failure(); } - } - void rewrite(TF::AssignVariableOp assign_op, - PatternRewriter& rewriter) const override { // `TF::ConstOp` and `TF::VarHandleOp` are not manually erased. // `applyPatternsGreedily` performs dead code elimination and unsed // ops will be erased during the optimization. rewriter.eraseOp(assign_op); + return success(); } }; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc index d1e46b4eb56031..ec5adb87d88c8c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc @@ -628,8 +628,7 @@ Value CreateXlaConvOp(OpBuilder &builder, Location loc, Value input, Value filter, Value input_zp, Value conv_output, ArrayAttr strides, ArrayAttr dilations, StringAttr conv_padding, ArrayAttr explicit_paddings, - int feature_group_cnt, bool four_bit = false, - int num_dims = 4) { + int feature_group_cnt, int num_dims = 4) { int32_t input_zp_value; if (!GetSplatValue(input_zp, input_zp_value)) { emitError(loc, @@ -675,14 +674,6 @@ Value CreateXlaConvOp(OpBuilder &builder, Location loc, Value input, conv_padding, explicit_paddings, padding, num_dims); std::string precision_config_str; - if (four_bit) { - input = PackOperand(builder, loc, input, /*pack_dim=*/num_dims - 1); - filter = PackOperand(builder, loc, filter, /*pack_dim=*/num_dims - 2); - xla::PrecisionConfig precision_config; - precision_config.add_operand_precision(xla::PrecisionConfig::PACKED_NIBBLE); - precision_config.add_operand_precision(xla::PrecisionConfig::PACKED_NIBBLE); - precision_config_str = precision_config.SerializeAsString(); - } Value xla_conv_output = builder .create( @@ -774,14 +765,13 @@ Value CreateXlaConvOpFromTfConv3dOp(OpBuilder &builder, Location loc, return CreateXlaConvOp(builder, loc, input, filter, input_zp, conv_output, strides, dilations, conv_padding, /*explicit_paddings=*/nullptr, feature_group_cnt, - /*four_bit=*/false, /*num_dims=*/5); + /*num_dims=*/5); } // Helper function to create an XlaDotV2Op. Value CreateXlaDotV2Op(OpBuilder &builder, Location loc, Value input, Value weight, Value input_zp, Value weight_zp, - Value output, const xla::DotDimensionNumbers &dnums, - bool four_bit = false) { + Value output, const xla::DotDimensionNumbers &dnums) { int32_t input_zp_value = 0; int32_t weight_zp_value = 0; if (input_zp != nullptr && !GetSplatValue(input_zp, input_zp_value)) { @@ -797,14 +787,6 @@ Value CreateXlaDotV2Op(OpBuilder &builder, Location loc, Value input, } std::string precision_config_str; - if (four_bit) { - input = PackOperand(builder, loc, input, /*pack_dim=*/1); - weight = PackOperand(builder, loc, weight, /*pack_dim=*/0); - xla::PrecisionConfig precision_config; - precision_config.add_operand_precision(xla::PrecisionConfig::PACKED_NIBBLE); - precision_config.add_operand_precision(xla::PrecisionConfig::PACKED_NIBBLE); - precision_config_str = precision_config.SerializeAsString(); - } Value dot_result = builder diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_add_quantization_unit_loc.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_add_quantization_unit_loc.cc new file mode 100644 index 00000000000000..208b9d0cb83f47 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_add_quantization_unit_loc.cc @@ -0,0 +1,202 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/match.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir { +namespace quant { +namespace { + +using QuantizationUnit = + tensorflow::quantization::UnitWiseQuantizationSpec::QuantizationUnit; + +// Adds QuantizationUnitLoc to quantizable layers. +class TFAddQuantizationUnitLocPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TFAddQuantizationUnitLocPass) + explicit TFAddQuantizationUnitLocPass() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-add-quantization-unit-loc"; + } + StringRef getDescription() const final { + return "Add QuantizationUnitLoc to quantizable layers."; + } + + private: + void runOnOperation() override; +}; + +// TF graph nodes are imported with one of following location patterns: +// FusedLoc[NameLoc(op_type:), ..., NameLoc(node_name@func_name)] or +// FusedLoc[NameLoc(op_type:), ..., CallSiteLoc(node_name@func_name)]. See +// tensorflow/compiler/mlir/tensorflow/translate/import_model.cc for more +// details. +bool IsImportLocPattern(FusedLoc loc) { + ArrayRef locations = mlir::cast(loc).getLocations(); + if (locations.size() < 2 || !isa(locations.front())) return false; + + StringRef op_type_with_suffix = + mlir::cast(locations.front()).getName().strref(); + if (!op_type_with_suffix.ends_with(":")) return false; + + return absl::c_all_of(locations, [](Location loc) { + return isa(loc) || + (isa(loc) && + isa(mlir::cast(loc).getCallee())); + }); +} + +// Finds the pattern of the location created by `ImporterBase::GetLocation` +// in `tensorflow/compiler/mlir/tensorflow/translate/import_model.cc`. +void FindQuantizationUnitsRecursively(Location loc, + SmallVector& units) { + if (!isa(loc)) return; + + auto set_node_and_func_name = [](QuantizationUnit& new_unit, + StringRef name_loc_id) { + if (name_loc_id.contains("@")) { + new_unit.set_node_name(name_loc_id.split('@').first.str()); + new_unit.set_func_name(name_loc_id.split('@').second.str()); + } else { + new_unit.set_node_name(name_loc_id.str()); + } + }; + + ArrayRef locations = mlir::cast(loc).getLocations(); + if (IsImportLocPattern(mlir::cast(loc))) { + QuantizationUnit new_unit; + // Op type is a NameLoc with the ":" suffix. + StringRef op_type_with_suffix = + mlir::cast(locations.front()).getName().strref(); + StringRef op_type = + op_type_with_suffix.substr(0, op_type_with_suffix.size() - 1); + new_unit.set_op_type(op_type.str()); + + if (isa(locations.back())) { + StringRef name_loc_id = + mlir::cast(locations.back()).getName().strref(); + set_node_and_func_name(new_unit, name_loc_id); + } else { + Location callee = mlir::cast(locations.back()).getCallee(); + StringRef name_loc_id = mlir::cast(callee).getName().strref(); + set_node_and_func_name(new_unit, name_loc_id); + } + units.push_back(new_unit); + } else { + for (Location child_loc : locations) { + FindQuantizationUnitsRecursively(child_loc, units); + } + } +} + +// Finds the QuantizationUnit from location. +std::optional FindQuantizationUnit(Operation* op) { + SmallVector quant_units; + FindQuantizationUnitsRecursively(op->getLoc(), quant_units); + + if (quant_units.size() == 1) { + return *quant_units.begin(); + } + // Among units, return the one with the same type as given op. + StringRef given_op_type = op->getName().getStringRef(); + for (const QuantizationUnit& quant_unit : quant_units) { + if (absl::StrContains(given_op_type.lower(), + StringRef(quant_unit.op_type()).lower())) { + return quant_unit; + } + } + + return std::nullopt; +} + +class AddQuantizationUnitLoc : public RewritePattern { + public: + explicit AddQuantizationUnitLoc(MLIRContext* context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + private: + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + if (!IsOpWithQuantizableTrait(op) || + FindQuantizationUnitFromLoc(op->getLoc()).has_value()) { + return failure(); + } + + std::optional quantization_unit = + FindQuantizationUnit(op); + if (!quantization_unit.has_value()) return failure(); + + if (quantization_unit->func_name().empty()) { + std::string func_name = + op->getParentOfType().getSymNameAttr().str(); + quantization_unit->set_func_name(func_name); + } + QuantizationUnitLoc unit_loc(getContext(), quantization_unit.value()); + op->setLoc(unit_loc); + + return success(); + } +}; + +void TFAddQuantizationUnitLocPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + func::FuncOp func = getOperation(); + + patterns.add(ctx); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + func.emitError() << "tf-quant-add-quantization-unit-loc pattern " + "conversion did not converge."; + signalPassFailure(); + } +} + +} // namespace + +// Creates an instance of `TFAddQuantizationUnitLocPass`. +std::unique_ptr> +CreateTFAddQuantizationUnitLocPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_fake_quant_to_qdq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_fake_quant_to_qdq.cc new file mode 100644 index 00000000000000..d1a1fd04b30440 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_fake_quant_to_qdq.cc @@ -0,0 +1,90 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project // IWYU pragma: keep, for applyPatternsGreedily +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +namespace mlir { +namespace quant { +namespace { + +class TFConvertFakeQuantToQdqPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TFConvertFakeQuantToQdqPass) + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-convert-fake-quant-to-qdq"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Convert Fake Quant op to quant.qcast and quant.dcast pairs"; + } + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; + +static PassRegistration pass; + +void TFConvertFakeQuantToQdqPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + func::FuncOp func = getOperation(); + + if (failed(tf_quant::ConvertFakeQuantOps( + func, ctx, /*use_fake_quant_num_bits=*/false))) { + func.emitError() << "quant-convert-fake-quant-to-qdq pass failed."; + signalPassFailure(); + } + + // For removing dead FakeQuant* ops + RewritePatternSet patterns(ctx); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> +CreateTFConvertFakeQuantToQdqPass() { + return std::make_unique(); +} + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h new file mode 100644 index 00000000000000..c227dc7ad01969 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h @@ -0,0 +1,54 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_TF_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_TF_PASSES_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir { +namespace quant { + +// Creates a pass that add QuantizationUnitLoc to quantizable layers. +std::unique_ptr> +CreateTFAddQuantizationUnitLocPass(); + +// Converts FakeQuant ops to quant.qcast and quant.dcast (QDQ) pairs. +std::unique_ptr> +CreateTFConvertFakeQuantToQdqPass(); + +// Apply graph optimizations such as fusing and constant folding to prepare +// lifting. +std::unique_ptr> CreateTFPrepareLiftingPass( + tensorflow::quantization::OpSet target_opset); + +// Creates an instance of the PostQuantize pass, which will remove unnecessary +// ops from the final quantized graph. +std::unique_ptr> CreatePostQuantizePass(); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_TF_PASSES_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.cc new file mode 100644 index 00000000000000..33da66846b2aeb --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.cc @@ -0,0 +1,158 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation pass applies some clean up steps after quantization. + +#include +#include +#include + +#include "llvm/Support/Casting.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project + +//===----------------------------------------------------------------------===// +// The post-quantize Passes. +// +namespace mlir { +namespace quant { +namespace { + +// Applies all the clean up steps after quantization. +class TFPostQuantizePass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TFPostQuantizePass) + + // Constructor used by the PassRegistration. This will remove the adaptor ops. + explicit TFPostQuantizePass() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-post-quantize"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Apply post quantization clean up after quantization"; + } + + void runOnOperation() override; +}; + +enum RemoveVolatileOpsType { + // Remove all volatile quant-dequant ops. + kPreserveNone, + // Preserve volatile quant-dequants for input and output ops. + kPreserveInputsAndOutputs, +}; + +// Remove the back-to-back quantize and dequantize ops with volatile attribute. +template +struct RemoveVolatileOps + : public OpRewritePattern { + explicit RemoveVolatileOps(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(mlir::quant::ir::DequantizeCastOp op, + PatternRewriter& rewriter) const override { + auto input_op = op.getArg().getDefiningOp(); + if (auto q = llvm::dyn_cast_or_null( + input_op)) { + if (!q->getAttr(kVolatileOpAttrName)) return failure(); + + if (remove_volatile_ops_type == kPreserveInputsAndOutputs) { + // Don't remove leading and trailing QDQ for PTQ workflow, so the io + // modifying lib can work correctly. + if (!q.getArg().getDefiningOp()) return failure(); + if (op->hasOneUse() && + op->user_begin()->hasTrait()) + return failure(); + } + // If the quantize op is a requantize op, it is being used in other scale + // adjustments and should be kept. Instead, moving dequantize op before + // the requantize op to remove the unnecessary requantize op. + if (auto qtype = + QuantizedType::getQuantizedElementType(q.getArg().getType())) { + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), q.getArg()); + return success(); + } + + op.replaceAllUsesWith(q.getArg()); + return success(); + } + return failure(); + } +}; + +// The StorageCastOp is used to cast from a quantized type to its storage type +// or the opposite. If none of its input and output is quantized, the op has +// no effect and should be removed. +class RemoveRedundantScast + : public mlir::OpRewritePattern { + public: + explicit RemoveRedundantScast(MLIRContext* context) + : OpRewritePattern(context) {} + + private: + LogicalResult matchAndRewrite(mlir::quant::ir::StorageCastOp scast_op, + PatternRewriter& rewriter) const override { + if (QuantizedType::getQuantizedElementType(scast_op.getArg().getType()) || + QuantizedType::getQuantizedElementType(scast_op.getType())) { + return failure(); + } + + scast_op.replaceAllUsesWith(scast_op.getArg()); + return success(); + } +}; + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.inc" + +void TFPostQuantizePass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + auto func = getOperation(); + auto* ctx = func.getContext(); + patterns.add, + RemoveVolatileOps, RemoveRedundantScast>(ctx); + populateWithGenerated(patterns); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace + +// Creates an instance of the TensorFlow dialect PostQuantize pass. +std::unique_ptr> CreateTFPostQuantizePass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.td new file mode 100644 index 00000000000000..e5cea091c8f194 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.td @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" + +// Re-orders the Identity op following a quantized composite function. This +// allows the QuantizeCompositeFunctionsPass to merge the DequantizeCast with +// the quantized composite function to optimize the requantization part. +def ReorderIdentityFollowingQuantizedFunction : Pat< + (Quantization_DequantizeCastOp:$output + (Quantization_StorageCastOp + (TF_IdentityOp + (Quantization_StorageCastOp $value)))), + (TF_IdentityOp + (Quantization_DequantizeCastOp + $value, (returnType (GetValueType $output))))>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_lifting.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_lifting.cc new file mode 100644 index 00000000000000..cff41ef6cdadd9 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_lifting.cc @@ -0,0 +1,349 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_identity_op_pattern.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" + +namespace mlir { +namespace quant { +namespace { + +using ::tensorflow::quantization::OpSet; + +class TFPrepareLiftingPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TFPrepareLiftingPass) + + TFPrepareLiftingPass() = default; + + explicit TFPrepareLiftingPass(OpSet op_set) { op_set_ = op_set; } + + TFPrepareLiftingPass(const TFPrepareLiftingPass& other) { + op_set_ = other.op_set_; + } + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-prepare-lifting"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Apply graph optimizations such as fusing and constant folding to " + "prepare lifting."; + } + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + void runOnOperation() override; + + private: + Option op_set_{ + *this, "target-opset", llvm::cl::init(OpSet::TF), + llvm::cl::desc("Choose target opset."), + llvm::cl::values( + clEnumValN(OpSet::TF, "TF", + "Uses TF ops that mimic quantization behavior"), + clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"), + clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED", + "Uses TF Uniform Quantized ops"))}; +}; + +// Check if given indices in `val1` has same number of elements as given +// indices in `val2`. +bool HasEqualElementSize(Value val1, Value val2, ArrayRef val1_indices, + ArrayRef val2_indices) { + ShapedType val1_shape = mlir::cast(val1.getType()); + ShapedType val2_shape = mlir::cast(val2.getType()); + if (!val1_shape.hasRank() || !val2_shape.hasRank()) return false; + + int val1_result = 1; + int val2_result = 1; + for (auto idx : val1_indices) { + if (idx < 0) idx = idx + val1_shape.getRank(); + if (idx >= val1_shape.getRank() || val1_shape.isDynamicDim(idx)) { + return false; + } + val1_result *= val1_shape.getDimSize(idx); + } + + for (auto idx : val2_indices) { + if (idx < 0) idx = idx + val2_shape.getRank(); + if (idx >= val2_shape.getRank() || val2_shape.isDynamicDim(idx)) { + return false; + } + val2_result *= val2_shape.getDimSize(idx); + } + + return val1_result == val2_result; +} + +// Checks if a shape has dim sizes of all ones except the right most dim. +bool ReshapableTo1DTensor(ShapedType rhs_shape) { + for (auto rank = 0; rank < rhs_shape.getRank() - 1; rank++) { + if (rhs_shape.getDimSize(rank) != 1) { + return false; + } + } + return true; +} + +Value ReshapeTo1DTensor(OpBuilder& builder, Location loc, Value value) { + auto shape = mlir::cast(value.getType()); + if (shape.getRank() != 1) { + SmallVector new_shape; + new_shape.push_back(shape.getNumElements()); + value = builder.create( + loc, value, Create1DConstValue(builder, loc, new_shape)); + } + return ConstantFoldOpIfPossible(value.getDefiningOp()).front(); +} + +// Matches convolution op with "NHWC" data format or matmul op with false adj_y. +// The list of supported ops in this function is: +// - Conv2DOp +// - Conv3DOp +// - DepthwiseConv2dNativeOp +// - MatMulOp +// - BatchMatMulV2Op +LogicalResult MatchSupportedAffineOp(Operation* op, Value& binding_output, + Value& binding_input, + Value& binding_weight) { + bool is_supported_affine_op = false; + if (llvm::isa(op)) { + if (const auto data_format = op->getAttrOfType("data_format")) { + is_supported_affine_op = + data_format.getValue() == "NHWC" || data_format.getValue() == "NDHWC"; + } + } else if (llvm::isa(op)) { + if (const auto adj_y = op->getAttrOfType("adj_y")) { + is_supported_affine_op = !adj_y.getValue(); + } + } else if (llvm::isa(op)) { + if (const auto adj_y = op->getAttrOfType("transpose_b")) { + is_supported_affine_op = !adj_y.getValue(); + } + } + + if (!is_supported_affine_op) return failure(); + + // Bind input, output and weight to the given values. + binding_output = op->getResult(0); + binding_input = op->getOperand(0); + binding_weight = op->getOperand(1); + return success(); +} + +// Makes the 1D value broadcastable with the `rhs_shape`. +Value MakeOneDimValueBroadcastable(OpBuilder& builder, Location loc, + Value value, ShapedType rhs_shape) { + ShapedType value_shape = mlir::dyn_cast_or_null(value.getType()); + if (!value_shape || value_shape.getRank() != 1 || + !value_shape.hasStaticShape() || !rhs_shape.hasStaticShape()) { + return {}; + } + + int64_t num_elements = value_shape.getNumElements(); + SmallVector new_shape; + for (auto idx : llvm::reverse(llvm::seq(0, rhs_shape.getRank()))) { + const int64_t rhs_dim = rhs_shape.getDimSize(idx); + if (num_elements % rhs_dim != 0) { + return {}; + } + new_shape.push_back(rhs_dim); + num_elements = num_elements / rhs_dim; + if (num_elements == 1) break; + } + absl::c_reverse(new_shape); + + auto reshape_op = builder.create( + loc, value, Create1DConstValue(builder, loc, new_shape)); + return ConstantFoldOpIfPossible(reshape_op).front(); +} + +// Checks if a value can be symmetrically quantized. +bool CanBeSymmetricallyQuantized(Value weight) { + auto dq_op = weight.getDefiningOp(); + if (!dq_op) return true; + + auto qtype = + mlir::cast(dq_op.getArg().getType()).getElementType(); + if (auto uniform_type = llvm::dyn_cast_or_null(qtype)) { + return uniform_type.getZeroPoint() == 0; + } else if (auto per_axis_type = + llvm::dyn_cast_or_null(qtype)) { + return absl::c_all_of(per_axis_type.getZeroPoints(), + [](int64_t x) { return x == 0; }); + } + return false; +} + +// Multiplies two 1D arrays with broadcasting support. +template +SmallVector MultiplyTwoArrays(ArrayRef a, ArrayRef b) { + auto get_value_at = [](ArrayRef v, size_t i) -> T { + if (v.size() == 1) return v.front(); + return v[i]; + }; + + size_t max_size = std::max(a.size(), b.size()); + SmallVector result(max_size); + for (size_t i : llvm::seq(0, max_size)) { + result[i] = get_value_at(a, i) * get_value_at(b, i); + } + return result; +} + +// Multiplies the value followed by a FakeQuant op and adjusts the quantization +// params. This function only supports symmetrically quantized values. +Value MultiplyFakeQuantValue(OpBuilder& builder, Location loc, Value value, + Value multiplier) { + auto dq_op = value.getDefiningOp(); + if (!dq_op) { + auto mul_op = builder.create(loc, value, multiplier); + return mul_op.getResult(); + } + auto q_op = dq_op.getArg().getDefiningOp(); + if (!q_op) return {}; + + Value float_value = q_op.getArg(); + Value new_value = builder.create(loc, float_value, multiplier); + auto new_value_type = mlir::cast(new_value.getType()); + + // Get multiplier value in double. + DenseFPElementsAttr multiplier_attr; + if (!matchPattern(multiplier, m_Constant(&multiplier_attr)) || + mlir::cast(multiplier_attr.getType()).getRank() > 1) { + return {}; + } + std::vector multiplier_values; + absl::c_transform(multiplier_attr, std::back_inserter(multiplier_values), + [](auto v) { return FloatAttr::getValueAsDouble(v); }); + ArrayRef multiplier_array(multiplier_values.data(), + multiplier_values.size()); + + // Multiply the quantization parameters by the multiplier. + QuantizedType new_qtype; + auto element_type = mlir::cast(q_op.getType()).getElementType(); + if (auto uniform_type = llvm::dyn_cast(element_type)) { + if (multiplier_attr.isSplat()) { + double new_scale = multiplier_array.front() * uniform_type.getScale(); + new_qtype = UniformQuantizedType::get( + uniform_type.getFlags(), uniform_type.getStorageType(), + uniform_type.getExpressedType(), new_scale, + uniform_type.getZeroPoint(), uniform_type.getStorageTypeMin(), + uniform_type.getStorageTypeMax()); + } else { + auto new_scales = + MultiplyTwoArrays(multiplier_array, {uniform_type.getScale()}); + int32_t quantized_dim = new_value_type.getRank() - 1; + auto new_zero_points = + SmallVector(new_scales.size(), uniform_type.getZeroPoint()); + new_qtype = UniformQuantizedPerAxisType::get( + uniform_type.getFlags(), uniform_type.getStorageType(), + uniform_type.getExpressedType(), new_scales, new_zero_points, + quantized_dim, uniform_type.getStorageTypeMin(), + uniform_type.getStorageTypeMax()); + } + } else if (auto per_axis_type = + llvm::dyn_cast_or_null( + element_type)) { + auto new_scales = + MultiplyTwoArrays(multiplier_array, per_axis_type.getScales()); + new_qtype = UniformQuantizedPerAxisType::get( + per_axis_type.getFlags(), per_axis_type.getStorageType(), + per_axis_type.getExpressedType(), new_scales, + per_axis_type.getZeroPoints(), per_axis_type.getQuantizedDimension(), + per_axis_type.getStorageTypeMin(), per_axis_type.getStorageTypeMax()); + } + + auto quantize = builder.create( + q_op.getLoc(), new_value_type.clone(new_qtype), new_value); + auto dequantize = builder.create( + dq_op.getLoc(), new_value_type, quantize.getResult()); + return dequantize.getResult(); +} + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.inc" + +void TFPrepareLiftingPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + auto func = getOperation(); + + // The pattern includes decomposing batch normalization ops, fusing add/mul + // with a constant operand to a preceding affine operation. + RewritePatternSet patterns(ctx); + populateWithGenerated(patterns); + patterns.add(ctx); + if (op_set_ != OpSet::XLA) { + // Convert Einsum into BatchMatMul for non-XLA opsets. + // For the uniform opset, it is requested to maintain the BatchMatmul logic. + // For the TF opset, since we need to test the effect we remain it as a + // future work. + patterns.add(ctx); + } + + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + func.emitError() << "tf-quant-prepare-lifting failed."; + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> CreateTFPrepareLiftingPass( + const OpSet target_opset) { + return std::make_unique(target_opset); +} + +static PassRegistration pass; + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc index 41716fd6fbe966..b77ffda92506f1 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc @@ -32,7 +32,6 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" @@ -41,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" #include "tensorflow/compiler/mlir/stablehlo/transforms/rename_entrypoint_to_main.h" #include "tensorflow/compiler/mlir/stablehlo/transforms/stablehlo_passes.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_add_quantization_unit_loc.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_add_quantization_unit_loc.mlir new file mode 100644 index 00000000000000..81c735b7513328 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_add_quantization_unit_loc.mlir @@ -0,0 +1,50 @@ +// RUN: tf-quant-opt %s -mlir-print-debuginfo -mlir-print-local-scope -tf-quant-add-quantization-unit-loc | FileCheck %s + +func.func @conv2d_unmatching_loc_pattern(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xbf16> + %1 = "tf.Conv2D"(%0, %cst) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1]} + : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> loc("Model/conv2d") + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xf32> + %3 = "tf.IdentityN"(%2) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %3 : tensor<1x3x2x2xf32> +// CHECK: tf.Conv2D +// CHECK-SAME: loc("Model/conv2d") +} + +func.func @conv2d_with_valid_loc(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xbf16> + %1 = "tf.Conv2D"(%0, %cst) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1]} + : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> loc(fused["Conv2D:", "Model/conv2d"]) + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xf32> + %3 = "tf.IdentityN"(%2) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %3 : tensor<1x3x2x2xf32> +// CHECK: tf.Conv2D +// CHECK-SAME: loc(callsite("Model/conv2d@conv2d_with_valid_loc"("Conv2D") at "QuantizationUnit({{.*}})")) +} + +func.func @conv2d_with_callsite_loc(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xbf16> + %1 = "tf.Conv2D"(%0, %cst) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1]} + : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> loc(fused["Conv2D:", callsite("Model/conv2d" at "model.py":10:8)]) + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xf32> + %3 = "tf.IdentityN"(%2) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %3 : tensor<1x3x2x2xf32> +// CHECK: tf.Conv2D +// CHECK-SAME: loc(callsite("Model/conv2d@conv2d_with_callsite_loc"("Conv2D") at "QuantizationUnit({{.*}})")) +} + +func.func @conv2d_with_func_name(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xbf16> + %1 = "tf.Conv2D"(%0, %cst) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1]} + : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> loc(fused["Conv2D:", "Model/conv2d@original_func"]) + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xf32> + %3 = "tf.IdentityN"(%2) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %3 : tensor<1x3x2x2xf32> +// CHECK: tf.Conv2D +// CHECK-SAME: loc(callsite("Model/conv2d@original_func"("Conv2D") at "QuantizationUnit({{.*}})")) +} + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_fake_quant_to_qdq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_fake_quant_to_qdq.mlir new file mode 100644 index 00000000000000..2909f73d4bba6b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_fake_quant_to_qdq.mlir @@ -0,0 +1,44 @@ +// RUN: tf-quant-opt %s -tf-quant-convert-fake-quant-to-qdq | FileCheck %s + +func.func @fakeQuantArgs(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> { + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min = -0.1 : f32, max = 0.2 : f32, num_bits = 8 + } : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + func.return %0 : tensor<8x8x8x8xf32> +} +// CHECK: func @fakeQuantArgs +// CHECK-NEXT: %[[q:.*]] = "quantization.qcast"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x!quant.uniform> +// CHECK-NEXT: %[[dq:.*]] = "quantization.dcast"(%[[q]]) +// CHECK-NEXT: return %[[dq]] + +func.func @doNotHandleNonEightBitFakeQuant(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> { + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min = -0.1 : f32, max = 0.2 : f32, num_bits = 16 + } : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + func.return %0 : tensor<8x8x8x8xf32> +} +// CHECK: func @doNotHandleNonEightBitFakeQuant +// CHECK: tf.FakeQuantWithMinMaxArgs +// CHECK-NOT: "quantization.qcast" + +func.func @fakeQuantVars(%arg0: tensor<3xf32>, %arg1: tensor<4x3xf32>) -> (tensor<3xf32>, tensor<4x3xf32>) { + %cst = "tf.Const"() {value = dense<-0.950868546> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<9.951540e-01> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<[-0.5, -0.4, -0.7]> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_2 = "tf.Const"() {value = dense<[0.5, 0.6, 0.3]> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) { + device = "", narrow_range = false, num_bits = 8 : i64 + } : (tensor<3xf32>, tensor, tensor) -> tensor<3xf32> + %1 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg1, %cst_1, %cst_2) { + device = "", narrow_range = true, num_bits = 8 : i64 + } : (tensor<4x3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<4x3xf32> + func.return %0, %1 : tensor<3xf32>, tensor<4x3xf32> +} + +// CHECK: %[[q1:.*]] = "quantization.qcast"(%arg0) +// CHECK-SAME: tensor<3x!quant.uniform> +// CHECK: %[[dq1:.*]] = "quantization.dcast"(%[[q1]]) +// CHECK: %[[q2:.*]] = "quantization.qcast"(%arg1) +// CHECK-SAME: tensor<4x3x!quant.uniform:f32:1, {0.003937007874015748,0.0039370079913477263:-25,0.003937007874015748:51}>> +// CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) +// CHECK: return %[[dq1]], %[[dq2]] diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_prepare_lifting.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_prepare_lifting.mlir new file mode 100644 index 00000000000000..b8384cbc4c2127 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_prepare_lifting.mlir @@ -0,0 +1,401 @@ +// RUN: tf-quant-opt %s -tf-quant-prepare-lifting -split-input-file | FileCheck %s + +func.func @decompose_batch_norm(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %add, %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%arg0, %cst, %cst_0, %cst_0, %cst) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + func.return %add : tensor<*xf32> +} +// CHECK: func @decompose_batch_norm +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.49743462E-5> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.999950051> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: %[[mul:.*]] = "tf.Mul"(%arg0, %[[CONST_0]]) : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> +// CHECK: %[[add:.*]] = "tf.AddV2"(%[[mul]], %[[CONST]]) : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> +// CHECK-NEXT: return %[[add]] : tensor<*xf32> + +// ----- + +func.func @not_decompose_batch_norm(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %bn, %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%arg0, %cst, %cst_0, %cst_0, %cst) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = true} : (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + func.return %bn : tensor<*xf32> +} +// CHECK: func @not_decompose_batch_norm +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: %[[bn:.*]], %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%arg0, %[[CONST]], %[[CONST_0]], %[[CONST_0]], %[[CONST]]) <{data_format = "NHWC", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = true}> {device = ""} : (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) +// CHECK-NEXT: return %[[bn]] : tensor<*xf32> + +// ----- + +func.func @convert_add_to_biasadd(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.AddV2"(%0, %cst_0) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + func.return %1 : tensor<1x3x2x2xf32> +} +// CHECK: func @convert_add_to_biasadd +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[BIASADD]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @not_convert_add_to_biasadd(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x3xf32>) { + %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x3xf32>} : () -> tensor<2x3x3x3xf32> + %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<1x3x2x3xf32>} : () -> tensor<1x3x2x3xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x3xf32>) -> tensor<1x3x2x3xf32> + %1 = "tf.AddV2"(%0, %cst_0) : (tensor<1x3x2x3xf32>, tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> + func.return %1 : tensor<1x3x2x3xf32> +} +// CHECK: func @not_convert_add_to_biasadd +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x3xf32>}> : () -> tensor<2x3x3x3xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<1x3x2x3xf32>}> : () -> tensor<1x3x2x3xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x3xf32>) -> tensor<1x3x2x3xf32> +// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[CONV2D]], %[[CONST_0]]) : (tensor<1x3x2x3xf32>, tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> +// CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x3xf32> + +// ----- + +func.func @fuse_conv2d_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.Mul"(%0, %cst_0) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + func.return %1 : tensor<1x3x2x2xf32> +} +// CHECK: func @fuse_conv2d_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[CONV2D]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @not_fuse_conv2d_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.Mul"(%0, %cst_0) : (tensor<1x3x2x2xf32>, tensor<2x2xf32>) -> tensor<1x3x2x2xf32> + func.return %1 : tensor<1x3x2x2xf32> +} +// CHECK: func @not_fuse_conv2d_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2x2xf32>}> : () -> tensor<2x2xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[ADD:.*]] = "tf.Mul"(%[[CONV2D]], %[[CONST_0]]) : (tensor<1x3x2x2xf32>, tensor<2x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @fuse_conv2d_with_bias_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + func.return %2 : tensor<1x3x2x2xf32> +} +// CHECK: func @fuse_conv2d_with_bias_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<2.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[BIASADD]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @not_fuse_conv2d_with_bias_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>, tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<0.800000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.Mul"(%0, %cst_1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + func.return %1, %2 : tensor<1x3x2x2xf32>, tensor<1x3x2x2xf32> +} +// CHECK: func @not_fuse_conv2d_with_bias_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[MUL:.*]] = "tf.Mul"(%[[CONV2D]], %[[CONST_1]]) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[BIASADD]], %[[MUL]] : tensor<1x3x2x2xf32>, tensor<1x3x2x2xf32> + +// ----- + +func.func @fuse_conv2d_with_bias_and_add(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.AddV2"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + func.return %2 : tensor<1x3x2x2xf32> +} +// CHECK: func @fuse_conv2d_with_bias_and_add +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[BIASADD]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @not_fuse_conv2d_with_bias_and_add(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.AddV2"(%1, %arg1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + func.return %2 : tensor<1x3x2x2xf32> +} +// CHECK: func @not_fuse_conv2d_with_bias_and_add +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[BIASADD]], %arg1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @match_depthwise_conv2d_and_add(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.AddV2"(%0, %cst_0) : (tensor, tensor<3xf32>) -> tensor<*xf32> + func.return %1 : tensor<*xf32> +} +// CHECK: func @match_depthwise_conv2d_and_add +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> {device = ""} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor, tensor<3xf32>) -> tensor<*xf32> +// CHECK-NEXT: return %[[BIASADD]] : tensor<*xf32> + +// ----- + +func.func @match_depthwise_conv2d_and_mul(%arg0: tensor<*xf32>) -> (tensor) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.Mul"(%0, %cst_0) : (tensor, tensor<3xf32>) -> tensor + func.return %1 : tensor +} +// CHECK: func @match_depthwise_conv2d_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> {device = ""} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor +// CHECK-NEXT: return %[[DEPTHWISE_CONV2D]] : tensor + +// ----- + +func.func @match_depthwise_conv2d_with_bias_and_add(%arg0: tensor<*xf32>) -> (tensor) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor + %2 = "tf.AddV2"(%1, %cst_1) : (tensor, tensor<3xf32>) -> tensor + func.return %2 : tensor +} +// CHECK: func @match_depthwise_conv2d_with_bias_and_add +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> {device = ""} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor, tensor<3xf32>) -> tensor +// CHECK-NEXT: return %[[BIASADD]] : tensor + +// ----- + +func.func @match_depthwise_conv2d_with_bias_and_mul(%arg0: tensor<*xf32>) -> (tensor) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor + %2 = "tf.Mul"(%1, %cst_1) : (tensor, tensor<3xf32>) -> tensor + func.return %2 : tensor +} +// CHECK: func @match_depthwise_conv2d_with_bias_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<2.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> {device = ""} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor, tensor<3xf32>) -> tensor +// CHECK-NEXT: return %[[BIASADD]] : tensor + +// ----- + +func.func @lower_einsum(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,ikm->ijm"}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> + func.return %0 : tensor<3x4x6xf32> +} +// CHECK-LABEL: lower_einsum +// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> + +// ----- + +func.func @removing_identity_after_const(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %identity = "tf.Identity"(%cst) : (tensor<2x3x3x1xf32>) -> tensor<2x3x3x1xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %identity) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> +} +// CHECK: func @removing_identity_after_const +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32> +// CHECK: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) + +// ----- + +func.func @not_removing_identity_of_returning_value(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> + %3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32> + func.return %3 : tensor<*xf32> +} +// CHECK: func @not_removing_identity_of_returning_value +// CHECK: %[[identity:.*]] = "tf.Identity" +// CHECK: return %[[identity]] : tensor<*xf32> + +// ----- + +func.func @batch_norm_with_q_dq(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<5.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {device = "", value = dense<5.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantization.qcast"(%cst_1) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.003937007874015748,0.003937007874015748}>> + %1 = "quantization.dcast"(%0) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.003937007874015748,0.003937007874015748}>>) -> tensor<2x3x3x2xf32> + %2 = "quantization.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> + %3 = "quantization.dcast"(%2) : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> + %4 = "tf.Conv2D"(%3, %1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %y, %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%4, %cst, %cst_0, %cst, %cst_0) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<*xf32>) + %5 = "tf.Relu6"(%y) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %6 = "quantization.qcast"(%5) : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2x!quant.uniform:f32:3, {0.0026771653824903836:-60,0.0032283464285332388:-28}>> + %7 = "quantization.dcast"(%6) : (tensor<1x3x2x2x!quant.uniform:f32:3, {0.0026771653824903836:-60,0.0032283464285332388:-28}>>) -> tensor<1x3x2x2xf32> + %8 = "tf.Identity"(%7) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %9 = "tf.Identity"(%8) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %9 : tensor<1x3x2x2xf32> +} + +// CHECK: func @batch_norm_with_q_dq +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<0.707036077> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<-0.914072155> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: %[[q_input:.*]] = "quantization.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[dq_input:.*]] = "quantization.dcast"(%[[q_input]]) : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> +// CHECK: %[[q_weight:.*]] = "quantization.qcast"(%[[cst]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.005567213212411235,0.005567213212411235}>> +// CHECK: %[[dq_weight:.*]] = "quantization.dcast"(%[[q_weight]]) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.005567213212411235,0.005567213212411235}>>) -> tensor<2x3x3x2xf32> +// CHECK: %[[conv:.*]] = "tf.Conv2D"(%[[dq_input]], %[[dq_weight]]) +// CHECK: %[[bias:.*]] = "tf.BiasAdd"(%[[conv]], %[[cst_0]]) <{data_format = "NHWC"}> +// CHECK: %[[relu6:.*]] = "tf.Relu6"(%[[bias]]) + +// ----- + +func.func @remove_check_numerics_op(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.CheckNumerics"(%arg0) {device = "", message = "transformer"} : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// CHECK: func @remove_check_numerics_op +// CHECK: return %arg0 : tensor<*xf32> + +// ----- + +func.func @remove_stop_gradient_op(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.StopGradient"(%arg0) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// CHECK: func @remove_stop_gradient_op +// CHECK: return %arg0 : tensor<*xf32> + +// ----- + +func.func @conv2d_with_large_weight_and_mul(%arg0: tensor) -> (tensor) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<48x48x3x1xf32>} : () -> tensor<48x48x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<256xf32>} : () -> tensor<256xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<256xf32>} : () -> tensor<256xf32> + %w = "tf.AddV2"(%cst, %cst_1) : (tensor<48x48x3x1xf32>, tensor<256xf32>) -> tensor<48x48x3x256xf32> + %0 = "tf.Conv2D"(%arg0, %w) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor, tensor<48x48x3x256xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor, tensor<256xf32>) -> tensor + %2 = "tf.Mul"(%1, %cst_1) : (tensor, tensor<256xf32>) -> tensor + func.return %2 : tensor +} +// CHECK: func @conv2d_with_large_weight_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.250000e+00> : tensor<48x48x3x256xf32>}> : () -> tensor<48x48x3x256xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<2.000000e-01> : tensor<256xf32>}> : () -> tensor<256xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) +// CHECK-NEXT: return %[[BIASADD]] + +// ----- + +func.func @depthwise_conv2d_with_large_weight_and_add(%arg0: tensor<*xf32>) -> (tensor) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<48x48x3x1xf32>} : () -> tensor<48x48x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_2 = "tf.Const"() {value = dense<0.500000e+00> : tensor<256xf32>} : () -> tensor<256xf32> + %w = "tf.AddV2"(%cst, %cst_2) : (tensor<48x48x3x1xf32>, tensor<256xf32>) -> tensor<48x48x3x256xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %w) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<48x48x3x256xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor + %2 = "tf.AddV2"(%1, %cst_1) : (tensor, tensor<3xf32>) -> tensor + func.return %2 : tensor +} +// CHECK: func @depthwise_conv2d_with_large_weight_and_add +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.500000e+00> : tensor<48x48x3x256xf32>}> : () -> tensor<48x48x3x256xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) +// CHECK-NEXT: return %[[BIASADD]] + +// ---- + +func.func @fuse_conv2d_with_sub_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %cst_1 = "tf.Const"() {value = dense<0.200000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.Sub"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + func.return %2 : tensor<1x3x2x2xf32> +} + +// CHECK: func @fuse_conv2d_with_sub_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<-0.0800000056> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-NEXT: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) +// CHECK-NEXT: %[[BIAS_ADD:.*]] = "tf.BiasAdd"(%[[CONV]], %[[CONST]]) +// CHECK-NEXT: return %[[BIAS_ADD]] + +// ----- + +func.func @fuse_conv2d_with_sub_mul_addv2(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %cst_1 = "tf.Const"() {value = dense<0.200000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %cst_2 = "tf.Const"() {value = dense<0.300000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.Sub"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + %3 = "tf.AddV2"(%2, %cst_2) : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + func.return %3 : tensor<1x3x2x2xf32> +} + +// CHECK: func @fuse_conv2d_with_sub_mul_addv2 +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.200000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-NEXT: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) +// CHECK-NEXT: %[[BIAS_ADD:.*]] = "tf.BiasAdd"(%[[CONV]], %[[CONST]]) +// CHECK-NEXT: return %[[BIAS_ADD]] diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD index fcd42b88cc30c9..dba0230f5d6d1f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD @@ -23,6 +23,25 @@ cc_library( ], ) +cc_library( + name = "temp_fake_quant_utils", + srcs = ["temp_fake_quant_utils.cc"], + hdrs = [ + "temp_fake_quant_utils.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "tf_quantize_op_utils", srcs = ["tf_quantize_op_utils.cc"], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.cc new file mode 100644 index 00000000000000..bcde1612898a17 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.cc @@ -0,0 +1,73 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied and modified from +// //third_party/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.cc +#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h" + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" + +namespace mlir { +namespace tf_quant { + +// Three instances of the rule to cover the three different types of +// TF::FakeQuant operators +using PreparePerTensorFakeQuant = ConvertFakeQuantOpToQuantOps< + TF::FakeQuantWithMinMaxVarsOp, /*PerAxis=*/false, + FetchConstantMinMaxInputs>; + +using PreparePerChannelFakeQuant = ConvertFakeQuantOpToQuantOps< + TF::FakeQuantWithMinMaxVarsPerChannelOp, /*PerAxis=*/true, + FetchConstantMinMaxInputs>; + +using PreparePerTensorFakeQuantWithMinMaxArgs = ConvertFakeQuantOpToQuantOps< + TF::FakeQuantWithMinMaxArgsOp, /*PerAxis=*/false, + FetchMinMaxAttrs>; + +// Removes the wrapper of the tf.FakeQuant* ops and creates the quant.qcast +// and quant.dcast pairs before tf.FakeQuant* ops are being foled. +LogicalResult ConvertFakeQuantOps(func::FuncOp func, MLIRContext* ctx, + bool use_fake_quant_num_bits) { + OpBuilder builder(func); + + // Insert the quant.qcast/quant.dcast ops in place of the tf.FakeQuant* ops to + // preserve the quantization parameters. + func.walk([&](Operation* op) { + if (auto fake_quant = llvm::dyn_cast(op)) { + (void)PreparePerTensorFakeQuantWithMinMaxArgs(use_fake_quant_num_bits) + .matchAndRewrite(fake_quant, builder); + } else if (auto fake_quant = + llvm::dyn_cast(op)) { + (void)PreparePerTensorFakeQuant(use_fake_quant_num_bits) + .matchAndRewrite(fake_quant, builder); + } else if (auto fake_quant = + llvm::dyn_cast( + op)) { + (void)PreparePerChannelFakeQuant(use_fake_quant_num_bits) + .matchAndRewrite(fake_quant, builder); + } + }); + + return success(); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h new file mode 100644 index 00000000000000..84119aa38b4a66 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h @@ -0,0 +1,160 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TF-Quant transformation +// passes to work with tf.FakeQuant* ops. Copied and modified from +// //third_party/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TEMP_FAKE_QUANT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TEMP_FAKE_QUANT_UTILS_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" + +namespace mlir { +namespace tf_quant { + +template +struct FetchMinMaxAttrs { + using AttrType = FloatAttr; + bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, + AttrType &max_value) const { + min_value = tf_op.getMinAttr(); + max_value = tf_op.getMaxAttr(); + return true; // Successfully matched and fetched. + } +}; + +template +struct FetchConstantMinMaxInputs { + using AttrType = DenseFPElementsAttr; + bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, + AttrType &max_value) const { + Value min = tf_op.getMin(), max = tf_op.getMax(); + if (auto min_id = min.getDefiningOp()) { + min = min_id.getInput(); + } + if (auto max_id = max.getDefiningOp()) { + max = max_id.getInput(); + } + + if (!matchPattern(min, m_Constant(&min_value))) { + return false; + } + if (!matchPattern(max, m_Constant(&max_value))) { + return false; + } + return true; // Successfully matched and fetched. + } +}; + +// Inserts a "quant.qcast" and "quant.dcast" op pair (QDQs) in place of the +// tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op +// before the op being constant folded. Since the constant +// folding logic will use a "arith.constant" op to replace the +// "tf.FakeQuantWithMinMaxVarsOp", the "quant.qcast" op is used to preserve +// the quantization parameters as a TypeAttr and "quant.dcast" op used to +// convert the output type to the next op. Here are the transformations: +// +// input min cst max cst input +// \ | | | +// \ (tf.Identity) (tf.Identity) => quant.qcast +// \ | | | +// tf.FakeQuantWithMinMaxVars quant.dcast +// | | +// +// Warns if the (most likely unwanted, currently not quite correctly handled) +// case of back-to-back tf.FakeQuant occurs +// +// tf.FakeQuant* +// | +// tf.FakeQuant* +// +template +class ConvertFakeQuantOpToQuantOps { + public: + explicit ConvertFakeQuantOpToQuantOps(bool use_fake_quant_num_bits) + : use_fake_quant_num_bits_(use_fake_quant_num_bits) {} + + FetchMinMax fetch_min_max_; + + using FetchAttrType = typename FetchMinMax::AttrType; + LogicalResult matchAndRewrite(TFFakeQuantOp tf_op, + OpBuilder &rewriter) const { + if (tf_op.getNumBits() != 8) { + return failure(); + } + + // Extract the min/max constant values from the operands. We also consider + // a special case that there are tf.Identity ops between the min/max + // constants and the tf.FakeQuantWithMinMaxVarsOp. + FetchAttrType min_value, max_value; + if (!fetch_min_max_(tf_op, min_value, max_value)) { + return failure(); + } + + Value input = tf_op.getInputs(); + int quant_dim = -1; + auto input_type = mlir::cast(input.getType()); + if (PerAxis) { + if (!input_type.hasRank()) { + tf_op.emitError("The input should have known rank for per-channel op."); + return failure(); + } + // This is a special case that the quant_dim is the last dimensions. + quant_dim = input_type.getRank() - 1; + } + // Use the min/max from the operands and the num_bits and narrow_range + // attribute to create the quantization parameter for the new quantize op. + rewriter.setInsertionPointAfter(tf_op.getOperation()); + IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.getNumBits()); + BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.getNarrowRange()); + Type res_type = tf_op.getType(); + TypeAttr qtype = tf_quant::GetQuantizedTypeAttr( + rewriter, input_type, min_value, max_value, quant_dim, num_bits, + narrow_range, /*is_signed=*/true, /*legacy_float_scale=*/false, + use_fake_quant_num_bits_); + if (!qtype) { + return failure(); + } + + // Finally, use the quantization parameter to create the quantize and + // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp + // and its users. + auto quantize = rewriter.create( + tf_op.getLoc(), qtype.getValue(), input); + auto dequantize = rewriter.create( + tf_op.getLoc(), res_type, quantize.getResult()); + tf_op.getOutputs().replaceAllUsesWith(dequantize); + + return success(); + } + + bool use_fake_quant_num_bits_; +}; + +// Removes the wrapper of the tf.FakeQuant* ops and creates the quant.qcast +// and quant.dcast pairs before tf.FakeQuant* ops are being folded. +LogicalResult ConvertFakeQuantOps(func::FuncOp func, MLIRContext *ctx, + bool use_fake_quant_num_bits); + +} // namespace tf_quant +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TEMP_FAKE_QUANT_UTILS_H_ diff --git a/tensorflow/compiler/mlir/stablehlo/BUILD b/tensorflow/compiler/mlir/stablehlo/BUILD index 4d24ec1f0e6661..c0287521c7301f 100644 --- a/tensorflow/compiler/mlir/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/stablehlo/BUILD @@ -1,6 +1,9 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("@local_xla//xla/tsl:tsl.default.bzl", "tsl_pybind_extension") +load("@local_xla//xla/tsl/platform:build_config_root.bzl", "if_static") load("//tensorflow:pytype.default.bzl", "pytype_strict_library") load("//tensorflow:strict.default.bzl", "py_strict_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( @@ -17,7 +20,6 @@ package_group( packages = [ "//platforms/darwinn/tools/visualization/graph_conversions/...", "//tensorflow/compiler/mlir/lite/...", - "//tensorflow/compiler/mlir/lite/stablehlo/...", "//tensorflow/compiler/mlir/quantization/...", "//tensorflow/compiler/mlir/quantization/tensorflow/...", "//tensorflow/compiler/tests/...", @@ -67,6 +69,26 @@ py_strict_test( ], ) +gentbl_cc_library( + name = "legalize_tf_patterns_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-rewriters"], + "transforms/generated_legalize_tf.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "transforms/legalize_tf_patterns.td", + deps = [ + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncTdFiles", + "@llvm-project//mlir:TensorOpsTdFiles", + "@local_xla//xla/mlir_hlo:hlo_ops_td_files", + ], +) + cc_library( name = "fold_broadcast_pass", srcs = [ @@ -93,6 +115,112 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "legalize_utils", + srcs = ["transforms/utils.cc"], + hdrs = ["transforms/utils.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@local_xla//xla/mlir_hlo", + ], +) + +tf_cc_test( + name = "legalize_utils_test", + srcs = ["transforms/utils_test.cc"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":legalize_utils", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_xla//xla/mlir_hlo", + ], +) + +cc_library( + name = "legalize_tf", + srcs = [ + "transforms/generated_legalize_tf.inc", + "transforms/legalize_tf.cc", + ], + hdrs = [ + "transforms/legalize_tf_passes.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":legalize_tf_patterns_inc_gen", + ":legalize_utils", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", + "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", + "//tensorflow/core:framework", + "//tensorflow/core/kernels:conv_grad_shape_utils", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@local_tsl//tsl/platform:bfloat16", + "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/hlo/builder:padding", + "@local_xla//xla/hlo/builder:sharding_builder", + "@local_xla//xla/hlo/builder/lib:conv_grad_size_util", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", + "@local_xla//xla/mlir_hlo", + "@local_xla//xla/mlir_hlo:convert_op_folder", + "@local_xla//xla/tsl/platform:status", + "@stablehlo//:chlo_ops", + "@stablehlo//:stablehlo_ops", + ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), +) + +cc_library( + name = "tf_stablehlo", + srcs = [ + "transforms/tf_stablehlo_pass.cc", + ], + hdrs = [ + "transforms/tf_stablehlo_pass.h", + ], + compatible_with = get_compatible_with_portable(), + copts = [ + "-Ithird_party", + ], + deps = [ + ":legalize_tf", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_with_tf2xla", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@local_xla//xla/mlir_hlo", + "@local_xla//xla/mlir_hlo:hlo_dialect_registration", + "@local_xla//xla/mlir_hlo:mhlo_passes", + "@local_xla//xla/mlir_hlo:type_conversion", + "@stablehlo//:chlo_ops", + "@stablehlo//:register", + ], + alwayslink = 1, +) + # LINT.IfChange(legalize_tf_xla_call_module_to_stablehlo_pass) cc_library( name = "legalize_tf_xla_call_module_to_stablehlo_pass", @@ -140,8 +268,8 @@ cc_library( "-Ithird_party", ], deps = [ - "//tensorflow/compiler/mlir/lite:validators", "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/utils:validators", "@llvm-project//llvm:Support", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:FuncDialect", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc similarity index 99% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc rename to tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc index d1e7dd75dcfa9d..beca54296e3ca2 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc @@ -15,18 +15,21 @@ limitations under the License. // This file implements logic for lowering TensorFlow dialect to XLA dialect. #include -#include +#include #include #include #include +#include #include #include #include #include #include +#include #include #include +#include "absl/status/status.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -54,14 +57,14 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" -#include "xla/client/lib/conv_grad_size_util.h" -#include "xla/client/padding.h" -#include "xla/client/sharding_builder.h" +#include "xla/hlo/builder/lib/conv_grad_size_util.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/sharding_builder.h" #include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/utils/convert_op_folder.h" @@ -6842,7 +6845,7 @@ class LowerControlFlowOp : public OpConversionPattern { // Keep all these in the odml namespace to avoid collisions with the tf2xla // version for now. -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/generated_legalize_tf.inc" +#include "tensorflow/compiler/mlir/stablehlo/transforms/generated_legalize_tf.inc" void PopulatePatterns(MLIRContext *context, RewritePatternSet *patterns) { populateWithGenerated(*patterns); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h similarity index 85% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h rename to tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h index 9594769e93f71c..a81cc57b4d2f7a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h +++ b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ #include #include @@ -48,4 +48,4 @@ void PopulateLegalizeTfPatterns(MLIRContext* context, } // namespace odml } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ +#endif // TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_patterns.td similarity index 95% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td rename to tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_patterns.td index 5e01eea4ed3435..24b1d05bce9735 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_patterns.td @@ -33,8 +33,8 @@ def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; // BatchNorm op patterns. //===----------------------------------------------------------------------===// -def FalseBoolAttr : AttrConstraint().getValue()">>; -def TrueBoolAttr : AttrConstraint().getValue()">>; +def FalseBoolAttr : AttrConstraint($_self).getValue()">>; +def TrueBoolAttr : AttrConstraint($_self).getValue()">>; def CastValueToI64: NativeCodeCall< "CastValueToI64($0.getLoc(), $1, &$_builder)">; @@ -47,18 +47,18 @@ def CastValueToElementType: NativeCodeCall< // the corresponding value of ranked tensor type whose axis is referred in $0. def GetHLOAxisFromTFAxis : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, $1.getType().cast().getRank(), &$_builder)">; + "$0, llvm::cast($1.getType()).getRank(), &$_builder)">; // Same as the above but with $1 of type operand_range from variadic TensorFlow // input. def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, (*$1.begin()).getType().cast().getRank(), " + "$0, llvm::cast((*$1.begin()).getType()).getRank(), " "&$_builder)">; def CastElementsToI64Elements : NativeCodeCall< - "hlo::convertElementsAttr(" - "$0.cast(), $_builder.getIntegerType(64)).cast()">; + "llvm::cast(hlo::convertElementsAttr(" + "llvm::cast($0), $_builder.getIntegerType(64)))">; def EmptyDotAlgorithmAttr : NativeCodeCall<"mlir::mhlo::DotAlgorithmAttr{}">; @@ -274,17 +274,17 @@ def : EqualityPat>; //===----------------------------------------------------------------------===// def OneElementAttrPred - : CPred<"$_self.cast().getShapedType().getNumElements() == 1">; + : CPred<"llvm::cast($_self).getShapedType().getNumElements() == 1">; def OneElementAttr : ElementsAttrBase, "Scalar ElementsAttr">; def HasRankedFirstOperand - : Constraint()">>; + : Constraint((*$0.begin()).getType())">>; def IsShapedTensor - : Constraint()">>; + : Constraint($0.getType())">>; // This pattern converts TensorFlow axis format to HLO axis format which // doesn't wrap around like TensorFlow and is always positive. For this @@ -332,10 +332,10 @@ class MHLO_FftTypeValue : ConstantAttr; def GetInnerDimFromValue : NativeCodeCall< - "GetInnerDimFromValue($0.getType().cast(), &$_builder)">; + "GetInnerDimFromValue(llvm::cast($0.getType()), &$_builder)">; def CheckInnerDimStatic - : Constraint(), &$_builder)">>; + : Constraint($0.getType()), &$_builder)">>; def : Pat<(TF_FFTOp:$res $input), (MHLO_FftOp $input, MHLO_FftTypeValue<"FFT">, (GetInnerDimFromValue $res)), @@ -364,14 +364,14 @@ def LegalizeGatherV2 : //===----------------------------------------------------------------------===// class SliceDenseIntElementsAttrColumn2D : NativeCodeCall< - "SliceDenseIntElementsAttrColumn2D($0.cast(), " # column # " )">; + "SliceDenseIntElementsAttrColumn2D(llvm::cast($0), " # column # " )">; class SliceDenseIntElementsAttr : NativeCodeCall< - "SliceDenseIntElementsAttr($0.cast(), " # index # ", " # axis # ")">; + "SliceDenseIntElementsAttr(llvm::cast($0), " # index # ", " # axis # ")">; // Interior padding attribute based on the TF padding. def GetInteriorPadding : NativeCodeCall < - "GetInteriorPadding($0.cast())">; + "GetInteriorPadding(llvm::cast($0))">; def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), (MHLO_PadOp $input, $c, @@ -407,6 +407,9 @@ def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), // Lower `tf.ZerosLike` //===----------------------------------------------------------------------===// +class MHLO_ConstantLike : NativeCodeCall< + "chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; + def : Pat<(TF_ZerosLikeOp AnyTensor:$arg), (MHLO_ConstantLike<"0"> $arg)>; @@ -511,10 +514,10 @@ def UnpackStartingIndices: NativeCodeCall< "UnpackTensorAlongZeroDim($0.getLoc(), $1, &$_builder).getOutput()">; def CanBeTranslatedToDynamicSlice : Constraint())">>; + "CanBeTranslatedToDynamicSlice($0, $1, llvm::cast($2))">>; def TFSliceSizes2HLOSliceSizes : NativeCodeCall< - "TFSliceSizes2HLOSliceSizes($0, $1, $2.cast()," + "TFSliceSizes2HLOSliceSizes($0, $1, llvm::cast($2)," "&$_builder)">; def : Pat<(TF_SliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, @@ -560,7 +563,7 @@ def : Pat<(TF_LegacyCallOp:$op $args, $args_attrs, $res_attrs, FlatSymbolRefAttr //===----------------------------------------------------------------------===// // Handles axis conversion for TF reverse. -def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1.cast(), &$_builder)">; +def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, llvm::cast($1), &$_builder)">; def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher ElementsAttr:$axis)), (MHLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc index a701f7830841b0..2a6db05dffc98e 100644 --- a/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc +++ b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc @@ -36,8 +36,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/utils/validators.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { @@ -95,7 +95,7 @@ class FuseMhloMulAndConvolutionPattern : public OpRewritePattern { // format and backprop input conv filter is in HWOI format. // Only fuses multiplier if all dimensions other than the out channel // dimension are equal to 1. - if (!TFL::IsDimensionsDegenerateExceptLastOne( + if (!TF::IsDimensionsDegenerateExceptLastOne( mul_value.getShapedType().getShape())) { return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic &diag) { diag << "entities 'mul_value' failed to satisfy constraint: " diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc b/tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.cc similarity index 96% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc rename to tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.cc index a3b2b47ac9f76a..b4f726ed4db858 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h" #include #include @@ -32,8 +32,7 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "stablehlo/dialect/Register.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h b/tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h similarity index 81% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h rename to tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h index c26a3f36daf675..2a1df5add974e7 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h +++ b/tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ #include "mlir/Pass/PassManager.h" // from @llvm-project @@ -30,4 +30,4 @@ void AddLegalizeTFToStablehloPasses(OpPassManager& pm, } // namespace odml } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/utils.cc b/tensorflow/compiler/mlir/stablehlo/transforms/utils.cc new file mode 100644 index 00000000000000..d440f20e6d9779 --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/utils.cc @@ -0,0 +1,55 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/stablehlo/transforms/utils.h" + +#include + +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/utils/hlo_utils.h" + +namespace mlir { +namespace odml { + +mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, + OpBuilder* builder) { + return builder->create(loc, + hlo::getScalarOfType(ty, raw_value)); +} + +mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, + OpBuilder* builder) { + return builder->create(loc, + hlo::getScalarNegZeroOfType(ty)); +} + +DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) { + RankedTensorType ty = + RankedTensorType::get(static_cast(attr.size()), + IntegerType::get(attr.getContext(), 64)); + return DenseIntElementsAttr::get(ty, attr.getValue()); +} + +DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, values); +} + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/utils.h b/tensorflow/compiler/mlir/stablehlo/transforms/utils.h new file mode 100644 index 00000000000000..b048850056ea39 --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/utils.h @@ -0,0 +1,63 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_UTILS_H_ + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +// Builds body for reduce op by using the template binary op as the +// reducer op. +template +void BuildReduceBody(Type element_type, Region* body, OpBuilder* builder) { + OpBuilder::InsertionGuard guard(*builder); + Block* block = builder->createBlock(body); + + // Block arguments are scalars of the given element type. + Type type = RankedTensorType::get(/*shape=*/{}, element_type); + Location loc = body->getLoc(); + block->addArguments({type, type}, SmallVector(2, loc)); + + auto reducer = + builder->create(loc, block->getArgument(0), block->getArgument(1)); + builder->create(loc, reducer.getResult()); +} + +mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, + OpBuilder* builder); + +mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, + OpBuilder* builder); + +// Converts an ArrayAttr to a 1D 64-bit dense elements attribute. +DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr); +DenseIntElementsAttr GetI64ElementsAttr(llvm::ArrayRef values, + Builder* builder); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_UTILS_H_ diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/utils_test.cc b/tensorflow/compiler/mlir/stablehlo/transforms/utils_test.cc new file mode 100644 index 00000000000000..dd989d8971a774 --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/utils_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/stablehlo/transforms/utils.h" + +#include + +#include +#include +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { +namespace { + +TEST(UtilsTest, GetScalarConstOfType) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + Type ty = builder.getI32Type(); + mhlo::ConstantOp op = GetScalarConstOfType(ty, loc, 123, &builder); + EXPECT_EQ(op.getValue().getValues()[0], 123); + + op->destroy(); +} + +TEST(UtilsTest, GetScalarNegZeroOfType) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + Type ty = builder.getF32Type(); + mhlo::ConstantOp op = GetScalarNegZeroOfType(ty, loc, &builder); + EXPECT_EQ(op.getValue().getValues()[0], -0.f); + + op->destroy(); +} + +TEST(UtilsTest, GetI64ElementsAttr) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + SmallVector values = {1, 2, 3}; + auto valuesAttr = builder.getI64ArrayAttr(values); + DenseIntElementsAttr attr = GetI64ElementsAttr(valuesAttr); + EXPECT_THAT(SmallVector(attr.getValues()), + testing::ElementsAreArray(values)); +} + +TEST(UtilsTest, GetI64ElementsAttrBuilder) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + SmallVector values = {1, 2, 3}; + DenseIntElementsAttr attr = GetI64ElementsAttr(values, &builder); + EXPECT_THAT(SmallVector(attr.getValues()), + testing::ElementsAreArray(values)); +} + +} // namespace + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 81bf61234707c0..4cf0cfc3f9d08f 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -47,16 +47,10 @@ td_library( gentbl_cc_library( name = "tensorflow_op_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "ir/tf_op_interfaces.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "ir/tf_op_interfaces.cc.inc", - ), - ], + tbl_outs = { + "ir/tf_op_interfaces.h.inc": ["-gen-op-interface-decls"], + "ir/tf_op_interfaces.cc.inc": ["-gen-op-interface-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_op_interfaces.td", test = True, @@ -68,12 +62,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_struct_doc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-dialect-doc"], - "g3doc/tf_ops.md", - ), - ], + tbl_outs = {"g3doc/tf_ops.md": ["-gen-dialect-doc"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", test = True, @@ -107,16 +96,10 @@ cc_library( gentbl_cc_library( name = "tensorflow_all_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_all_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_all_ops.cc.inc", - ), - ], + tbl_outs = { + "ir/tf_all_ops.h.inc": ["-gen-op-decls"], + "ir/tf_all_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", deps = [ @@ -140,22 +123,16 @@ tf_ops_category_list = [ gentbl_cc_library( name = "tensorflow_" + target["name"] + "_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-op-decls", - "-op-include-regex=" + target["include"], - ], - "ir/tf_" + target["name"] + ".h.inc", - ), - ( - [ - "-gen-op-defs", - "-op-include-regex=" + target["include"], - ], - "ir/tf_" + target["name"] + ".cc.inc", - ), - ], + tbl_outs = { + "ir/tf_" + target["name"] + ".h.inc": [ + "-gen-op-decls", + "-op-include-regex=" + target["include"], + ], + "ir/tf_" + target["name"] + ".cc.inc": [ + "-gen-op-defs", + "-op-include-regex=" + target["include"], + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", deps = [ @@ -167,22 +144,16 @@ tf_ops_category_list = [ gentbl_cc_library( name = "tensorflow_remaining_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-op-decls", - "-op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), - ], - "ir/tf_remaining_ops.h.inc", - ), - ( - [ - "-gen-op-defs", - "-op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), - ], - "ir/tf_remaining_ops.cc.inc", - ), - ], + tbl_outs = { + "ir/tf_remaining_ops.h.inc": [ + "-gen-op-decls", + "-op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), + ], + "ir/tf_remaining_ops.cc.inc": [ + "-gen-op-defs", + "-op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", deps = [ @@ -193,20 +164,11 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_saved_model_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_saved_model.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_saved_model.cc.inc", - ), - ( - ["-gen-dialect-doc"], - "g3doc/tf_saved_model.md", - ), - ], + tbl_outs = { + "ir/tf_saved_model.h.inc": ["-gen-op-decls"], + "ir/tf_saved_model.cc.inc": ["-gen-op-defs"], + "g3doc/tf_saved_model.md": ["-gen-dialect-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_saved_model_ops.td", test = True, @@ -219,23 +181,14 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_executor_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_executor.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_executor.cc.inc", - ), - ( - [ - "-gen-dialect-doc", - "-dialect=tf_executor", - ], - "g3doc/tf_executor.md", - ), - ], + tbl_outs = { + "ir/tf_executor.h.inc": ["-gen-op-decls"], + "ir/tf_executor.cc.inc": ["-gen-op-defs"], + "g3doc/tf_executor.md": [ + "-gen-dialect-doc", + "-dialect=tf_executor", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_executor_ops.td", test = True, @@ -250,20 +203,11 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_device_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_device.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_device.cc.inc", - ), - ( - ["-gen-dialect-doc"], - "g3doc/tf_device.md", - ), - ], + tbl_outs = { + "ir/tf_device.h.inc": ["-gen-op-decls"], + "ir/tf_device.cc.inc": ["-gen-op-defs"], + "g3doc/tf_device.md": ["-gen-dialect-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_device_ops.td", test = True, @@ -1034,9 +978,9 @@ cc_library( ":mlir_roundtrip_flags", ":serialize_mlir_module_utils", ":tensorflow", - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow/translate/tools:parsers", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", + "//tensorflow/compiler/mlir/tools:translate_cl_options", "//tensorflow/compiler/mlir/utils:string_container_utils", "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/tf2xla:xla_argument", @@ -1695,6 +1639,7 @@ cc_library( deps = [ "tensorflow_side_effects", "tensorflow_types", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD index ccf7b0b547ab90..f1ab2432181e36 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD @@ -31,16 +31,10 @@ td_library( gentbl_cc_library( name = "tensorflow_tfrt_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "tfrt_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tfrt_ops.cc.inc", - ), - ], + tbl_outs = { + "tfrt_ops.h.inc": ["-gen-op-decls"], + "tfrt_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tfrt_ops.td", deps = [ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index d58b2c7bd65039..e6cee35a820276 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -47,11 +47,11 @@ def TfExecutor_Dialect : Dialect { } // Control type. -def TfeControlType : Type()">, "control">, +def TfeControlType : Type($_self)">, "control">, BuildableType<"$_builder.getType()">; // Token type. -def TfeTokenType : Type()">, "token">, +def TfeTokenType : Type($_self)">, "token">, BuildableType<"$_builder.getType()">; // TODO(hinsu): Define and use TensorType instead of AnyType for data operands diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 127210340114a5..d7ae0542890a79 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -144,24 +144,24 @@ def TF_UniqueResourceAllocation: TraitList<[ //===----------------------------------------------------------------------===// class TF_OperandIsUnrankedPred : - CPred<"$_op.getOperand(" # n # ").getType().isa()">; + CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">; class TF_ResultIsUnrankedPred : - CPred<"$_op.getResult(" # n # ").getType().isa()">; + CPred<"llvm::isa($_op.getResult(" # n # ").getType())">; // Returns true if the n-th operand has unknown rank or has rank m. class TF_OperandHasRank : PredOpTrait<"operand " # n # " is " # m # "-D", Or<[TF_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() == " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() == " # m>]>>; // Returns true if the n-th result has unknown rank or has rank m. class TF_ResultHasRank : PredOpTrait<"result " # n # " is " # m # "-D", Or<[TF_ResultIsUnrankedPred, - CPred<"$_op.getResult(" # n # - ").getType().cast().getRank() == " # m>]>>; + CPred<"llvm::cast($_op.getResult(" # n # + ").getType()).getRank() == " # m>]>>; //===----------------------------------------------------------------------===// // TensorFlow resources and side effects @@ -282,12 +282,12 @@ class TF_Op traits = []> : //===----------------------------------------------------------------------===// class TF_TensorFlowAttr : - Attr()">, + Attr($_self)">, "TensorFlow " # description # " attribute">; def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> { let returnType = "std::optional>"; - let convertFromStorage = "$_self.cast().getValue()"; + let convertFromStorage = "llvm::cast($_self).getValue()"; // Create a ranked shape attr by default. let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)"; @@ -309,11 +309,11 @@ def TF_SymbolRefArrayAttr : // Any tensor element type defined in the TensorFlow dialect def TF_TFDialectType : - Type()">, "TensorFlow type">; + Type($_self)">, "TensorFlow type">; // Class for any TensorFlow dialect specific type class TF_TensorFlowType : - Type()">, + Type($_self)">, "TensorFlow " # description # " type">, BuildableType<"getType()">; @@ -547,9 +547,9 @@ def TF_Tensor : TensorOf<[TF_ElementType]>; // A string attribute whose value are one of the values in `cases`. class TF_AnyStrAttrOf cases> : StringBasedAttr< CPred().getValue() == \"" # !head(cases) # "\"", + "llvm::cast($_self).getValue() == \"" # !head(cases) # "\"", !foreach(case, !tail(cases), - "$_self.cast().getValue() == \"" # case # "\""), + "llvm::cast($_self).getValue() == \"" # case # "\""), prev, cur, prev # " || " # cur)>, "string attribute whose value is " # !foldl(/*init*/!head(cases), /*list*/!tail(cases), @@ -558,8 +558,8 @@ class TF_AnyStrAttrOf cases> : StringBasedAttr< // TODO: Use EnumAttr to define the common attribute cases def TF_ConvnetDataFormatAttr : StringBasedAttr< - CPred<"$_self.cast().getValue() == \"NHWC\" || " # - "$_self.cast().getValue() == \"NCHW\"">, + CPred<"llvm::cast($_self).getValue() == \"NHWC\" || " # + "llvm::cast($_self).getValue() == \"NCHW\"">, "'NHWC' or 'NCHW' convnet data format">; //===----------------------------------------------------------------------===// @@ -679,7 +679,7 @@ class TF_DerivedResultShapeListAttr : DerivedAttr< // A derived attribute that returns the shape of the first result type. def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType", - "return (*getOperation()->result_type_begin()).cast();", + "return llvm::cast((*getOperation()->result_type_begin()));", [{ mlir::TF::ShapeAttr::get($_ctxt, $_self) }]>; def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> { @@ -713,14 +713,14 @@ class WithBroadcastableCmpOpBuilder { OpBuilder<(ins "Value":$x, "Value":$y), [{ Type resultType; - if (x.getType().isa() || - y.getType().isa()) { + if (llvm::isa(x.getType()) || + llvm::isa(y.getType())) { resultType = UnrankedTensorType::get($_builder.getI1Type()); } else { SmallVector resultShape; if (!OpTrait::util::getBroadcastedShape( - x.getType().cast().getShape(), - y.getType().cast().getShape(), resultShape)) { + llvm::cast(x.getType()).getShape(), + llvm::cast(y.getType()).getShape(), resultShape)) { mlir::emitError($_state.location, "operands have no broadcastable shapes"); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 83dca69fc1a9d8..c989178f5fb463 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -57,7 +57,7 @@ class TF_TensorListInitOp : TF_Op { // Returns data type of the result handle. Returned type contains type of // the TensorList element as a subtype. VariantType handle_dtype() { - return getElementTypeOrSelf(getHandle().getType()).cast(); + return llvm::cast(getElementTypeOrSelf(getHandle().getType())); } }]; } @@ -118,7 +118,7 @@ An n-way switch statement, implementing the following: // Prefer passing in SymbolTableCollection to reduce lookup costs by // enabling reusing cached symbol table lookup. func::FuncOp ResolveBranchFunction(::mlir::SymbolTableCollection* table, int index) { - auto flat_sym_ref = getBranches()[index].cast(); + auto flat_sym_ref = llvm::cast(getBranches()[index]); if (table) return table->lookupNearestSymbolFrom(*this, flat_sym_ref); return SymbolTable::lookupNearestSymbolFrom(*this, flat_sym_ref); @@ -854,14 +854,14 @@ Example: "return getElementTypeOrSelf(resource_subtype());">; DerivedAttr shape = DerivedAttr< "ShapedType", - "return resource_subtype().cast();", + "return llvm::cast(resource_subtype());", [{ mlir::TF::ShapeAttr::get($_ctxt, $_self) }]>; let extraClassDeclaration = [{ TensorType resource_subtype() { return resource_type().getSubtypes()[0]; } ResourceType resource_type() { - return getElementTypeOrSelf(getResource()).cast(); + return llvm::cast(getElementTypeOrSelf(getResource())); } }]; @@ -2210,6 +2210,36 @@ def TF_XlaSparseDenseMatmulWithCsrInputOp : TF_Op<"XlaSparseDenseMatmulWithCsrIn ); } +def TF_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput", [Pure]> { + let summary = "This op looks up the embedding vectors on SparseCores and performs the given combiner computation on TensorCores."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$weights, + + ConfinedAttr]>:$input_size, + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + OptionalAttr:$quantization_config_low, + OptionalAttr:$quantization_config_high, + OptionalAttr:$quantization_config_num_buckets, + + SymbolRefAttr:$combiner_computation, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$activations, + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors + ); +} + def TF_XlaSparseDenseMatmulGradWithSgdAndCsrInputOp : TF_Op<"XlaSparseDenseMatmulGradWithSgdAndCsrInput", [Pure]> { let summary = ""; @@ -2819,6 +2849,282 @@ def TF_XlaSparseDenseMatmulGradWithCsrInputOp : TF_Op<"XlaSparseDenseMatmulGradW TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<5>; } +def TF_XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput", [Pure]> { + let summary = "This op back-propagates the activation gradients to the embedding table and the combiner weights."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + // Custom combiner learnable weights to be updated in this backward pass. + TF_Float32Tensor:$weights, + // Preserved outputs of the SparseCore embedding forward pass (for TC + // combiner VJP). + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors, + TF_Float32Tensor:$preserved_weights, + // Gradients of the activation. + TF_Float32Tensor:$activation_gradients, + // Learning rate of the embedding table. + TF_Float32Tensor:$learning_rate, + // Learning rate of the custom combiner weights (using SGD). + TF_Float32Tensor:$combiner_weights_learning_rate, + TF_Float32Tensor:$embedding_table, + + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + + SymbolRefAttr:$combiner_table_vjp_computation, + SymbolRefAttr:$combiner_weights_vjp_computation, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_weights + ); +} + +def TF_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput", [Pure]> { + let summary = "This op back-propagates the activation gradients to the embedding table and the combiner weights."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + // Custom combiner learnable weights to be updated in this backward pass. + TF_Float32Tensor:$weights, + // Preserved outputs of the SparseCore embedding forward pass (for TC + // combiner VJP). + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors, + TF_Float32Tensor:$preserved_weights, + // Gradients of the activation. + TF_Float32Tensor:$activation_gradients, + // Learning rate of the embedding table. + TF_Float32Tensor:$learning_rate, + // Learning rate of the custom combiner weights (using SGD). + TF_Float32Tensor:$combiner_weights_learning_rate, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$accumulator, + + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + + SymbolRefAttr:$combiner_table_vjp_computation, + SymbolRefAttr:$combiner_weights_vjp_computation, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_accumulator, + TF_Float32Tensor:$updated_weights + ); +} + +def TF_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput", [Pure]> { + let summary = "This op back-propagates the activation gradients to the embedding table and the combiner weights."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + // Custom combiner learnable weights to be updated in this backward pass. + TF_Float32Tensor:$weights, + // Preserved outputs of the SparseCore embedding forward pass (for TC + // combiner VJP). + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors, + TF_Float32Tensor:$preserved_weights, + // Gradients of the activation. + TF_Float32Tensor:$activation_gradients, + // Learning rate of the embedding table. + TF_Float32Tensor:$learning_rate, + // Learning rate of the custom combiner weights (using SGD). + TF_Float32Tensor:$combiner_weights_learning_rate, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$accumulator, + TF_Float32Tensor:$momenta, + + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + + BoolAttr:$use_nesterov, + F32Attr:$exponent, + F32Attr:$beta1, + F32Attr:$beta2, + F32Attr:$epsilon, + + SymbolRefAttr:$combiner_table_vjp_computation, + SymbolRefAttr:$combiner_weights_vjp_computation, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_accumulator, + TF_Float32Tensor:$updated_momenta, + TF_Float32Tensor:$updated_weights + ); +} + +def TF_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput", [Pure]> { + let summary = "This op back-propagates the activation gradients to the embedding table and the combiner weights."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + // Custom combiner learnable weights to be updated in this backward pass. + TF_Float32Tensor:$weights, + // Preserved outputs of the SparseCore embedding forward pass (for TC + // combiner VJP). + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors, + TF_Float32Tensor:$preserved_weights, + // Gradients of the activation. + TF_Float32Tensor:$activation_gradients, + // Learning rate of the embedding table. + TF_Float32Tensor:$learning_rate, + // Learning rate of the custom combiner weights (using SGD). + TF_Float32Tensor:$combiner_weights_learning_rate, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$momenta, + TF_Float32Tensor:$velocity, + + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + + BoolAttr:$use_sum_inside_sqrt, + F32Attr:$beta1, + F32Attr:$beta2, + F32Attr:$epsilon, + + SymbolRefAttr:$combiner_table_vjp_computation, + SymbolRefAttr:$combiner_weights_vjp_computation, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_momenta, + TF_Float32Tensor:$updated_velocity, + TF_Float32Tensor:$updated_weights + ); +} + +def TF_XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput", [Pure]> { + let summary = "This op back-propagates the activation gradients to the embedding table and the combiner weights."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + // Custom combiner learnable weights to be updated in this backward pass. + TF_Float32Tensor:$weights, + // Preserved outputs of the SparseCore embedding forward pass (for TC + // combiner VJP). + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors, + TF_Float32Tensor:$preserved_weights, + // Gradients of the activation. + TF_Float32Tensor:$activation_gradients, + // Learning rate of the embedding table. + TF_Float32Tensor:$learning_rate, + // Learning rate of the custom combiner weights (using SGD). + TF_Float32Tensor:$combiner_weights_learning_rate, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$accumulator, + TF_Float32Tensor:$linear, + + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + + BoolAttr:$multiply_linear_by_learning_rate, + F32Attr:$beta, + F32Attr:$learning_rate_power, + F32Attr:$l1_regularization_strength, + F32Attr:$l2_regularization_strength, + + SymbolRefAttr:$combiner_table_vjp_computation, + SymbolRefAttr:$combiner_weights_vjp_computation, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_accumulator, + TF_Float32Tensor:$updated_linear, + TF_Float32Tensor:$updated_weights + ); +} + +def TF_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput", [AttrSizedOperandSegments, Pure]> { + let summary = "This op back-propagates the activation gradients to the embedding table and the combiner weights."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + // Custom combiner learnable weights to be updated in this backward pass. + TF_Float32Tensor:$weights, + // Preserved outputs of the SparseCore embedding forward pass (for TC + // combiner VJP). + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors, + TF_Float32Tensor:$preserved_weights, + // Gradients of the activation. + TF_Float32Tensor:$activation_gradients, + // The embedding table and the associated slot variables. + Variadic:$tables, + // Hyperparameters of the current optimizer. + Variadic:$hyperparameters, + // Learning rate of the custom combiner weights (using SGD). + TF_Float32Tensor:$combiner_weights_learning_rate, + + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + + SymbolRefAttr:$combiner_table_vjp_computation, + SymbolRefAttr:$combiner_weights_vjp_computation, + SymbolRefAttr:$optimizer_custom_computation, + StrAttr:$table_name + ); + + let results = (outs + Variadic:$updated_tables, + TF_Float32Tensor:$updated_weights + ); + + // Number of embedding table + its associated slot variables. + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<10>; + // Number of hyperparameters. + TF_DerivedOperandSizeAttr M = TF_DerivedOperandSizeAttr<11>; +} + // b/394499589: move back to tf_generated_ops.td def TF_PartitionedCallOp : TF_Op<"PartitionedCall", [CallOpInterface, DeclareOpInterfaceMethods, Pure]> { let summary = [{ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 905f4864655a33..ce586b43fd38b6 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -98,7 +98,7 @@ namespace { // Returns the equivalent Value skipping through identity nodes. Value LookThroughIdentity(Value result) { while (isa_and_nonnull(result.getDefiningOp())) { - auto op_result = result.cast(); + auto op_result = cast(result); result = op_result.getOwner()->getOperand(op_result.getResultNumber()); } return result; @@ -195,7 +195,7 @@ LogicalResult OneHotOp::verify() { OneHotOp op = *this; int64_t axis = op.getAxis(); - auto indices_ty = op.getIndices().getType().dyn_cast(); + auto indices_ty = llvm::dyn_cast(op.getIndices().getType()); if (indices_ty && !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) { return op.emitOpError() @@ -234,11 +234,11 @@ LogicalResult OneHotOp::verify() { static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value, Value off_value, IntegerAttr axis) { int64_t axis_val = axis.getInt(); - Type element_ty = on_value.getType().cast().getElementType(); + Type element_ty = llvm::cast(on_value.getType()).getElementType(); auto unranked_ty = UnrankedTensorType::get(element_ty); if (axis_val < -1) return unranked_ty; - auto indices_ty = indices.getType().dyn_cast(); + auto indices_ty = llvm::dyn_cast(indices.getType()); if (!indices_ty) return unranked_ty; auto shape = llvm::to_vector<2>(indices_ty.getShape()); @@ -278,7 +278,7 @@ LogicalResult PackOp::verify() { int64_t inputs_rank = -1; for (Value value : values) { - if (auto ty = value.getType().dyn_cast()) { + if (auto ty = llvm::dyn_cast(value.getType())) { // Exit early as input types are verified to be compatible so all ranked // tensors have the same rank. inputs_rank = ty.getRank(); @@ -346,7 +346,7 @@ OpFoldResult PackOp::fold(FoldAdaptor) { auto const_op = dyn_cast_or_null(value.getDefiningOp()); if (!const_op) return std::nullopt; - auto value_attr = const_op.getValue().dyn_cast(); + auto value_attr = llvm::dyn_cast(const_op.getValue()); if (!value_attr || value_attr.getNumElements() != 1) return std::nullopt; auto value_ty = value_attr.getType(); @@ -378,7 +378,7 @@ OpFoldResult PackOp::fold(FoldAdaptor) { return {}; // First tensor dimension is dynamic. - auto arg_ty = tensor.getType().dyn_cast(); + auto arg_ty = llvm::dyn_cast(tensor.getType()); if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 || !arg_ty.isDynamicDim(0)) return {}; @@ -416,8 +416,8 @@ struct ConvertPackToReshape : public OpRewritePattern { } // Check if input and output are static. - auto input_ty = pack_op.getOperand(0).getType().cast(); - auto output_ty = pack_op.getOutput().getType().cast(); + auto input_ty = llvm::cast(pack_op.getOperand(0).getType()); + auto output_ty = llvm::cast(pack_op.getOutput().getType()); if (!input_ty.hasStaticShape() || !output_ty.hasStaticShape()) { return failure(); } @@ -467,7 +467,8 @@ LogicalResult PadOp::FoldOperandsPermutation(ArrayRef permutation) { dyn_cast_or_null(getPaddings().getDefiningOp()); if (!paddings_op) return failure(); - auto paddings_value = paddings_op.getValue().dyn_cast(); + auto paddings_value = + llvm::dyn_cast(paddings_op.getValue()); if (!paddings_value || paddings_value.getNumElements() != permutation.size() * 2) return failure(); @@ -493,9 +494,8 @@ LogicalResult PadOp::FoldOperandsPermutation(ArrayRef permutation) { setOperand(1, shuffled_paddings_op); // Change the result type. - getResult().setType(ShuffleRankedTensorType(getResult().getType(), - ReversePermutation(permutation)) - .cast()); + getResult().setType(llvm::cast(ShuffleRankedTensorType( + getResult().getType(), ReversePermutation(permutation)))); return success(); } @@ -561,7 +561,7 @@ LogicalResult ParseExampleV2Op::verify() { template static LogicalResult VerifyPartitionedCall(CallOpClass op, SymbolTableCollection &symbolTable) { - SymbolRefAttr func = op->getAttr("f").template cast(); + SymbolRefAttr func = llvm::cast(op->getAttr("f")); auto function = symbolTable.lookupNearestSymbolFrom(op, func); if (!function) { return op.emitError("'f' attribute refers to an undefined function: ") @@ -625,10 +625,10 @@ void TPUPartitionedCallOp::setCalleeFromCallable( OpFoldResult PowOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); - auto constant_y = operands[1].dyn_cast_or_null(); + auto constant_y = llvm::dyn_cast_if_present(operands[1]); if (constant_y && constant_y.isSplat()) { APFloat y_value = constant_y.getSplatValue(); - auto output_type = getType().cast(); + auto output_type = llvm::cast(getType()); if (y_value.isZero() && output_type.hasStaticShape()) { return DenseElementsAttr::get( output_type, @@ -661,7 +661,7 @@ void QuantizeAndDequantizeV2Op::getCanonicalizationPatterns( // LogicalResult QrOp::verify() { QrOp op = *this; - auto ttype = op.getInput().getType().cast(); + auto ttype = llvm::cast(op.getInput().getType()); if (!ttype.hasRank()) return success(); if (!HasRankAtLeast(op.getInput(), 2)) return op.emitOpError( @@ -765,29 +765,29 @@ void RangeOp::build(OpBuilder &builder, OperationState &result, Value start, builder, result, tensorflow::GetTypeFromTFTensorShape( size.getSExtValue(), - start.getType().cast().getElementType()), + llvm::cast(start.getType()).getElementType()), start, limit, delta); } return RangeOp::build( builder, result, tensorflow::GetTypeFromTFTensorShape( - {-1}, start.getType().cast().getElementType()), + {-1}, llvm::cast(start.getType()).getElementType()), start, limit, delta); } OpFoldResult RangeOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); assert(operands.size() == 3); - auto start_tensor = operands[0].dyn_cast_or_null(); - auto limit_tensor = operands[1].dyn_cast_or_null(); - auto delta_tensor = operands[2].dyn_cast_or_null(); + auto start_tensor = llvm::dyn_cast_if_present(operands[0]); + auto limit_tensor = llvm::dyn_cast_if_present(operands[1]); + auto delta_tensor = llvm::dyn_cast_if_present(operands[2]); if (!(start_tensor && limit_tensor && delta_tensor)) return nullptr; // Operands should all be scalars assert(start_tensor.getShapedType().getRank() == 0 && limit_tensor.getShapedType().getRank() == 0 && delta_tensor.getShapedType().getRank() == 0); - Type elem_type = getType().cast().getElementType(); + Type elem_type = llvm::cast(getType()).getElementType(); if (elem_type.isSignlessInteger() || elem_type.isUnsignedInteger()) { auto start_attr = start_tensor.getValues()[0]; auto limit_attr = limit_tensor.getValues()[0]; @@ -809,7 +809,7 @@ OpFoldResult RangeOp::fold(FoldAdaptor adaptor) { } return BuildConstRangeTensor(elem_type, num_elements, start_attr, delta_attr); - } else if (elem_type.isa()) { + } else if (isa(elem_type)) { auto start_attr = start_tensor.getValues()[0]; auto limit_attr = limit_tensor.getValues()[0]; auto delta_attr = delta_tensor.getValues()[0]; @@ -836,12 +836,12 @@ void RankOp::build(OpBuilder &builder, OperationState &result, Value input) { // This will create a constant value for RankOp of a ranked tensor. OpFoldResult RankOp::fold(FoldAdaptor) { auto type = getInput().getType(); - auto ranked_type = type.dyn_cast(); + auto ranked_type = llvm::dyn_cast(type); if (!ranked_type) return {}; // DenseIntElementsAttr::get requires the output type be ranked with static // shape. - auto output_type = getType().dyn_cast(); + auto output_type = llvm::dyn_cast(getType()); if (!output_type || !output_type.hasStaticShape()) return {}; int32_t rank = ranked_type.getRank(); @@ -882,11 +882,11 @@ using ReshapeErrorHandler = LogicalResult GetReshapeOutputType(Value tensor, Value shape, ReshapeErrorHandler error_handler, TensorType &output_ty) { - auto tensor_ty = tensor.getType().cast(); + auto tensor_ty = llvm::cast(tensor.getType()); auto element_ty = tensor_ty.getElementType(); output_ty = UnrankedTensorType::get(element_ty); - auto shape_ty = shape.getType().dyn_cast(); + auto shape_ty = llvm::dyn_cast(shape.getType()); if (!shape_ty) return success(); if (shape_ty.getRank() != 1) return error_handler(llvm::formatv( @@ -982,9 +982,9 @@ LogicalResult ReshapeOp::verify() { expected_ty))) return failure(); - auto output_ty = op.getType().dyn_cast(); + auto output_ty = llvm::dyn_cast(op.getType()); if (!output_ty) return success(); - auto tensor_ty = op.getTensor().getType().cast(); + auto tensor_ty = llvm::cast(op.getTensor().getType()); if (output_ty.hasStaticShape() && tensor_ty.hasStaticShape()) { const int64_t output_ty_size = output_ty.getNumElements(); const int64_t tensor_ty_size = tensor_ty.getNumElements(); @@ -1027,7 +1027,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor) { // Fold reshape if operand and result types are the same and all dimensions // are statically known (no-op reshape). - auto result_ty = getType().dyn_cast(); + auto result_ty = llvm::dyn_cast(getType()); if (result_ty && result_ty.hasStaticShape() && result_ty == tensor.getType()) { return tensor; @@ -1049,8 +1049,8 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor) { // first dimension equal to `cond`. LogicalResult SelectOp::verify() { SelectOp op = *this; - auto then_tensor = op.getThenValue().getType().cast(); - auto else_tensor = op.getElseValue().getType().cast(); + auto then_tensor = llvm::cast(op.getThenValue().getType()); + auto else_tensor = llvm::cast(op.getElseValue().getType()); // Check (1). if (!AreCastCompatible({then_tensor, else_tensor})) return op.emitOpError() << "requires t and e have compatible shapes"; @@ -1081,7 +1081,8 @@ LogicalResult SelectOp::verify() { return success(); } - auto cond_tensor = op.getCondition().getType().dyn_cast(); + auto cond_tensor = + llvm::dyn_cast(op.getCondition().getType()); if (!cond_tensor) return success(); auto cond_rank = cond_tensor.getRank(); // Check (2a) and (2b). @@ -1111,15 +1112,15 @@ LogicalResult SelectOp::verify() { //===----------------------------------------------------------------------===// static Type InferSelectV2OpType(Value condition, Value e, Value t) { - Type element_ty = e.getType().cast().getElementType(); + Type element_ty = llvm::cast(e.getType()).getElementType(); auto unranked_ty = UnrankedTensorType::get(element_ty); Type broadcasted_ty = OpTrait::util::getBroadcastedType(e.getType(), t.getType()); if (!broadcasted_ty) return unranked_ty; - auto cond_ranked_ty = condition.getType().dyn_cast(); - auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast(); + auto cond_ranked_ty = llvm::dyn_cast(condition.getType()); + auto broadcasted_ranked_ty = llvm::dyn_cast(broadcasted_ty); if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty; // Explicitly get broadcasted output type as element types of condition may @@ -1149,12 +1150,13 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, std::string variadic_idx_str = variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str(); - auto result_ranked_type = result_type.dyn_cast(); + auto result_ranked_type = llvm::dyn_cast(result_type); if (!result_ranked_type) return success(); if (result_ranked_type.getShape().size() != 1) return op->emitOpError("requires 1D type for result") << variadic_idx_str; - auto operand_ranked_type = operand_type.dyn_cast_or_null(); + auto operand_ranked_type = + llvm::dyn_cast_or_null(operand_type); if (operand_ranked_type) { // The operand is a ranked tensor. if (result_ranked_type.hasStaticShape() && @@ -1197,7 +1199,7 @@ LogicalResult ShapeOp::verify() { // Converts shape of the given type to attribute if it is of ranked tensor type. // Returned attribute has integer elements of the given width. static Attribute ConvertShapeToAttr(Type input_ty, int out_width) { - auto ranked_ty = input_ty.dyn_cast(); + auto ranked_ty = llvm::dyn_cast(input_ty); if (!ranked_ty || !ranked_ty.hasStaticShape()) return {}; auto shape = ranked_ty.getShape(); @@ -1214,14 +1216,15 @@ static Attribute ConvertShapeToAttr(Type input_ty, int out_width) { } OpFoldResult ShapeOp::fold(FoldAdaptor) { - int width = - getType().cast().getElementType().getIntOrFloatBitWidth(); + int width = llvm::cast(getType()) + .getElementType() + .getIntOrFloatBitWidth(); return ConvertShapeToAttr(getOperand().getType(), width); } void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input, BoolAttr use32Bit) { - auto rankedTensorType = input.getType().dyn_cast(); + auto rankedTensorType = llvm::dyn_cast(input.getType()); int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1; auto out_type = use32Bit.getValue() ? builder.getIntegerType(32) : builder.getIntegerType(64); @@ -1347,9 +1350,9 @@ LogicalResult SizeOp::verify() { } OpFoldResult SizeOp::fold(FoldAdaptor) { - ShapedType output_type = getType().cast(); + ShapedType output_type = llvm::cast(getType()); if (!output_type.hasRank()) return {}; - ShapedType input_type = getOperand().getType().cast(); + ShapedType input_type = llvm::cast(getOperand().getType()); if (!input_type.hasStaticShape()) return {}; int size = input_type.getNumElements(); return DenseElementsAttr::get( @@ -1395,13 +1398,13 @@ LogicalResult SliceOp::verify() { " same number of elements"; } - auto input_ty = op.getInput().getType().dyn_cast(); + auto input_ty = llvm::dyn_cast(op.getInput().getType()); if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) { return op.emitOpError() << "requires number of elements in begin and size " "are equal to input rank"; } - auto output_ty = op.getOutput().getType().dyn_cast(); + auto output_ty = llvm::dyn_cast(op.getOutput().getType()); if (output_ty && input_ty && output_ty.getRank() != input_ty.getRank()) { return op.emitOpError() << "requires output to have the same rank as input, but got input " @@ -1488,9 +1491,8 @@ LogicalResult SoftmaxOp::verify() { LogicalResult SoftmaxCrossEntropyWithLogitsOp::verify() { SoftmaxCrossEntropyWithLogitsOp op = *this; auto broadcasted_ty = - OpTrait::util::getBroadcastedType(op.getFeatures().getType(), - op.getLabels().getType()) - .dyn_cast_or_null(); + llvm::dyn_cast_or_null(OpTrait::util::getBroadcastedType( + op.getFeatures().getType(), op.getLabels().getType())); if (!broadcasted_ty || (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2)) return op.emitOpError( @@ -1516,9 +1518,10 @@ int64_t SpaceToBatchNDBlockRank(const TensorType block_shape_type, LogicalResult SpaceToBatchNDOp::verify() { SpaceToBatchNDOp op = *this; - const auto input_type = op.getInput().getType().cast(); - const auto block_shape_type = op.getBlockShape().getType().cast(); - const auto paddings_type = op.getPaddings().getType().cast(); + const auto input_type = llvm::cast(op.getInput().getType()); + const auto block_shape_type = + llvm::cast(op.getBlockShape().getType()); + const auto paddings_type = llvm::cast(op.getPaddings().getType()); // Check that block_shape has rank 1. if (!IsOfRankOrUnranked(op.getBlockShape(), 1)) { @@ -1626,8 +1629,9 @@ LogicalResult SparseSoftmaxCrossEntropyWithLogitsOp::verify() { if (!IsOfRankOrUnranked(op.getLabels(), 1)) { return op.emitOpError("requires labels operand of rank one"); } - auto features_ty = op.getFeatures().getType().dyn_cast(); - auto labels_ty = op.getLabels().getType().dyn_cast(); + auto features_ty = + llvm::dyn_cast(op.getFeatures().getType()); + auto labels_ty = llvm::dyn_cast(op.getLabels().getType()); if (features_ty && labels_ty) { int64_t features_batches = features_ty.getDimSize(0); int64_t labels_batches = labels_ty.getDimSize(0); @@ -1653,7 +1657,8 @@ LogicalResult VerifySplitInputAndSplitDim(Op op, *dim_index = std::nullopt; Value split_dim = op.getSplitDim(); - if (auto split_dim_type = split_dim.getType().dyn_cast()) + if (auto split_dim_type = + llvm::dyn_cast(split_dim.getType())) if (split_dim_type.getRank() != 0) return op.emitOpError( "split dimension should be an integer scalar tensor"); @@ -1661,8 +1666,7 @@ LogicalResult VerifySplitInputAndSplitDim(Op op, // We can perform further verification if the input tensor to be split has // known rank and the split dimension tensor is a constant. - auto input_type = - op.getValue().getType().template dyn_cast(); + auto input_type = llvm::dyn_cast(op.getValue().getType()); if (!input_type) return success(); int64_t input_rank = input_type.getRank(); @@ -1691,8 +1695,8 @@ LogicalResult SplitOp::verify() { if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); if (!dim_index) return success(); - int64_t input_dim_size = - op.getValue().getType().cast().getDimSize(*dim_index); + int64_t input_dim_size = llvm::cast(op.getValue().getType()) + .getDimSize(*dim_index); if (ShapedType::isDynamic(input_dim_size)) return success(); if (op.getNumResults() == 0) return failure(); @@ -1711,7 +1715,7 @@ LogicalResult SplitOp::verify() { LogicalResult SplitVOp::verify() { SplitVOp op = *this; auto split_sizes_type = - op.getSizeSplits().getType().dyn_cast(); + llvm::dyn_cast(op.getSizeSplits().getType()); if (!split_sizes_type) return success(); if (split_sizes_type.getRank() != 1 || @@ -1724,8 +1728,8 @@ LogicalResult SplitVOp::verify() { if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); if (!dim_index) return success(); - int64_t input_dim_size = - op.getValue().getType().cast().getDimSize(*dim_index); + int64_t input_dim_size = llvm::cast(op.getValue().getType()) + .getDimSize(*dim_index); if (ShapedType::isDynamic(input_dim_size)) return success(); // If split sizes come from a constant, they must sum to the dimension size @@ -1739,7 +1743,7 @@ LogicalResult SplitVOp::verify() { SmallVector split_sizes; split_sizes.reserve( - split_sizes_attr.getType().cast().getNumElements()); + llvm::cast(split_sizes_attr.getType()).getNumElements()); for (const auto &dim : llvm::enumerate(split_sizes_attr)) { int64_t dim_val = dim.value().getSExtValue(); @@ -1785,7 +1789,7 @@ void SquareOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult SqueezeOp::verify() { SqueezeOp op = *this; - auto input_type = op.getInput().getType().dyn_cast(); + auto input_type = llvm::dyn_cast(op.getInput().getType()); if (!input_type) return success(); // Can't verify squeeze dims. @@ -1829,9 +1833,9 @@ void SumOp::build(OpBuilder &builder, OperationState &result, Value input, // TODO: Templatize this fold for all reduction ops. OpFoldResult SumOp::fold(FoldAdaptor) { - auto input_ty = getInput().getType().template dyn_cast(); + auto input_ty = llvm::dyn_cast(getInput().getType()); if (!input_ty) return {}; - auto result_ty = getType().template dyn_cast(); + auto result_ty = llvm::dyn_cast(getType()); if (!result_ty) return {}; // Bypass this op if the result has the same shape and type. This can happen @@ -1866,7 +1870,7 @@ static LogicalResult VerifyStridedSliceBase(OpTy op) { int64_t expected_size = -1; for (Value val : {op.getBegin(), op.getEnd(), op.getStrides()}) { - auto operand_ty = val.getType().dyn_cast(); + auto operand_ty = llvm::dyn_cast(val.getType()); if (!operand_ty || !operand_ty.hasStaticShape()) { // TensorFlow constant ops may have non-static shape because the shape is // not propagated during constant folding. If the defining op for this @@ -2151,7 +2155,7 @@ bool StridedSliceOp::GetSlicedBoundRanges( !matchPattern(getStrides(), m_Constant(&sparse_strides_attr))) return false; - auto input_ty = this->getInput().getType().dyn_cast(); + auto input_ty = llvm::dyn_cast(this->getInput().getType()); if (!input_ty || !input_ty.hasStaticShape()) return false; auto input_shape = llvm::to_vector<4>(input_ty.getShape()); @@ -2210,7 +2214,8 @@ OpFoldResult StridedSliceOp::fold(FoldAdaptor) { // pattern. if (getNewAxisMask() != 0) return {}; - auto tensor_ty = shape_op.getInput().getType().dyn_cast(); + auto tensor_ty = + llvm::dyn_cast(shape_op.getInput().getType()); // Only ranked tensor can be folded. if (!tensor_ty) return {}; @@ -2269,8 +2274,8 @@ OpFoldResult StridedSliceOp::fold(FoldAdaptor) { // scalar or a vector based on `shrink_axis_mask` because we have rejected // the case of `new_axis_mask` != 0. auto output_elt_ty = - getOutput().getType().cast().getElementType(); - auto output_ty = getOutput().getType().dyn_cast(); + llvm::cast(getOutput().getType()).getElementType(); + auto output_ty = llvm::dyn_cast(getOutput().getType()); if (!output_ty || !output_ty.hasStaticShape()) { if (getShrinkAxisMask() == 1) { output_ty = tensorflow::GetTypeFromTFTensorShape({}, output_elt_ty); @@ -2296,7 +2301,7 @@ OpFoldResult StridedSliceOp::fold(FoldAdaptor) { LogicalResult StridedSliceGradOp::verify() { StridedSliceGradOp op = *this; - auto shape_type = op.getShape().getType().dyn_cast(); + auto shape_type = llvm::dyn_cast(op.getShape().getType()); if (shape_type && shape_type.getRank() != 1) return op.emitOpError("'shape' operand must be 1D tensor, but got ") << shape_type.getRank() << "D tensor"; @@ -2418,7 +2423,7 @@ LogicalResult TPUExecuteAndUpdateVariablesOp::verify() { TPUExecuteAndUpdateVariablesOp op = *this; int num_resource_args = 0; for (Type arg_type : op.getArgs().getTypes()) - if (arg_type.cast().getElementType().isa()) + if (isa(cast(arg_type).getElementType())) ++num_resource_args; auto check_attr = [&](ArrayAttr indices, llvm::StringRef name, @@ -2431,7 +2436,7 @@ LogicalResult TPUExecuteAndUpdateVariablesOp::verify() { << num_resource_args << "), but got " << indices.size(); for (const auto &entry : llvm::enumerate(indices.getValue())) { - auto int_attr = entry.value().cast(); + auto int_attr = llvm::cast(entry.value()); if (int_attr.getInt() < min) return op.emitOpError() << "requires '" << name << "' to contain values of at least " @@ -2457,20 +2462,16 @@ void TPUExecuteAndUpdateVariablesOp::getEffects( ResourceEffects::TPUExecute::get()); auto resource_handles = llvm::make_filter_range(getArgsMutable(), [](OpOperand &op_operand) { - return op_operand.get() - .getType() - .cast() - .getElementType() - .isa(); + return isa( + cast(op_operand.get().getType()).getElementType()); }); for (const auto& entry : llvm::enumerate(resource_handles)) { OpOperand &op_operand = entry.value(); effects.emplace_back(MemoryEffects::Read::get(), &op_operand, ResourceEffects::Variable::get()); - if (getDeviceVarUpdatesIndices() - .getValue()[entry.index()] - .cast() + if (llvm::cast( + getDeviceVarUpdatesIndices().getValue()[entry.index()]) .getInt() >= 0) effects.emplace_back(MemoryEffects::Write::get(), &op_operand, ResourceEffects::Variable::get()); @@ -2544,10 +2545,11 @@ LogicalResult TensorListReserveOp::verify() { //===----------------------------------------------------------------------===// OpFoldResult TensorListElementShapeOp::fold(FoldAdaptor) { - int width = - getType().cast().getElementType().getIntOrFloatBitWidth(); - auto variant_type = - getElementTypeOrSelf(getOperand().getType()).cast(); + int width = llvm::cast(getType()) + .getElementType() + .getIntOrFloatBitWidth(); + auto variant_type = llvm::cast( + getElementTypeOrSelf(getOperand().getType())); if (variant_type.getSubtypes().empty()) return {}; return ConvertShapeToAttr(variant_type.getSubtypes()[0], width); } @@ -2578,8 +2580,8 @@ LogicalResult TensorScatterUpdateOp::verify() { return op.emitOpError( "requires indices operand to have at least 1 dimension"); - auto tensor_ty = op.getTensor().getType().dyn_cast(); - auto indices_ty = op.getIndices().getType().dyn_cast(); + auto tensor_ty = llvm::dyn_cast(op.getTensor().getType()); + auto indices_ty = llvm::dyn_cast(op.getIndices().getType()); if (!tensor_ty || !indices_ty) return success(); int64_t num_index_dims = indices_ty.getShape().back(); @@ -2608,10 +2610,10 @@ LogicalResult TensorScatterUpdateOp::verify() { LogicalResult TileOp::verify() { TileOp op = *this; - auto input_type = op.getInput().getType().dyn_cast(); + auto input_type = llvm::dyn_cast(op.getInput().getType()); auto multiples_type = - op.getMultiples().getType().dyn_cast(); - auto output_type = op.getOutput().getType().dyn_cast(); + llvm::dyn_cast(op.getMultiples().getType()); + auto output_type = llvm::dyn_cast(op.getOutput().getType()); if (multiples_type && multiples_type.getRank() != 1) { return op.emitOpError() << "expected multiples to be rank 1, got rank = " @@ -2745,7 +2747,7 @@ class FuseWithBroadcastCompatibleOp continue; } - auto shape = tile.getInput().getType().dyn_cast(); + auto shape = llvm::dyn_cast(tile.getInput().getType()); if (!shape) { continue; } @@ -2837,13 +2839,13 @@ class ToBoolOfRankedTensor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ToBoolOp op, PatternRewriter &rewriter) const override { - auto type = op.getOperand().getType().dyn_cast(); + auto type = llvm::dyn_cast(op.getOperand().getType()); // If the input is an unranked tensor, cannpt rewrite. if (!type) return failure(); // Expected return type of the ToBool operation. The return type of ToBool // operation is always 0D tensor of bool type. - auto result_type = op.getResult().getType().cast(); + auto result_type = llvm::cast(op.getResult().getType()); // If input is already a tensor, it can be folded into an identity. if (type == result_type) { @@ -2858,7 +2860,7 @@ class ToBoolOfRankedTensor : public OpRewritePattern { Attribute zero_attr; if (element_type.isIntOrFloat()) zero_attr = rewriter.getZeroAttr(type); - else if (element_type.isa()) + else if (isa(element_type)) zero_attr = DenseStringElementsAttr::get(type, {""}); if (!zero_attr) return failure(); @@ -2905,7 +2907,7 @@ LogicalResult TPUPartitionedInputV2Op::verify() { int num_partitions = 1; const mlir::ArrayAttr partition_dims = op.getPartitionDims(); for (const mlir::Attribute &dim : partition_dims) { - num_partitions *= dim.cast().getInt(); + num_partitions *= llvm::cast(dim).getInt(); } const bool is_packed = op.getIsPacked(); @@ -2926,9 +2928,9 @@ LogicalResult TPUPartitionedInputV2Op::verify() { LogicalResult TransposeOp::verify() { TransposeOp op = *this; - auto perm_type = op.getPerm().getType().dyn_cast(); - auto x_type = op.getX().getType().dyn_cast(); - auto y_type = op.getY().getType().dyn_cast(); + auto perm_type = llvm::dyn_cast(op.getPerm().getType()); + auto x_type = llvm::dyn_cast(op.getX().getType()); + auto y_type = llvm::dyn_cast(op.getY().getType()); if (perm_type && perm_type.getRank() != 1) { return op.emitOpError() @@ -2985,7 +2987,7 @@ LogicalResult TransposeOp::verify() { // TODO(jpienaar): perm could be optional too. void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x, Value perm) { - auto x_type = x.getType().cast(); + auto x_type = llvm::cast(x.getType()); // If value is unranked, then so is results. if (!x_type.hasRank()) return TransposeOp::build(builder, result, @@ -2995,7 +2997,7 @@ void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x, // TODO(jpienaar): Handle unknown perm case. // TODO(jpienaar): Extract utility function. - auto etype = x_type.cast().getElementType(); + auto etype = llvm::cast(x_type).getElementType(); DenseIntElementsAttr attr_shape; if (matchPattern(perm, m_Constant(&attr_shape))) { llvm::SmallVector const_shape; @@ -3040,7 +3042,7 @@ OpFoldResult FoldCancellableTranspose(TransposeOp op) { if (transpose->getBlock() != op->getBlock()) { tensorflow::DataType dtype; auto status = tensorflow::ConvertToDataType( - op.getX().getType().cast().getElementType(), &dtype); + llvm::cast(op.getX().getType()).getElementType(), &dtype); if (status.ok()) { // We can only leave the transpose op on host if its dtype is supported on // host. @@ -3104,7 +3106,7 @@ class NMSV3ToNMSV4Op : public OpRewritePattern { } SmallVector new_result_types; new_result_types.push_back(nms_op.getType()); - auto input_ty = nms_op.getType().template cast(); + auto input_ty = llvm::cast(nms_op.getType()); // corresponds to the second result type of nmsv4 RankedTensorType valid_output_type = tensorflow::GetTypeFromTFTensorShape({}, input_ty.getElementType()); @@ -3184,7 +3186,7 @@ LogicalResult XlaCallModuleOp::verifySymbolUses( SymbolTableCollection &symbolTable) { for (auto f : getFunctionList()) { auto func = symbolTable.lookupNearestSymbolFrom( - getOperation(), f.cast()); + getOperation(), llvm::cast(f)); if (!func) { return emitOpError() << "refers to an undefined function: " << f; } @@ -3223,7 +3225,7 @@ std::optional XlaLaunchOp::GetResourceInstanceStr() { LogicalResult UnpackOp::verify() { UnpackOp op = *this; - auto value_type = op.getValue().getType().dyn_cast(); + auto value_type = llvm::dyn_cast(op.getValue().getType()); if (!value_type) return success(); int64_t value_rank = value_type.getRank(); @@ -3321,9 +3323,9 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) { if (!HasRankAtMost(op.getNumSegments(), 0)) return op.emitOpError("number of segments should be a 0-D tensor"); - auto data_type = op.getData().getType().template dyn_cast(); + auto data_type = llvm::dyn_cast(op.getData().getType()); auto segment_ids_type = - op.getSegmentIds().getType().template dyn_cast(); + llvm::dyn_cast(op.getSegmentIds().getType()); if (data_type && segment_ids_type) { if (data_type.getRank() < segment_ids_type.getRank()) return op.emitOpError( @@ -3434,11 +3436,12 @@ void VariableOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult VariableShapeOp::verify() { VariableShapeOp op = *this; - auto input_type = op.getInput().getType().cast(); + auto input_type = llvm::cast(op.getInput().getType()); if (input_type.hasStaticShape() && input_type.getNumElements() != 1) return op.emitOpError("requires input to have one resource"); - auto resource_type = input_type.getElementType().cast(); + auto resource_type = + llvm::cast(input_type.getElementType()); auto subtypes = resource_type.getSubtypes(); switch (subtypes.size()) { case 1: @@ -3453,10 +3456,11 @@ LogicalResult VariableShapeOp::verify() { } OpFoldResult VariableShapeOp::fold(FoldAdaptor) { - int width = - getType().cast().getElementType().getIntOrFloatBitWidth(); - auto resource_type = - getElementTypeOrSelf(getOperand().getType()).cast(); + int width = llvm::cast(getType()) + .getElementType() + .getIntOrFloatBitWidth(); + auto resource_type = llvm::cast( + getElementTypeOrSelf(getOperand().getType())); if (resource_type.getSubtypes().empty()) return {}; return ConvertShapeToAttr(resource_type.getSubtypes()[0], width); } @@ -3566,7 +3570,7 @@ LogicalResult WhileRegionOp::verify() { << "condition should yield a tensor and forward the arguments"; auto cond_type = - cond_yield->getOperand(0).getType().dyn_cast(); + llvm::dyn_cast(cond_yield->getOperand(0).getType()); if (!cond_type || !cond_type.getShape().equals({}) || !cond_type.getElementType().isInteger(/*width=*/1)) return op.emitOpError() @@ -3852,8 +3856,8 @@ LogicalResult XlaBroadcastHelperOp::inferReturnTypeComponents( return success(); }; - RankedTensorType lhs_ty = lhs.getType().dyn_cast(); - RankedTensorType rhs_ty = rhs.getType().dyn_cast(); + RankedTensorType lhs_ty = llvm::dyn_cast(lhs.getType()); + RankedTensorType rhs_ty = llvm::dyn_cast(rhs.getType()); if (!lhs_ty || !rhs_ty) return set_unranked_results(); int64_t lhs_rank = lhs_ty.getRank(); @@ -3871,8 +3875,8 @@ LogicalResult XlaBroadcastHelperOp::inferReturnTypeComponents( "if broadcast_dims is empty, both arguments must have equal rank or " "at least one argument must be a scalar"); } - inferredReturnShapes.emplace_back(lhs_ty.cast()); - inferredReturnShapes.emplace_back(rhs_ty.cast()); + inferredReturnShapes.emplace_back(llvm::cast(lhs_ty)); + inferredReturnShapes.emplace_back(llvm::cast(rhs_ty)); return success(); } @@ -3904,9 +3908,9 @@ LogicalResult XlaBroadcastHelperOp::inferReturnTypeComponents( if (broadcast_lhs) { inferredReturnShapes.emplace_back(broadcast_shape, lhs_ty.getElementType()); - inferredReturnShapes.emplace_back(rhs_ty.cast()); + inferredReturnShapes.emplace_back(llvm::cast(rhs_ty)); } else { - inferredReturnShapes.emplace_back(lhs_ty.cast()); + inferredReturnShapes.emplace_back(llvm::cast(lhs_ty)); inferredReturnShapes.emplace_back(broadcast_shape, rhs_ty.getElementType()); } return success(); @@ -3984,7 +3988,7 @@ LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypeComponents( SmallVectorImpl &inferredReturnShapes) { XlaSetDynamicDimensionSizeOpAdaptor op(operands.getValues(), attributes); - TensorType operand_ty = op.getInput().getType().cast(); + TensorType operand_ty = llvm::cast(op.getInput().getType()); Type element_ty = operand_ty.getElementType(); TensorType result_ty; @@ -4009,7 +4013,7 @@ LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypeComponents( result_ty = UnrankedTensorType::get(element_ty); } - inferredReturnShapes.emplace_back(result_ty.cast()); + inferredReturnShapes.emplace_back(llvm::cast(result_ty)); return success(); } @@ -4045,7 +4049,7 @@ void XlaReduceOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult XlaReduceWindowOp::verify() { XlaReduceWindowOp op = *this; - const auto &input_ty = op.getInput().getType().cast(); + const auto &input_ty = llvm::cast(op.getInput().getType()); auto check = [&](mlir::Value val, std::string attr_name) -> LogicalResult { ElementsAttr attr; @@ -4114,7 +4118,7 @@ LogicalResult XlaReduceWindowOp::verify() { LogicalResult XlaSelectAndScatterOp::verify() { XlaSelectAndScatterOp op = *this; - auto input_ty = op.getOperand().getType().cast(); + auto input_ty = llvm::cast(op.getOperand().getType()); auto check = [&](mlir::Value val, std::string attr_name) -> LogicalResult { ElementsAttr attr; @@ -4188,9 +4192,9 @@ LogicalResult XlaVariadicReduceOp::verify() { // We rely on V2 for the majority of the checks. const auto &input_ty = op.getInput().getType(); if (input_ty.empty()) return op.emitOpError() << "No input"; - const auto &dtype = input_ty[0].cast().getElementType(); + const auto &dtype = llvm::cast(input_ty[0]).getElementType(); for (const auto &ty : input_ty) { - if (ty.cast().getElementType() != dtype) + if (llvm::cast(ty).getElementType() != dtype) return op.emitOpError() << "This version is limited to operands of the same dtype"; } @@ -4234,10 +4238,10 @@ LogicalResult XlaVariadicReduceV2Op::verify() { << n_init_values << ")"; } - auto input_ty_0 = inputs_ty[0].cast(); + auto input_ty_0 = llvm::cast(inputs_ty[0]); if (input_ty_0.hasStaticShape()) { for (int i = 1; i < n_inputs; ++i) { - auto input_ty_i = inputs_ty[i].cast(); + auto input_ty_i = llvm::cast(inputs_ty[i]); if (input_ty_i.hasStaticShape() && input_ty_i.getShape() != input_ty_0.getShape()) { return op.emitOpError() @@ -4254,7 +4258,7 @@ LogicalResult XlaVariadicReduceV2Op::verify() { } for (int i = 0; i < n_inputs; ++i) { - auto init_value_ty_i = init_values_ty[i].cast(); + auto init_value_ty_i = llvm::cast(init_values_ty[i]); if (init_value_ty_i.hasRank() && init_value_ty_i.getRank() != 0) { return op.emitOpError() << "init_values[" << i << "] must be a scalar but got [" @@ -4280,10 +4284,10 @@ LogicalResult XlaVariadicSortOp::verify() { XlaVariadicSortOp op = *this; const auto &inputs_ty = op.getInputs().getType(); int n_inputs = inputs_ty.size(); - auto input_ty_0 = inputs_ty[0].cast(); + auto input_ty_0 = llvm::cast(inputs_ty[0]); if (input_ty_0.hasStaticShape()) { for (int i = 1; i < n_inputs; ++i) { - auto input_ty_i = inputs_ty[i].cast(); + auto input_ty_i = llvm::cast(inputs_ty[i]); if (input_ty_i.hasStaticShape() && input_ty_i.getShape() != input_ty_0.getShape()) { return op.emitOpError() @@ -4318,10 +4322,9 @@ LogicalResult XlaVariadicSortOp::verify() { LogicalResult SetStaticDimensionBoundsOp::verify() { SetStaticDimensionBoundsOp op = *this; - mlir::ShapedType input_type = - op.getInput().getType().cast(); + mlir::ShapedType input_type = llvm::cast(op.getInput().getType()); mlir::ShapedType static_shape_type = - op.getStaticShape().getType().cast(); + llvm::cast(op.getStaticShape().getType()); int input_type_rank = input_type.hasRank() ? input_type.getRank() : -1; if (input_type_rank > 2) { return op.emitOpError() << "was used with an input tensor with rank > 2, " @@ -4348,8 +4351,8 @@ template LogicalResult VerifyScalesAndZeroPoints(UniformQuantizedOp op, Value scales, Value zero_points, int32_t quantization_axis) { - ShapedType scales_type = scales.getType().cast(); - ShapedType zero_points_type = zero_points.getType().cast(); + ShapedType scales_type = llvm::cast(scales.getType()); + ShapedType zero_points_type = llvm::cast(zero_points.getType()); if (quantization_axis == -1) { if (scales_type.hasRank() && scales_type.getRank() != 0) { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir index 5f59e35498151e..abff7aeb61a2d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir @@ -656,7 +656,6 @@ func.func @incomplete_composite_devices_while_body(%arg0: !tf_res, %arg1: !tf_re %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control) wraps "tf.NoOp"() : () -> () // CHECK: [[exe]]{{.*}}"tf.Identity" - // CHECK-NOT: "tf.Identity" // CHECK: tf_executor.fetch tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control } @@ -816,11 +815,11 @@ func.func @tpu_execute_with_non_resource_operands(%arg0: !tf_res {tf._composite_ func.func @double_tpu_execute_while_body(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (!tf_res, !tf_res, tensor) { - // CHECK: "tf.Identity" %graph:3 = tf_executor.graph { // CHECK: {{.*}}, [[ctrl1:%.*]] = tf_executor.island wraps "tf.Identity" // CHECK: {{.*}}, [[ctrl2:%.*]] = tf_executor.island wraps "tf.Identity" // CHECK: "tf.Identity" + // CHECK: "tf.Identity" %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str // CHECK: [[exe_ctrl1:%.*]] = tf_executor.island([[ctrl1]]) wraps "tf.TPUExecuteAndUpdateVariables" %exe_control1 = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg2, %arg0, %arg1, %key) { @@ -887,9 +886,9 @@ func.func @tpu_executes_on_same_device_while_body(%arg0: !tf_res, %arg1: !tf_res %arg2: tensor) -> (!tf_res, !tf_res, tensor) { %graph:3 = tf_executor.graph { - // CHECK: "tf.Identity" // CHECK: {{.*}}, [[id_ctrl:%.*]] = tf_executor.island wraps "tf.Identity" // CHECK: "tf.Identity" + // CHECK: "tf.Identity" %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str // CHECK: [[exe_ctrl1:%.*]] = tf_executor.island([[id_ctrl]]) wraps "tf.TPUExecuteAndUpdateVariables" %exe_control1 = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg2, %arg0, %arg1, %key) { @@ -911,8 +910,8 @@ func.func @tpu_executes_on_same_device_while_body(%arg0: !tf_res, %arg1: !tf_res %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control1, %exe_control2) wraps "tf.NoOp"() : () -> () - // CHECK: "tf.Identity"(%arg3) // CHECK: tf_executor.island([[exe_ctrl1]], [[exe_ctrl2]]) wraps "tf.Identity" + // CHECK: "tf.Identity"(%arg4) // CHECK: "tf.Identity"(%arg5) // CHECK-NEXT: tf_executor.fetch tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index 19a1137b20de4f..54d92b5b2ece06 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -13,12 +13,7 @@ package( gentbl_cc_library( name = "tensorflow_canonicalize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_canonicalize.inc", - ), - ], + tbl_outs = {"generated_canonicalize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "canonicalize.td", deps = [ @@ -29,12 +24,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_reduce_patterns_inc_gen", - tbl_outs = [ - ( - ["-gen-rewriters"], - "reducer/tf_reduce_patterns.inc", - ), - ], + tbl_outs = {"reducer/tf_reduce_patterns.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "reducer/tf_mlir_reduce_patterns.td", deps = [ @@ -89,12 +79,7 @@ cc_library( gentbl_cc_library( name = "decompose_resource_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_decompose_resource_ops.inc", - ), - ], + tbl_outs = {"generated_decompose_resource_ops.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "decompose_resource_ops.td", deps = [ @@ -118,6 +103,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:framework", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", ], ) @@ -152,12 +138,7 @@ cc_library( gentbl_cc_library( name = "tf_data_optimization_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_tf_data_optimization.inc", - ), - ], + tbl_outs = {"generated_tf_data_optimization.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_data_optimization.td", deps = [ @@ -376,19 +357,13 @@ cc_library( gentbl_cc_library( name = "tf_pass_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlow", - ], - "tf_passes.h.inc", - ), - ( - ["-gen-pass-doc"], - "g3doc/_includes/tf_passes.md", - ), - ], + tbl_outs = { + "tf_passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlow", + ], + "g3doc/_includes/tf_passes.md": ["-gen-pass-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_passes.td", deps = [ @@ -399,19 +374,13 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_device_pass_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlowDevice", - ], - "tf_device_passes.h.inc", - ), - ( - ["-gen-pass-doc"], - "g3doc/includes/tf_device_passes.md", - ), - ], + tbl_outs = { + "tf_device_passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlowDevice", + ], + "g3doc/includes/tf_device_passes.md": ["-gen-pass-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_device_passes.td", deps = [ @@ -422,19 +391,13 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_savedmodel_pass_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlowSavedModel", - ], - "tf_savedmodel_passes.h.inc", - ), - ( - ["-gen-pass-doc"], - "g3doc/includes/tf_savedmodel_passes.md", - ), - ], + tbl_outs = { + "tf_savedmodel_passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlowSavedModel", + ], + "g3doc/includes/tf_savedmodel_passes.md": ["-gen-pass-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_savedmodel_passes.td", deps = [ @@ -445,19 +408,13 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_test_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlowTest", - ], - "test_passes.h.inc", - ), - ( - ["-gen-pass-doc"], - "g3doc/includes/tf_test_passes.md", - ), - ], + tbl_outs = { + "test_passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlowTest", + ], + "g3doc/includes/tf_test_passes.md": ["-gen-pass-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_test_passes.td", deps = [ @@ -601,7 +558,6 @@ cc_library( ":verify_no_outside_compilation_markers_pass", "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", - "//tensorflow/compiler/mlir/lite:validators", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", @@ -643,6 +599,7 @@ cc_library( "//tensorflow/compiler/mlir/tf2xla/transforms:split_into_island_per_op_pass", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_with_tf2xla", + "//tensorflow/compiler/mlir/utils:validators", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla/kernels:xla_call_module_loader", "//tensorflow/core:core_cpu_base", @@ -840,6 +797,7 @@ cc_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", + "@local_xla//xla/mlir_hlo", "@local_xla//xla/service:shape_inference", "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/util:env_var", @@ -1025,12 +983,7 @@ filegroup( gentbl_cc_library( name = "tensorflow_optimize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_optimize.inc", - ), - ], + tbl_outs = {"generated_optimize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "optimize.td", deps = [ @@ -1045,12 +998,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lower_tf_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_lower_tf.inc", - ), - ], + tbl_outs = {"generated_lower_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lower_tf.td", deps = [ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc index 6262cad26ca6e3..d63ace094451a6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc @@ -424,7 +424,7 @@ void ChainResourceOps( for (auto class_iter = resource_equivalence_classes.begin(); class_iter != resource_equivalence_classes.end(); ++class_iter) { // Only visit one element per class, the leader. - if (!class_iter->isLeader()) continue; + if (!(*class_iter)->isLeader()) continue; // Create chain source and sink identity islands for current equivalence // class. @@ -445,7 +445,7 @@ void ChainResourceOps( // by `class_iter`). Keep track of ops that have already been processed. llvm::SmallDenseSet processed_ops; for (auto member_iter = - resource_equivalence_classes.member_begin(class_iter); + resource_equivalence_classes.member_begin(**class_iter); member_iter != resource_equivalence_classes.member_end(); ++member_iter) { ResourceAndDevice resource_and_device = *member_iter; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc index 970e55c20855fb..144bdb44018649 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -32,10 +33,8 @@ namespace { // Returns subtype of `resource` if present. Otherwise an unranked tensor type // of `element_type` is returned. static Type GetResourceSubtypeOrDefault(Value resource, Type element_type) { - auto resource_type = resource.getType() - .cast() - .getElementType() - .cast(); + auto resource_type = llvm::cast( + llvm::cast(resource.getType()).getElementType()); if (resource_type.getSubtypes().size() == 1) return resource_type.getSubtypes().front(); @@ -43,19 +42,15 @@ static Type GetResourceSubtypeOrDefault(Value resource, Type element_type) { } static bool HasResourceSubtype(Value resource) { - return resource.getType() - .cast() - .getElementType() - .cast() + return llvm::cast( + llvm::cast(resource.getType()).getElementType()) .getSubtypes() .size() == 1; } static Type GetResourceSubtype(Value resource) { - return resource.getType() - .cast() - .getElementType() - .cast() + return llvm::cast( + llvm::cast(resource.getType()).getElementType()) .getSubtypes() .front(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td index f466c1d48d6835..1fc666da4a8d95 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td @@ -30,7 +30,7 @@ def CreateTFReadVariableOp : NativeCodeCall< "$_builder.create(" " $0.getLoc()," " GetResourceSubtypeOrDefault(" - " $2, $1.getType().cast().getElementType())," + " $2, llvm::cast($1.getType()).getElementType())," " $2)" >; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc index 61bba38454afd8..8bdd088b2ddef2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/Inliner.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -57,6 +58,7 @@ void ExecutorTPUV1IslandInliningPass::runOnOperation() { if (!nested_module) return; InlinerInterface inliner(&getContext()); + InlinerConfig config; auto walk_result = getOperation().walk([&](TF::PartitionedCallOp call_op) { if (!call_op.getF().getRootReference().getValue().starts_with( kNestedModule)) @@ -69,7 +71,7 @@ void ExecutorTPUV1IslandInliningPass::runOnOperation() { auto called_func = dyn_cast_or_null(call_interface.resolveCallable()); - if (failed(inlineCall(inliner, call_interface, + if (failed(inlineCall(inliner, config.getCloneCallback(), call_interface, cast(called_func.getOperation()), called_func.getCallableRegion(), /* shouldCloneInlinedRegion = */ false))) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc index d7db0f03ed1a8c..2c70a078fbb13a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc @@ -52,9 +52,10 @@ struct HoistReplicateInvariantResourceWritesPass // TODO(prakalps): This is a common utility and other passes use something // similar. Move to common utils. bool IsResourceType(Type type) { - return type.isa() || - (type.isa() && - type.cast().getElementType().isa()); + return llvm::isa(type) || + (llvm::isa(type) && + llvm::isa( + llvm::cast(type).getElementType())); } SmallVector GetAccessedResources(Operation& op) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD index abeec429e2fe1d..838dc1eb6fb8c1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD @@ -6,7 +6,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__", "//tensorflow/compiler/mlir:__pkg__", "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", "//tensorflow/compiler/mlir/tf2xla/api:__subpackages__", @@ -142,15 +141,10 @@ tf_cc_test( gentbl_cc_library( name = "runtime_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=RuntimeLowering", - ], - "runtime_passes.h.inc", - ), - ], + tbl_outs = {"runtime_passes.h.inc": [ + "-gen-pass-decls", + "-name=RuntimeLowering", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "runtime_passes.td", deps = [ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_merge_variables_with_execute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_merge_variables_with_execute.cc index 714fefaca8cded..9492c007b07ca5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_merge_variables_with_execute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_merge_variables_with_execute.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -147,8 +148,8 @@ bool AddAccessedResourceIds( bool IsResourceMergeable(Attribute& resource_attr, Attribute& device_attr) { return resource_attr && ((resource_attr == device_attr) || - (resource_attr.cast().getValue().find( - "COMPOSITE") != llvm::StringRef::npos)); + (llvm::cast(resource_attr).getValue().find("COMPOSITE") != + llvm::StringRef::npos)); } // Finds the variable access info for a TPUExecute op. @@ -196,7 +197,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( // Check device matching for the node defining the resource. if (!IsResourceMergeable(resource_attr, device_attr)) continue; } else { - auto resource_arg = resource.dyn_cast(); + auto resource_arg = dyn_cast(resource); assert(resource_arg); if (resource_arg.getOwner() != &func.front()) continue; // Check device matching for the argument defining the resource. @@ -518,8 +519,8 @@ LogicalResult MergeForOneTPUExecute( // Check that all resources are either read or written to. for (auto it : llvm::enumerate(var_access_info.new_operand_values)) { Type type = it.value().getType(); - if (type.isa() && - type.cast().getElementType().isa()) { + if (isa(type) && + isa(cast(type).getElementType())) { if (!llvm::is_contained(device_var_reads_indices, it.index()) && !llvm::is_contained(device_var_updates_indices, it.index())) { return execute_launch.GetBody().front().emitError("operand #") diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index 4c7810f8df51b1..a9ff5a8f76268a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -34,7 +34,7 @@ class GetI64ScalarElementsAttr : class GetF32Scalar : NativeCodeCall<"GetF32Scalar(&$_builder, " # value # ")">; -def TrueBoolAttr : AttrConstraint().getValue()">>; +def TrueBoolAttr : AttrConstraint($_self).getValue()">>; def CreateTFShapeOp : NativeCodeCall< "$_builder.create($0.getLoc(), $1, $2)">; @@ -74,7 +74,7 @@ def LowerAddOp : Pat<(TF_AddOp TF_NumberNotQuantizedTensor:$x, def GetBiasAddGradReductionIndices : NativeCodeCall< "GetBiasAddGradReductionIndices(" - "$0.getType().cast().getRank(), $1, &$_builder)">; + "llvm::cast($0.getType()).getRank(), $1, &$_builder)">; def LowerBiasAddGradOp : Pat<(TF_BiasAddGradOp AnyRankedTensor:$out_backprop, $data_format), @@ -120,12 +120,12 @@ def LowerSoftmaxCrossEntropyWithLogitsOp : Pattern< // dimension should be known. class GetDimSizeOfType : NativeCodeCall< "GetScalarOfType(getElementTypeOrSelf($1), " - "$0.getType().cast().getDimSize(" # dim # "))">; + "llvm::cast($0.getType()).getDimSize(" # dim # "))">; // Same as the above with i32 element type. class GetDimSizeAsI32 : NativeCodeCall< "GetScalarOfType($_builder.getIntegerType(32), " - "$0.getType().cast().getDimSize(" # dim # "))">; + "llvm::cast($0.getType()).getDimSize(" # dim # "))">; // Sparse version of SoftmaxCrossEntropyWithLogits is lowered to dense by // expanding the sparse labels using: @@ -285,7 +285,7 @@ def LowerIsNanOp : Pat<(TF_IsNanOp $x), def GetAllAxes : NativeCodeCall< "GetI64ElementsAttrForSeq(" - "0, $0.getType().cast().getRank(), &$_builder)">; + "0, llvm::cast($0.getType()).getRank(), &$_builder)">; // L2Loss is lowered using the formula, // L2Loss(input) = Sum(input * input) / 2 diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index be7e914bd29846..f02dffc5d6f2f3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -27,10 +27,10 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h" +#include "tensorflow/compiler/mlir/utils/validators.h" // IWYU pragma: keep namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td index be01d276902047..9ad34d2064c764 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td @@ -23,27 +23,27 @@ def IsDataFormatNHWC : ConstantAttr; // Get the last dimension size as a 1-d single element attr. def GetLastDimSizeAsI32 : NativeCodeCall< "DenseElementsAttr::get(RankedTensorType::get({1}, $_builder.getIntegerType(32)), " - "static_cast($0.getType().cast().getDimSize( " - " $0.getType().cast().getRank() - 1)))">; + "static_cast(llvm::cast($0.getType()).getDimSize( " + " llvm::cast($0.getType()).getRank() - 1)))">; // Check whether the tensor is ranked and whether its last dim is static. def IsRankedShapeLastDimStatic : Constraint()">, - CPred<"!$0.getType().cast().isDynamicDim( " - " $0.getType().cast().getRank() - 1)">]>>; + CPred<"llvm::isa($0.getType())">, + CPred<"!llvm::cast($0.getType()).isDynamicDim( " + " llvm::cast($0.getType()).getRank() - 1)">]>>; def IsNotComplexType : Constraint()">, - CPred<"!$0.getType().cast().getElementType().isa()"> + CPred<"llvm::isa($0.getType())">, + CPred<"!llvm::isa(llvm::cast($0.getType()).getElementType())"> ]>>; // Only fuse multiplier if all dimensions other than the channel dimension // are equal to 1. def CanFuseMulAndConv2D : - Constraint>; + Constraint>; def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; + CPred<"llvm::cast($_self).getShapedType().getElementType().isF32()">, "float constant tensor">; def DefinedByConv2D : Constraint($0.getDefiningOp())">>; // Checks if the value has only one user. def HasOneUse : Constraint>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc index 03e3072732b92f..8b5d2e0de1e26a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc @@ -67,21 +67,19 @@ class PrepareTpuComputationForTfExportPass }; class RewriteXlaHostComputeMlir - : public OpRewritePattern::SplitMatchAndRewrite { + : public OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(TF::_XlaHostComputeMlirOp op) const override { + LogicalResult matchAndRewrite(TF::_XlaHostComputeMlirOp op, + PatternRewriter& rewriter) const override { if (op.getManualSharding()) { // This rewrite does not support manual_sharding. It is expected that the // _XlaHostComputeMlirOp registered as an MlirXlaOpKernel will handle this // case later once the XlaBuilder graph reaches it. return failure(); } - return success(); - } - void rewrite(TF::_XlaHostComputeMlirOp op, - PatternRewriter& rewriter) const override { + llvm::SmallVector shape_attrs; shape_attrs.reserve(op.getNumResults()); for (Type ty : op.getResultTypes()) { @@ -141,6 +139,7 @@ class RewriteXlaHostComputeMlir op.getRecvKeyAttr(), /*cost_estimate_ns=*/rewriter.getI64IntegerAttr(kDefaultCostEstimate), /*tpu_core=*/rewriter.getI64IntegerAttr(0)); + return success(); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 3928faaa280398..4b699773371e3c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -94,7 +94,8 @@ LogicalResult GetDeviceOrdinal(const std::optional& devices, << " to be present in 'tf.device.replicate' op"; } llvm::StringRef tpu_device = - tpu_replica.cast()[replica_id].cast().getValue(); + llvm::cast(llvm::cast(tpu_replica)[replica_id]) + .getValue(); return tensorflow::GetDeviceOrdinalFromDeviceString(op->getLoc(), tpu_device, &device_ordinal); } @@ -136,9 +137,9 @@ LogicalResult UpdateRegionReplicateVariantOps( // Map aliased devices to explicit devices based on replica. if (auto launch = dyn_cast(op)) if (auto device_by_replica = devices.value().get(launch.getDevice())) - launch->setAttr( - kDeviceAttr, - device_by_replica.cast()[replica_id].cast()); + launch->setAttr(kDeviceAttr, + llvm::cast(llvm::cast( + device_by_replica)[replica_id])); return WalkResult::advance(); }); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 50f6cc54c4e12c..106c65368a18f8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -92,6 +92,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/shape_inference.h" #include "xla/shape.h" #include "xla/tsl/platform/errors.h" @@ -510,7 +511,7 @@ Type GetNewArgType(Type old_arg_type, ArrayRef shape, } new_arg_type = tensorflow::GetTypeFromTFTensorShape( new_shape, element_type, - mhlo::TypeExtensionsAttr::get(context, new_bounds)); + mlir::mhlo::TypeExtensionsAttr::get(context, new_bounds)); } } return new_arg_type; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD index d19d5e8e8ab5aa..74f952a6cb7db6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD @@ -16,15 +16,10 @@ package( gentbl_cc_library( name = "sparsecore_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=SparseCore", - ], - "sparsecore_passes.h.inc", - ), - ], + tbl_outs = {"sparsecore_passes.h.inc": [ + "-gen-pass-decls", + "-name=SparseCore", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "sparsecore_passes.td", deps = [ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc index ccd246bd0d85a2..d22180fdbe45f5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc @@ -148,6 +148,7 @@ return selected_results #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Inliner.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" @@ -422,6 +423,7 @@ struct Inliner : public InlinerInterface { LogicalResult InlineCallsInFunc(func::FuncOp func, bool inline_all_funcs = false) { llvm::SetVector ops_to_erase; + InlinerConfig config; for (auto caller : func.getRegion().getOps()) { if (!inline_all_funcs && @@ -441,7 +443,8 @@ struct Inliner : public InlinerInterface { auto callee = llvm::dyn_cast(symbol_table.lookup(caller.getF())); auto& src_region = callee.getRegion(); - auto result = inlineCall(*this, caller, callee, &src_region, true); + auto result = inlineCall(*this, config.getCloneCallback(), caller, callee, + &src_region, true); if (failed(result)) { func.emitError("Inliner failed"); return result; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index 47b046d9fdaee2..7326c0bde1201b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -88,7 +88,8 @@ LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split, if (!lengths_const) return split.emitOpError("non-constant split lengths"); *count = lengths_const.getValue().getNumElements(); if (*count <= 0) return split.emitOpError("non-positive split count"); - auto buffer_type = split.getValue().getType().dyn_cast(); + auto buffer_type = + llvm::dyn_cast(split.getValue().getType()); if (!buffer_type || !buffer_type.hasStaticShape() || buffer_type.getRank() < 1) { return split.emitOpError("unknown or invalid split tensor shape"); @@ -110,7 +111,7 @@ LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split, // Tries to infer the tensor array element shape. std::optional> GetTensorArrayElementShape( TF::TensorArrayV3Op ta, ModuleOp module) { - auto element_shape = ta.getElementShapeAttr().cast(); + auto element_shape = llvm::cast(ta.getElementShapeAttr()); if (element_shape.hasStaticShape()) { auto shape = element_shape.getShape(); // Convert int64 to int64_t. @@ -142,20 +143,22 @@ std::optional> GetTensorArrayElementShape( // TensorArrayScatter writes vector of tensors to TensorArray. We can // deduce the shape of TensorArray by dropping the 0th dim of // TensorArrayScatter `value`. - auto t = scatter.getValue().getType().dyn_cast(); + auto t = + llvm::dyn_cast(scatter.getValue().getType()); if (!t || t.getShape().empty()) return std::nullopt; return RankedTensorType::get(t.getShape().drop_front(), t.getElementType()); } else if (auto gather = llvm::dyn_cast(user)) { // Try to infer from result type of gather. - auto t = gather.getValue().getType().dyn_cast(); + auto t = + llvm::dyn_cast(gather.getValue().getType()); if (t && !t.getShape().empty()) return RankedTensorType::get(t.getShape().drop_front(), t.getElementType()); // Try to infer from `element_shape` attribute of gather. - auto element_shape = gather.getElementShapeAttr() - .dyn_cast_or_null(); + auto element_shape = llvm::dyn_cast_if_present( + gather.getElementShapeAttr()); if (element_shape && element_shape.hasStaticShape()) { return RankedTensorType::get(element_shape.getShape(), gather.getDtype()); @@ -211,7 +214,7 @@ LogicalResult HandleTensorArrayV3Op( } auto var_type = RankedTensorType::get( {}, TF::ResourceType::get( - ArrayRef{buffer.getType().cast()}, + ArrayRef{llvm::cast(buffer.getType())}, ta.getContext())); auto local_var = builder.create( ta.getLoc(), ArrayRef{var_type}, ArrayRef{}); @@ -270,7 +273,7 @@ LogicalResult HandleTensorArrayWriteV3Op( cutil::GetElement(index_reshape, buffer, builder, write.getLoc(), /*keep_slice_shape=*/true); // Add a size-1 leading dimension to elem. - auto slice_type = original_elem.getType().cast(); + auto slice_type = llvm::cast(original_elem.getType()); elem = builder.create( write.getLoc(), ArrayRef{slice_type}, ArrayRef{elem, cutil::GetR1Const(slice_type.getShape(), builder, @@ -295,7 +298,7 @@ LogicalResult HandleTensorArrayConcatV3Op( } OpBuilder builder(concat); auto buffer = cutil::ReadLocalVariable(local_var, builder, concat.getLoc()); - auto buffer_type = buffer.getType().cast(); + auto buffer_type = llvm::cast(buffer.getType()); if (buffer_type.getShape().size() <= 1) { return concat.emitOpError("cannot concat on scalar-element tensor array"); } @@ -369,10 +372,9 @@ LogicalResult HandleTensorArraySizeV3Op( if (stats.count(local_var) == 0) { return size.emitOpError("unknown tensor array"); } - auto buffer_type = getElementTypeOrSelf(local_var.getType()) - .cast() - .getSubtypes()[0] - .cast(); + auto buffer_type = llvm::cast( + llvm::cast(getElementTypeOrSelf(local_var.getType())) + .getSubtypes()[0]); OpBuilder builder(size); auto result = cutil::CreateScalarConst(buffer_type.getDimSize(0), builder, size.getLoc()); @@ -387,10 +389,9 @@ LogicalResult CreateAndInitializeGradVariable(Type local_var_type, *var = builder.create( op->getLoc(), ArrayRef{local_var_type}, ArrayRef{}); Value buffer; - auto buffer_type = getElementTypeOrSelf(local_var_type) - .cast() - .getSubtypes()[0] - .cast(); + auto buffer_type = llvm::cast( + llvm::cast(getElementTypeOrSelf(local_var_type)) + .getSubtypes()[0]); if (failed(cutil::CreateInitBufferValue( buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op, buffer_type.getElementType(), builder, &buffer))) { @@ -478,7 +479,7 @@ llvm::SmallDenseMap> AccessedGradients( llvm::SmallDenseMap> result; llvm::SmallDenseMap> result_sets; auto insert = [&](Value v, const string& source, const Block& func_block) { - auto arg = v.dyn_cast(); + auto arg = dyn_cast(v); if (!arg || arg.getOwner() != &func_block) return; auto insert_res = result_sets[arg.getArgNumber()].insert(source); if (!insert_res.second) return; @@ -594,7 +595,7 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, for (int64_t i = 0; i < while_op.getNumResults(); ++i) { if (!ta_arg_buffer_type(i)) continue; auto retval = old_body_ret->getOperand(i); - auto arg = retval.dyn_cast(); + auto arg = dyn_cast(retval); if (!arg) { return while_op.emitOpError( "output tensor array does not alias input in a while loop"); @@ -702,13 +703,13 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, if_op->getAttrs()); auto ret_forwards_input = [](func::FuncOp f, int64_t ret_ind) -> int64_t { auto retval = f.front().getTerminator()->getOperand(ret_ind); - auto arg = retval.dyn_cast(); + auto arg = dyn_cast(retval); if (!arg) return -1; return arg.getArgNumber(); }; for (int64_t i = 0; i < if_op.getNumResults(); ++i) { - if (!getElementTypeOrSelf(if_op.getResult(i).getType()) - .isa()) { + if (!isa( + getElementTypeOrSelf(if_op.getResult(i).getType()))) { if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i)); continue; } @@ -811,8 +812,8 @@ LogicalResult HandlePartitionedCallOp( } for (int64_t i = 0; i < call.getNumResults(); ++i) { auto ret = lowered_callee.front().getTerminator()->getOperand(i); - if (!getElementTypeOrSelf(ret.getType()).isa()) continue; - auto arg = ret.dyn_cast(); + if (!isa(getElementTypeOrSelf(ret.getType()))) continue; + auto arg = dyn_cast(ret); if (!arg) continue; info.ret_forward_input.emplace_back(i, arg.getArgNumber()); } @@ -842,7 +843,7 @@ LogicalResult HandleRegionControlFlowOps( llvm::StringMap* decomposed_partitioned_call_callees) { for (OpOperand& operand : op.getOpOperands()) { - if (getElementTypeOrSelf(operand.get().getType()).isa()) { + if (isa(getElementTypeOrSelf(operand.get().getType()))) { return op.emitOpError() << "found unexpected type " << operand.get().getType() << " of operand #" << operand.getOperandNumber() @@ -851,7 +852,7 @@ LogicalResult HandleRegionControlFlowOps( } } for (OpResult result : op.getResults()) { - if (getElementTypeOrSelf(result.getType()).isa()) { + if (isa(getElementTypeOrSelf(result.getType()))) { return op.emitOpError() << "found unexpected type " << result.getType() << " of result #" << result.getResultNumber() diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc index fdacf313d30240..18344894ff4cf3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" @@ -49,19 +50,18 @@ struct TPUResourceReadsWritesPartitioningPass bool AllResourceTypesHaveSubtypes(TypeRange resources) { for (Type resource : resources) - if (!llvm::hasSingleElement(resource.cast() - .getElementType() - .cast() - .getSubtypes())) + if (!llvm::hasSingleElement( + llvm::cast( + llvm::cast(resource).getElementType()) + .getSubtypes())) return false; return true; } Type GetResourceSubtype(Type type) { - return type.cast() - .getElementType() - .cast() + return llvm::cast( + llvm::cast(type).getElementType()) .getSubtypes() .front(); } @@ -118,7 +118,7 @@ mlir::Attribute GetDeviceOfResource(mlir::func::FuncOp func, if (auto* resource_op = resource.getDefiningOp()) { return resource_op->getAttr(kDeviceAttr); } else { - const auto resource_arg = resource.dyn_cast_or_null(); + const auto resource_arg = dyn_cast_or_null(resource); if (resource_arg && (resource_arg.getOwner() == &(func.front()))) { return func.getArgAttrOfType( resource_arg.getArgNumber(), kFuncDeviceAttr); @@ -129,7 +129,7 @@ mlir::Attribute GetDeviceOfResource(mlir::func::FuncOp func, } bool IsCompositeDevice(mlir::Attribute attr) { - const auto str_attr = attr.dyn_cast_or_null(); + const auto str_attr = llvm::dyn_cast_if_present(attr); return str_attr && (str_attr.getValue().find("COMPOSITE") != llvm::StringRef::npos); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index a9eb45e5da3c25..bf786ac1a06cf8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -180,10 +180,10 @@ absl::Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { absl::StrCat("Converting ", debugString(type), " to DataType")); } -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ - if (type.isa()) { \ - *dtype = DT_##enumerant; \ - return OkStatus(); \ +#define HANDLE_TF_TYPE(tftype, enumerant, name) \ + if (llvm::isa(type)) { \ + *dtype = DT_##enumerant; \ + return OkStatus(); \ } // NOLINTNEXTLINE #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index 2efd63b29b04ef..aa818d2ae73bd2 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -126,7 +126,6 @@ TEST(DumpCrashReproducerTest, RoundtripDumpAndReadValid) { registry, mlir::MlirOptMainConfig{} .splitInputFile("") - .verifyDiagnostics(false) .verifyPasses(false) .allowUnregisteredDialects(false) .setPassPipelineParser(passPipeline)) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc index 9d4305b8e033f4..56dcee5430157a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -41,11 +42,8 @@ void MarkResourceAsReadAndWrite( OpOperand& op_operand, SmallVectorImpl>& effects) { - if (op_operand.get() - .getType() - .cast() - .getElementType() - .isa()) { + if (llvm::isa(llvm::cast(op_operand.get().getType()) + .getElementType())) { effects.emplace_back(MemoryEffects::Read::get(), &op_operand, ResourceEffects::Variable::get()); effects.emplace_back(MemoryEffects::Write::get(), &op_operand, @@ -57,11 +55,8 @@ void MarkResourceAsReadOnly( OpOperand& op_operand, SmallVectorImpl>& effects) { - if (op_operand.get() - .getType() - .cast() - .getElementType() - .isa()) { + if (llvm::isa(llvm::cast(op_operand.get().getType()) + .getElementType())) { effects.emplace_back(MemoryEffects::Read::get(), &op_operand, ResourceEffects::Variable::get()); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc index 348ae41e3d2ebb..34917780dc80cb 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc @@ -48,13 +48,13 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/utils/string_container_utils.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_argument.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index ac8ecf1090b2a7..b87afe63412551 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -380,7 +380,7 @@ mlir::LogicalResult HandleTileShardedInputsUsingXlaSplitOps( std::vector paddings; paddings.reserve(rank); auto shape = llvm::to_vector<4>( - original_source.getType().cast().getShape()); + mlir::cast(original_source.getType()).getShape()); for (int dim = 0; dim < rank; ++dim) { paddings.push_back( GetPadding(dim, input_sharding.tile_assignment_dimensions(dim), diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index 3fe8f0cb052062..cf83f71d0a6629 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -77,6 +77,7 @@ cc_library( "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/platform:statusor", + "@stablehlo//:base", "@stablehlo//:register", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc index 6281ea68e37807..829d3ca5819379 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc @@ -54,6 +54,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo #include "stablehlo/dialect/Register.h" // from @stablehlo #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" @@ -201,12 +202,12 @@ mlir::RankedTensorType GetBufferType(mlir::Type ty) { int64_t rank = ranked_ty.getRank(); llvm::SmallVector dims = llvm::to_vector<4>(ranked_ty.getShape()); - auto encoding = mlir::dyn_cast_or_null( - ranked_ty.getEncoding()); - if (encoding && !encoding.getBounds().empty()) { + llvm::ArrayRef bounds = + mlir::hlo::encodingToBounds(ranked_ty.getEncoding()); + if (!bounds.empty()) { for (int64_t dim = 0; dim < rank; ++dim) { if (dims[dim] == mlir::ShapedType::kDynamic) { - dims[dim] = encoding.getBounds()[dim]; + dims[dim] = bounds[dim]; } } } diff --git a/tensorflow/compiler/mlir/tf2xla/internal/inference/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/inference/BUILD index e80d33abb5cb37..d87efdfbf146f5 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/inference/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/inference/BUILD @@ -13,15 +13,10 @@ package( gentbl_cc_library( name = "inference_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TF2XLA", - ], - "inference_passes.h.inc", - ), - ], + tbl_outs = {"inference_passes.h.inc": [ + "-gen-pass-decls", + "-name=TF2XLA", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "inference_passes.td", deps = [ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index becdc528044f86..6e342a751da93e 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -7,7 +7,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//learning/pathways/serving/transforms:__pkg__", + "//learning/brain/tfrt/tpu/compiler/mlir:__pkg__", "//tensorflow/compiler/mlir:__pkg__", "//tensorflow/compiler/mlir/tf2xla/api:__subpackages__", "//tensorflow/compiler/mlir/tf2xla/internal:__subpackages__", @@ -71,15 +71,10 @@ cc_library( gentbl_cc_library( name = "clustering_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TFXLABridgeClustering", - ], - "clustering_passes.h.inc", - ), - ], + tbl_outs = {"clustering_passes.h.inc": [ + "-gen-pass-decls", + "-name=TFXLABridgeClustering", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "clustering_passes.td", deps = [ @@ -229,15 +224,10 @@ cc_library( gentbl_cc_library( name = "mlir_to_graph_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TFXLABridgeMlirToGraph", - ], - "mlir_to_graph_passes.h.inc", - ), - ], + tbl_outs = {"mlir_to_graph_passes.h.inc": [ + "-gen-pass-decls", + "-name=TFXLABridgeMlirToGraph", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mlir_to_graph_passes.td", deps = [ @@ -459,15 +449,10 @@ cc_library( gentbl_cc_library( name = "lowering_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TFXLABridgeLowering", - ], - "lowering_passes.h.inc", - ), - ], + tbl_outs = {"lowering_passes.h.inc": [ + "-gen-pass-decls", + "-name=TFXLABridgeLowering", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lowering_passes.td", deps = [ @@ -570,7 +555,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/hlo/builder:sharding_builder", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc index db39ca12d9ce91..0da0cc4fc4ddfb 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc @@ -872,7 +872,7 @@ LogicalResult FormClustersInBlock( block, cluster_ops, results, cluster_successor_ops.getArrayRef()); auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr); - if (!num_replicas || !num_replicas.isa()) + if (!num_replicas || !mlir::isa(num_replicas)) return cluster.emitError() << "requires '" << kNumReplicasAttr << "' int attribute"; @@ -881,9 +881,9 @@ LogicalResult FormClustersInBlock( cluster_metadata->getSecond().get(kNumCoresPerReplicaAttr)); if (num_cores_per_replica_attr) num_cores_per_replica = num_cores_per_replica_attr.getInt(); - if (failed(ReplicateCluster(cluster, - num_replicas.cast().getInt(), - num_cores_per_replica))) + if (failed(ReplicateCluster( + cluster, mlir::cast(num_replicas).getInt(), + num_cores_per_replica))) return mlir::failure(); // Copy TPUReplicateMetadata attributes to `tf_device.cluster`. diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index 92754a181e8551..6188395f648bc8 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -559,7 +559,7 @@ func.func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { // CHECK: %[[RS:.*]] = mhlo.reshape %[[ARG]] : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32> // CHECK-DAG: %[[IOTA0:.*]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<12x12xi32> // CHECK-DAG: %[[IOTA1:.*]] = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<12x12xi32> - // CHECK-DAG: %[[COMP:.*]] = mhlo.compare EQ, %[[IOTA0]], %[[IOTA1]], NOTYPE : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1> + // CHECK-DAG: %[[COMP:.*]] = mhlo.compare EQ, %[[IOTA0]], %[[IOTA1]] : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1> // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ZERO_MAT:.*]] = "mhlo.broadcast"(%[[ZERO]]) <{broadcast_sizes = dense<12> : tensor<2xi64>}> : (tensor) -> tensor<12x12xf32> // CHECK-DAG: %[[SEL:.*]] = mhlo.select %[[COMP]], %[[RS]], %[[ZERO_MAT]] : tensor<12x12xi1>, tensor<12x12xf32> @@ -622,7 +622,7 @@ func.func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32 // CHECK-DAG: %[[V40:.*]] = mhlo.and %[[V36]], %[[V39]] : tensor<1x22x128xi1> // CHECK-DAG: %[[V41:.*]] = mhlo.reshape %[[V40]] : (tensor<1x22x128xi1>) -> tensor<22x128xi1> // CHECK-DAG: %[[V42:.*]] = "mhlo.concatenate"(%[[V33]], %[[V32]]) <{dimension = 0 : i64}> : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> - // CHECK-DAG: %[[V43:.*]] = "mhlo.gather"(%[[ARG]], %[[V42]]) <{dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[7, 1, 1]> : tensor<3xi64>}> : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> + // CHECK-DAG: %[[V43:.*]] = "mhlo.gather"(%[[ARG]], %[[V42]]) <{dimension_numbers = #mhlo.gather, slice_sizes = dense<[7, 1, 1]> : tensor<3xi64>}> : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> // CHECK-DAG: %[[V44:.*]] = "mhlo.broadcast"(%[[V41]]) <{broadcast_sizes = dense<7> : tensor<1xi64>}> : (tensor<22x128xi1>) -> tensor<7x22x128xi1> // CHECK-DAG: %[[V45:.*]] = "mhlo.broadcast"(%[[V0]]) <{broadcast_sizes = dense<[7, 22, 128]> : tensor<3xi64>}> : (tensor) -> tensor<7x22x128xi32> // CHECK: %[[V46:.*]] = mhlo.select %[[V44]], %[[V43]], %[[V45]] : tensor<7x22x128xi1>, tensor<7x22x128xi32> @@ -731,6 +731,80 @@ func.func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> ten func.return %2: tensor<3x5x7x9x11x4x10xf32> } +//===----------------------------------------------------------------------===// +// Conv +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @conv2d_NHWC +func.func @conv2d_NHWC(%arg0: tensor<1x4x4x2xf32> {tf_saved_model.index_path = ["input_2"]}, %arg1: tensor<3x3x2x2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>, %arg4: tensor<2xf32>, %arg5: tensor<2xf32>, %arg6: tensor<2xf32>, %arg7: tensor<2xf32>) -> (tensor<1x4x4x2xf32> {tf_saved_model.index_path = [""]}) { + // CHECK{LITERAL}: mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<1x4x4x2xf32>, tensor<3x3x2x2xf32>) -> tensor<1x4x4x2xf32> + %0 = "tf.Conv2D"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> {device = ""} : (tensor<1x4x4x2xf32>, tensor<3x3x2x2xf32>) -> tensor<1x4x4x2xf32> + %1 = "tf.Mul"(%0, %arg6) : (tensor<1x4x4x2xf32>, tensor<2xf32>) -> tensor<1x4x4x2xf32> + %2 = "tf.AddV2"(%1, %arg7) : (tensor<1x4x4x2xf32>, tensor<2xf32>) -> tensor<1x4x4x2xf32> + return %2 : tensor<1x4x4x2xf32> +} + +// ----- + +// CHECK-LABEL: func @conv2d_backprop_input +func.func @conv2d_backprop_input(%arg0: tensor<3x3x8x8xf32>, %arg1: tensor<1x128x192x8xf32>) -> tensor<1x256x384x8xf32> { + %cst = "tf.Const"() <{value = dense<[1, 256, 384, 8]> : tensor<4xi32>}> : () -> tensor<4xi32> + %0 = "tf.Conv2DBackpropInput"(%cst, %arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> {device = ""} : (tensor<4xi32>, tensor<3x3x8x8xf32>, tensor<1x128x192x8xf32>) -> tensor<1x256x384x8xf32> + return %0 : tensor<1x256x384x8xf32> + } + +//===----------------------------------------------------------------------===// +// Cumulative +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @cumsum +func.func @cumsum(%arg0: tensor<1x4x1xf32>) -> tensor<1x4x1xf32> { + // CHECK: mhlo.reduce_window + // CHECK-SAME{LITERAL}: padding = dense<[[0, 0], [3, 0], [0, 0]]> : tensor<3x2xi64>, window_dimensions = dense<[1, 4, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64> + // CHECK: mhlo.add + %cst = "tf.Const"() <{value = dense<1> : tensor}> : () -> tensor + %0 = "tf.Cumsum"(%arg0, %cst) <{exclusive = false, reverse = false}> {device = ""} : (tensor<1x4x1xf32>, tensor) -> tensor<1x4x1xf32> + return %0 : tensor<1x4x1xf32> +} + +// ----- + +// CHECK-LABEL: func @cumprod +func.func @cumprod(%arg0: tensor<1x4x1xf32>) -> tensor<1x4x1xf32> { + // CHECK: mhlo.reduce_window + // CHECK-SAME{LITERAL}: padding = dense<0> : tensor<3x2xi64>, window_dimensions = dense<1> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64> + // CHECK: mhlo.multiply + %cst = "tf.Const"() <{value = dense<2> : tensor}> : () -> tensor + %0 = "tf.Cumprod"(%arg0, %cst) <{exclusive = false, reverse = false}> {device = ""} : (tensor<1x4x1xf32>, tensor) -> tensor<1x4x1xf32> + return %0 : tensor<1x4x1xf32> +} + +//===----------------------------------------------------------------------===// +// DynamicSlice +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @dynamic_slice_i32 +func.func @dynamic_slice_i32(%arg0: tensor<8x512x384xbf16>, %arg1: tensor<3xi32>) -> tensor<1x512x384xbf16> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1,_arg2", outputs = "_retval0"}} { + %cst = "tf.Const"() <{value = dense<[1, 512, 384]> : tensor<3xi32>}> : () -> tensor<3xi32> + // CHECK: "mhlo.dynamic_slice"{{.*}}slice_sizes = dense<[1, 512, 384]> : tensor<3xi64> + %0 = "tf.XlaDynamicSlice"(%arg0, %arg1, %cst) {device = ""} : (tensor<8x512x384xbf16>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x384xbf16> + return %0 : tensor<1x512x384xbf16> +} + +// ----- + +// CHECK-LABEL: func @dynamic_slice_i64 +func.func @dynamic_slice_i64(%arg0: tensor<8x512x384xbf16>, %arg1: tensor<3xi32>) -> tensor<1x512x384xbf16> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1,_arg2", outputs = "_retval0"}} { + %cst = "tf.Const"() <{value = dense<[1, 512, 384]> : tensor<3xi64>}> : () -> tensor<3xi64> + // CHECK: "mhlo.dynamic_slice"{{.*}}slice_sizes = dense<[1, 512, 384]> : tensor<3xi64> + %0 = "tf.XlaDynamicSlice"(%arg0, %arg1, %cst) {device = ""} : (tensor<8x512x384xbf16>, tensor<3xi32>, tensor<3xi64>) -> tensor<1x512x384xbf16> + return %0 : tensor<1x512x384xbf16> +} + //===----------------------------------------------------------------------===// // Erf //===----------------------------------------------------------------------===// @@ -739,7 +813,8 @@ func.func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> ten // CHECK-LABEL: func @erf func.func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK: mhlo.erf %arg0 : tensor<2x3xf32> + // CHECK: chlo.erf %arg0 : tensor<2x3xf32> + // CHLO: mhlo.erf %arg0 : tensor<2x3xf32> %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> func.return %0 : tensor<2x3xf32> } @@ -1488,7 +1563,7 @@ func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_outpu // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) <{window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({ // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]] : (tensor, tensor) -> tensor // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor // CHECK: }, { // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): @@ -1513,7 +1588,7 @@ func.func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_ // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) <{window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>}> ({ // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]] : (tensor, tensor) -> tensor // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor // CHECK: }, { // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): @@ -1558,7 +1633,7 @@ func.func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_out func.func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tensor) -> tensor<3x5xf32> { // CHECK: %[[IOTA:.*]] = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<3x5xi32> // CHECK: %[[BCAST_ARG0:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<3xi32>) -> tensor<3x5xi32> - // CHECK: %[[COMPARE:.*]] = mhlo.compare EQ, %[[BCAST_ARG0]], %[[IOTA]], NOTYPE : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> + // CHECK: %[[COMPARE:.*]] = mhlo.compare EQ, %[[BCAST_ARG0]], %[[IOTA]] : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> // CHECK: %[[ON_VALUE:.*]] = "mhlo.broadcast"(%arg1) <{broadcast_sizes = dense<[3, 5]> : tensor<2xi64>}> : (tensor) -> tensor<3x5xf32> // CHECK: %[[OFF_VALUE:.*]] = "mhlo.broadcast"(%arg2) <{broadcast_sizes = dense<[3, 5]> : tensor<2xi64>}> : (tensor) -> tensor<3x5xf32> // CHECK: %[[RESULT:.*]] = mhlo.select %[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]] : tensor<3x5xi1>, tensor<3x5xf32> @@ -1763,7 +1838,7 @@ func.func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) - // CHECK-LABEL: func @elu func.func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> { - // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) <{value = 0.000000e+00 : f32}> : (tensor<1xf32>) -> tensor<1xf32> + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1xf32> // CHECK-DAG: %[[PRED:.*]] = mhlo.compare GT, %arg0, %[[ZERO]] // CHECK-DAG: %[[EXP:.*]] = mhlo.exponential_minus_one %arg0 // CHECK: %[[RESULT:.*]] = mhlo.select %[[PRED]], %arg0, %[[EXP]] @@ -1841,7 +1916,7 @@ func.func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attribu // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) <{value = 2.000000e-01 : f32}> : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) <{value = 0.000000e+00 : f32}> : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[ALPHA]] : tensor<1x4x4x3xf32> - // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP]], %[[ZERO]], NOTYPE : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1> + // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP]], %[[ZERO]] : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1> // CHECK-NEXT: %[[RES:.*]] = mhlo.select %[[CMP]], %[[INP]], %[[LEAKY]] : tensor<1x4x4x3xi1>, tensor<1x4x4x3xf32> // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32> %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> @@ -1855,7 +1930,7 @@ func.func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) - // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) <{value = 2.000000e-01 : f32}> : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) <{value = 0.000000e+00 : f32}> : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] : tensor<1x4x4xf32> - // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP:.*]], %[[ZERO]], NOTYPE : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1> + // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP:.*]], %[[ZERO]] : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1> // CHECK-NEXT: %[[RES:.*]] = mhlo.select %[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]] : tensor<1x4x4xi1>, tensor<1x4x4xf32> // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32> %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32> @@ -1866,7 +1941,7 @@ func.func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) - // CHECK-LABEL: func @softsign func.func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> { - // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) <{value = 1.000000e+00 : f32}> : (tensor<4x10xf32>) -> tensor<4x10xf32> + // CHECK-NEXT: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<4x10xf32> // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32> // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ONE]], %[[ABS]] : tensor<4x10xf32> // CHECK-NEXT: %[[DIV:.*]] = mhlo.divide %{{.*}}, %[[ADD]] : tensor<4x10xf32> diff --git a/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD b/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD index a5d8d8d8c5183f..f46627f0e43565 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD @@ -12,11 +12,11 @@ cc_library( "graph_to_tf_executor_registration.cc", ], deps = [ - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow/translate/tools:file_tf_mlir_translate", "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", + "//tensorflow/compiler/mlir/tools:translate_cl_options", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_base", diff --git a/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc b/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc index 8a9811c8dcbcbc..7b7b5771f5a4be 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc +++ b/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc @@ -26,11 +26,11 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tools/file_tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/client/client_library.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index abd057643629c6..1e85dbff84e248 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -16,12 +16,7 @@ package( gentbl_cc_library( name = "legalize_tf_patterns_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_legalize_tf.inc", - ), - ], + tbl_outs = {"generated_legalize_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "legalize_tf_patterns.td", deps = [ @@ -30,21 +25,18 @@ gentbl_cc_library( "@llvm-project//mlir:FuncTdFiles", "@llvm-project//mlir:TensorOpsTdFiles", "@local_xla//xla/mlir_hlo:hlo_ops_td_files", + "@stablehlo//:chlo_ops_td_files", + "@stablehlo//:stablehlo_ops_td_files", ], ) gentbl_cc_library( name = "xla_legalize_tf_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=LegalizeTf", - ], - "xla_legalize_tf_passes.h.inc", - ), - ], + tbl_outs = {"xla_legalize_tf_passes.h.inc": [ + "-gen-pass-decls", + "-name=LegalizeTf", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_legalize_tf_passes.td", deps = [ @@ -55,15 +47,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_xla_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TfXla", - ], - "tf_xla_passes.h.inc", - ), - ], + tbl_outs = {"tf_xla_passes.h.inc": [ + "-gen-pass-decls", + "-name=TfXla", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_xla_passes.td", deps = [ @@ -177,6 +164,8 @@ cc_library( "@local_xla//xla/mlir_hlo:convert_op_folder", "@local_xla//xla/tsl/platform:status", "@stablehlo//:chlo_ops", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_pass_utils", ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), ) @@ -262,7 +251,6 @@ cc_library( ":xla_legalize_targets", ":xla_legalize_tf_passes_inc_gen", ":xla_legalize_tf_with_tf2xla", - "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", @@ -305,8 +293,10 @@ cc_library( "@local_xla//xla/mlir_hlo:type_conversion", "@local_xla//xla/stream_executor/tpu:c_api_conversions", "@local_xla//xla/stream_executor/tpu:tpu_api", + "@stablehlo//:base", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", ], ) @@ -339,7 +329,6 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/framework:allocator", "//tensorflow/core/protobuf:for_core_protos_cc", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -361,12 +350,13 @@ cc_library( "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer", "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", - "@local_xla//xla/mlir_hlo", "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/tsl/platform:env", "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/platform:status", "@local_xla//xla/tsl/platform:statusor", + "@stablehlo//:base", + "@stablehlo//:stablehlo_ops", ], ) @@ -381,9 +371,7 @@ tf_cc_test( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/core:framework", "//tensorflow/core:ops", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", @@ -396,11 +384,11 @@ tf_cc_test( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder:xla_computation", - "@local_xla//xla/mlir_hlo", "@local_xla//xla/tsl/lib/core:status_test_util", "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/platform:status", "@local_xla//xla/tsl/platform:statusor", + "@stablehlo//:stablehlo_ops", ], ) @@ -442,7 +430,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", - "@local_xla//xla/mlir_hlo", + "@stablehlo//:base", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc index 7df70e4de558a2..69c6a47e9ab707 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc @@ -358,6 +358,7 @@ bool IsOpTypeAllowedTf2XlaFallback(const TypeID& type_id) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get< TF::XlaSparseDenseMatmulGradWithAdagradAndStaticBufferSizeOp>(), @@ -370,6 +371,18 @@ bool IsOpTypeAllowedTf2XlaFallback(const TypeID& type_id) { TypeID::get< TF::XlaSparseDenseMatmulGradWithSgdAndStaticBufferSizeOp>(), // NOLINT TypeID::get(), + TypeID::get< + TF::XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInputOp>(), // NOLINT + TypeID::get< + TF::XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInputOp>(), // NOLINT + TypeID::get< + TF::XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInputOp>(), // NOLINT + TypeID::get< + TF::XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInputOp>(), // NOLINT + TypeID::get< + TF::XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInputOp>(), // NOLINT + TypeID::get< + TF::XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp>(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index 113b088b3db7d2..4a49785ed101b2 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -83,7 +83,7 @@ TEST(LegalizationOpConfigTest, CountLoweringsSet) { // from MLIR to TF2XLA), these numbers should change. Or if TF Dialect adds // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 67); - EXPECT_EQ(tf2xla_fallback_count, 323); + EXPECT_EQ(tf2xla_fallback_count, 330); EXPECT_EQ(non_categorized_count, 431); } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index 047a5fb7b46bbc..c6d9761a95cefe 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -42,6 +42,8 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -57,6 +59,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/transforms/PassUtils.h" // from @stablehlo // IWYU pragma: keep, legalize_tf_patterns.td #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" @@ -66,7 +70,6 @@ limitations under the License. #include "xla/hlo/builder/padding.h" #include "xla/hlo/builder/sharding_builder.h" #include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/utils/convert_op_folder.h" #include "xla/mlir_hlo/utils/hlo_utils.h" #include "xla/tsl/platform/status.h" @@ -80,7 +83,14 @@ limitations under the License. #include "tsl/platform/tensor_float_32_utils.h" namespace mlir { -namespace mhlo { +namespace hlo { + +// Methods from utils.h +using mhlo::BuildReduceBody; +using mhlo::GetI64ElementsAttr; +using mhlo::GetScalarConstOfType; +using mhlo::GetScalarNegZeroOfType; + namespace { constexpr char kShardingAttr[] = "mhlo.sharding"; @@ -99,6 +109,34 @@ void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl *values) { values->push_back(mlir::cast(val).getValue().getSExtValue()); } +DenseI64ArrayAttr GetI64ArrayAttr(ArrayRef values, Builder *builder) { + return builder->getDenseI64ArrayAttr(values); +} + +static DenseI64ArrayAttr ToDenseI64ArrayAttr(DenseIntElementsAttr attr, + Builder *builder) { + if (!attr) return {}; + if (attr.getElementType().isInteger(64)) { + return GetI64ArrayAttr(llvm::to_vector(attr.getValues()), builder); + } + + // Requires conversion to i64 first. + std::vector values; + values.reserve(attr.getNumElements()); + for (auto value : attr.getValues()) { + values.push_back(value.getValue().getSExtValue()); + } + return GetI64ArrayAttr(values, builder); +} + +static DenseI64ArrayAttr ToDenseI64ArrayAttr(ElementsAttr attr, + Builder *builder) { + return ToDenseI64ArrayAttr( + mlir::cast( + hlo::convertElementsAttr(attr, builder->getIntegerType(64))), + builder); +} + // Returns 1D 32-bit dense elements attribute with the given values. static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, Builder *builder) { @@ -109,26 +147,22 @@ static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, // Returns a 1-d i64 elements attribute populated with numbers from start to // end, excluding. -static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, - Builder *builder) { +static DenseI64ArrayAttr GetI64ArrayAttrForSeq(int start, int end, + Builder *builder) { int size = end - start; SmallVector vals; vals.resize(size); std::iota(vals.begin(), vals.end(), start); - - TensorType ty = - tensorflow::GetTypeFromTFTensorShape({size}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, vals); + return builder->getDenseI64ArrayAttr(vals); } // Returns a 1-d i64 elements attribute populated with `val` repeated `size` // times. -static DenseIntElementsAttr GetI64ElementsAttrForValue(int size, int64_t val, - Builder *builder) { - TensorType ty = - tensorflow::GetTypeFromTFTensorShape({size}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, val); +static DenseI64ArrayAttr GetI64ArrayAttrForValue(int size, int64_t val, + Builder *builder) { + llvm::SmallVector vals(size, val); + return builder->getDenseI64ArrayAttr(vals); } // Returns the corresponding type that should be used for performing sum @@ -164,14 +198,14 @@ static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank, // Returns a PrecisionConfig as an array attribute based on whether TF32 // execution is enabled static ArrayAttr GetPrecisionConfig(Builder *builder) { - mlir::mhlo::Precision precision = tsl::tensor_float_32_execution_enabled() - ? mhlo::Precision::DEFAULT - : mlir::mhlo::Precision::HIGHEST; + mlir::stablehlo::Precision precision = + tsl::tensor_float_32_execution_enabled() ? stablehlo::Precision::DEFAULT + : stablehlo::Precision::HIGHEST; llvm::SmallVector attr_vec; const int num_inputs = 2; for (int i = 0; i < num_inputs; i++) { attr_vec.push_back( - mlir::mhlo::PrecisionAttr::get(builder->getContext(), precision)); + mlir::stablehlo::PrecisionAttr::get(builder->getContext(), precision)); } return builder->getArrayAttr(attr_vec); } @@ -193,9 +227,10 @@ static std::optional GetIntegerHLOAxisFromTFAxis(Value value, /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining /// the shape of the input value. -static ConvertOp CastValueToI64(Location loc, Value value, - PatternRewriter *rewriter) { - return rewriter->create(loc, value, rewriter->getIntegerType(64)); +static stablehlo::ConvertOp CastValueToI64(Location loc, Value value, + PatternRewriter *rewriter) { + return rewriter->create(loc, value, + rewriter->getIntegerType(64)); } // Creates an unpack op along the 0th dimension of the tensor. The `value` input @@ -239,10 +274,11 @@ tensorflow::TensorShape ToTensorShape( // Returns a limit scalar const op for the given type. // Requires FloatType or IntegerType -static ConstantOp GetScalarLimitConstOfType(Type ty, Location loc, - hlo::ScalarLimit limit, - OpBuilder *builder) { - return builder->create(loc, hlo::getScalarLimitOfType(ty, limit)); +static stablehlo::ConstantOp GetScalarLimitConstOfType(Type ty, Location loc, + hlo::ScalarLimit limit, + OpBuilder *builder) { + return builder->create( + loc, hlo::getScalarLimitOfType(ty, limit)); } // Deprecated: This is maintained to aid in porting old code that is not yet @@ -311,22 +347,24 @@ static Value StaticBinaryBroadcast(Location loc, Value x, Value y, return nullptr; } auto larger_broadcast_dims = - GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder); + GetI64ArrayAttrForSeq(0, result_type.getRank(), &builder); if (x_type.getRank() < y_type.getRank()) { if (x_type != result_type) { - x = builder.create(loc, result_type, x, broadcast_dims); + x = builder.create(loc, result_type, x, + broadcast_dims); } if (y_type != result_type) { - y = builder.create(loc, result_type, y, - larger_broadcast_dims); + y = builder.create(loc, result_type, y, + larger_broadcast_dims); } } else { if (x_type != result_type) { - x = builder.create(loc, result_type, x, - larger_broadcast_dims); + x = builder.create(loc, result_type, x, + larger_broadcast_dims); } if (y_type != result_type) { - y = builder.create(loc, result_type, y, broadcast_dims); + y = builder.create(loc, result_type, y, + broadcast_dims); } } return builder.create(loc, x, y); @@ -356,13 +394,13 @@ static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) { static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to, Value broadcast_from, int64_t feature_dim, OpBuilder &builder) { - auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder); + auto broadcast_dims = GetI64ArrayAttr({feature_dim}, &builder); auto to_type = mlir::cast(broadcast_to.getType()); auto result_shape = builder.create(loc, broadcast_to); auto result_extents_type = GetExtentsTensorTypeFor(to_type); auto result_extents = builder.create( loc, result_extents_type, result_shape); - return builder.create( + return builder.create( loc, to_type, broadcast_from, result_extents, broadcast_dims); } @@ -381,8 +419,8 @@ static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to, auto result_extents = builder.create( loc, result_extents_type, result_shape); int64_t rank = mlir::cast(input.getType()).getRank(); - auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder); - return builder.create( + auto broadcast_dims = GetI64ArrayAttrForSeq(0, rank, &builder); + return builder.create( loc, to_type, input, result_extents, broadcast_dims); } @@ -391,33 +429,35 @@ static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to, static Value ApplyReduction(Location loc, Value input, DenseIntElementsAttr reduce_dims, OpBuilder *builder) { - auto reduce_dims_op = builder->create(loc, reduce_dims); + auto reduce_dims_op = + builder->create(loc, reduce_dims); return builder->create(loc, input, reduce_dims_op, builder->getBoolAttr(false)); } -// Creates a mhlo.rng_uniform op with `builder` to generate `num_elements` +// Creates a stablehlo.rng_uniform op with `builder` to generate `num_elements` // 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`). -static mhlo::RngOp CreateRngUniform32(Location loc, int num_elements, - int lower_limit, int upper_limit, - OpBuilder *builder) { - auto shape_tensor = builder->create( +static stablehlo::RngOp CreateRngUniform32(Location loc, int num_elements, + int lower_limit, int upper_limit, + OpBuilder *builder) { + auto shape_tensor = builder->create( loc, GetI64ElementsAttr({num_elements}, builder)); - auto lower = builder->create( + auto lower = builder->create( loc, builder->getI32IntegerAttr(lower_limit)); - auto upper = builder->create( + auto upper = builder->create( loc, builder->getI32IntegerAttr(upper_limit)); - return builder->create(loc, lower, upper, shape_tensor, - ::mlir::mhlo::RngDistribution::UNIFORM); + return builder->create( + loc, lower, upper, shape_tensor, + ::mlir::stablehlo::RngDistribution::UNIFORM); } using WhileBodyFnType = llvm::function_ref old_values, SmallVectorImpl *new_values, OpBuilder *builder)>; -// Creates a mhlo.while op with `builder` to loop `num_interations` times, +// Creates a stablehlo.while op with `builder` to loop `num_interations` times, // each time calling the given `body_fn` on a set of values to generate a new // set of values. Returns the final set of values via `final_values`. The // initial set of values is passed in via `init_values`. @@ -449,8 +489,8 @@ static void CreateWhile32(Location loc, int num_iterations, init_types_with_loop_iv.reserve(value_count); // The initial value for the loop induction variable is 0. - init_values_with_loop_iv.push_back( - builder->create(loc, builder->getI32IntegerAttr(0))); + init_values_with_loop_iv.push_back(builder->create( + loc, builder->getI32IntegerAttr(0))); init_values_with_loop_iv.append(init_values.begin(), init_values.end()); // Accumulate types of all the init values. @@ -458,8 +498,8 @@ static void CreateWhile32(Location loc, int num_iterations, init_types_with_loop_iv.push_back(init_value_with_loop_iv.getType()); // Create the while op. - auto while_op = builder->create(loc, init_types_with_loop_iv, - init_values_with_loop_iv); + auto while_op = builder->create( + loc, init_types_with_loop_iv, init_values_with_loop_iv); auto ivs_count = init_types_with_loop_iv.size(); { @@ -473,12 +513,12 @@ static void CreateWhile32(Location loc, int num_iterations, // Get the loop induction variable and compare it against the upper limit. auto loop_iv = block->getArgument(0); - auto upper_limit = builder->create( + auto upper_limit = builder->create( loc, builder->getI32IntegerAttr(num_iterations)); - Value compare = builder->create(loc, loop_iv, upper_limit, - ComparisonDirection::LT); + Value compare = builder->create( + loc, loop_iv, upper_limit, stablehlo::ComparisonDirection::LT); - builder->create(loc, compare); + builder->create(loc, compare); } { @@ -500,15 +540,15 @@ static void CreateWhile32(Location loc, int num_iterations, &new_values, builder); // Increment the loop induction variable by one. - auto one = - builder->create(loc, builder->getI32IntegerAttr(1)); + auto one = builder->create( + loc, builder->getI32IntegerAttr(1)); auto scalar_broadcast_dims = builder->getDenseI64ArrayAttr({}); auto plus_one = builder->create( loc, block->getArgument(0), one, scalar_broadcast_dims); // Prepend with the updated loop induction variable. new_values.insert(new_values.begin(), plus_one); - builder->create(loc, new_values); + builder->create(loc, new_values); } // TODO(jpienaar): Support multi-operand while op. @@ -534,12 +574,12 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, // Returns the 1D i64 elements attribute populated with the inner-most dim of // the value. -static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type, - Builder *builder) { +static DenseI64ArrayAttr GetInnerDimFromValue(ShapedType type, + Builder *builder) { if (type.getRank() == 0) { - return builder->getI64TensorAttr({}); + return builder->getDenseI64ArrayAttr({}); } - return builder->getI64TensorAttr(type.getShape().back()); + return builder->getDenseI64ArrayAttr(type.getShape().back()); } // Returns True if the inner-most dim is static. @@ -569,13 +609,13 @@ static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) { // // Always returns 64 bit integer attribute regardless of bitwidth of the input // attribute. -static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( - ElementsAttr input, int column) { +static DenseI64ArrayAttr SliceDenseIntElementsAttrColumn2D(ElementsAttr input, + int column) { auto int_attr = mlir::cast(input); auto shaped_type = int_attr.getType(); auto shape = shaped_type.getShape(); - if (shape.size() != 2) return DenseIntElementsAttr(); + if (shape.size() != 2) return DenseI64ArrayAttr(); llvm::SmallVector values; values.reserve(shaped_type.getNumElements() / shape[1]); @@ -586,18 +626,15 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( } } - auto element_type = IntegerType::get(input.getContext(), 64); - return DenseIntElementsAttr::get( - tensorflow::GetTypeFromTFTensorShape({shape[0]}, element_type), values); + return DenseI64ArrayAttr::get(input.getContext(), values); } // Returns interior padding to use in HLO Pad op based on the TensorFlow padding // in TensorFlow PadV2 op. -static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) { +static DenseI64ArrayAttr GetInteriorPadding(ElementsAttr tf_padding) { auto length = tf_padding.getShapedType().getShape()[0]; - auto element_type = IntegerType::get(tf_padding.getContext(), 64); - return DenseIntElementsAttr::get( - tensorflow::GetTypeFromTFTensorShape({length}, element_type), 0); + std::vector padding(length, 0); + return DenseI64ArrayAttr::get(tf_padding.getContext(), padding); } //===----------------------------------------------------------------------===// @@ -689,10 +726,10 @@ static DenseElementsAttr GetEpsilonValue(Type ty) { // ArgMax/ArgMin op utilities. //===----------------------------------------------------------------------===// -static void BuildArgMinMaxReductionBody(Type input_element_type, - Type index_element_type, - ComparisonDirection direction, - Region *body, OpBuilder *builder) { +static void BuildArgMinMaxReductionBody( + Type input_element_type, Type index_element_type, + stablehlo::ComparisonDirection direction, Region *body, + OpBuilder *builder) { OpBuilder::InsertionGuard insertion_point_gurad(*builder); Type input_type = @@ -710,20 +747,21 @@ static void BuildArgMinMaxReductionBody(Type input_element_type, Value rhs_index = block->getArgument(3); ImplicitLocOpBuilder b(loc, *builder); - Value compare_dt = b.create(lhs_val, rhs_val, direction); + Value compare_dt = + b.create(lhs_val, rhs_val, direction); Value selected_input = - b.create(input_type, compare_dt, lhs_val, rhs_val); + b.create(input_type, compare_dt, lhs_val, rhs_val); - Value compare_eq = - b.create(lhs_val, rhs_val, ComparisonDirection::EQ); - Value min_index = b.create(lhs_index, rhs_index); - Value min_val_index = - b.create(index_type, compare_dt, lhs_index, rhs_index); - Value selected_index = - b.create(index_type, compare_eq, min_index, min_val_index); + Value compare_eq = b.create( + lhs_val, rhs_val, stablehlo::ComparisonDirection::EQ); + Value min_index = b.create(lhs_index, rhs_index); + Value min_val_index = b.create(index_type, compare_dt, + lhs_index, rhs_index); + Value selected_index = b.create( + index_type, compare_eq, min_index, min_val_index); Value return_values[] = {selected_input, selected_index}; - b.create(return_values); + b.create(return_values); } //===----------------------------------------------------------------------===// @@ -780,13 +818,12 @@ static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices, // TF slice size can be -1, which represents all elements from start_index to // the end. HLO slice size can't be -1. As such, we need to translate TF slice // size -1 to HLO slice size. -static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( +static DenseI64ArrayAttr TFSliceSizes2HLOSliceSizes( Value input, Value start_indices, DenseIntElementsAttr slice_sizes, Builder *builder) { DenseIntElementsAttr constant_start_indices; if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) { - return mlir::cast( - hlo::convertElementsAttr(slice_sizes, builder->getIntegerType(64))); + return ToDenseI64ArrayAttr(slice_sizes, builder); } auto input_ty = mlir::dyn_cast(input.getType()); @@ -803,7 +840,7 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( : slice_size); } - return GetI64ElementsAttr(normalized_sizes, builder); + return GetI64ArrayAttr(normalized_sizes, builder); } //===----------------------------------------------------------------------===// @@ -815,11 +852,11 @@ bool HasValidGatherDims(StringAttr attr) { return dims.ParseFromString(attr.getValue().str()); } -GatherDimensionNumbersAttr GetGatherDimNumsAttr(StringAttr attr, - Builder *builder) { +stablehlo::GatherDimensionNumbersAttr GetGatherDimNumsAttr(StringAttr attr, + Builder *builder) { ::xla::GatherDimensionNumbers dims; if (!dims.ParseFromString(attr.getValue().str())) return {}; - return ::xla::ConvertGatherDimensionNumbers(dims, builder); + return ::xla::stablehlo::ConvertGatherDimensionNumbers(dims, builder); } //===----------------------------------------------------------------------===// @@ -831,10 +868,11 @@ bool HasValidDotDims(StringAttr attr) { return dims.ParseFromString(attr.getValue().str()); } -DotDimensionNumbersAttr GetDotDimNumsAttr(StringAttr attr, Builder *builder) { +stablehlo::DotDimensionNumbersAttr GetDotDimNumsAttr(StringAttr attr, + Builder *builder) { ::xla::DotDimensionNumbers dims; if (!dims.ParseFromString(attr.getValue().str())) return {}; - return ::xla::ConvertDotDimensionNumbers(dims, builder); + return ::xla::stablehlo::ConvertDotDimensionNumbers(dims, builder); } bool HasValidPrecisionConfig(StringAttr attr) { @@ -845,7 +883,7 @@ bool HasValidPrecisionConfig(StringAttr attr) { mlir::ArrayAttr GetPrecisionConfigAttr(StringAttr attr, Builder *builder) { ::xla::PrecisionConfig precision; if (!precision.ParseFromString(attr.getValue().str())) return {}; - return ::xla::ConvertPrecisionConfig(&precision, builder); + return ::xla::stablehlo::ConvertPrecisionConfig(&precision, builder); } //===----------------------------------------------------------------------===// @@ -862,7 +900,7 @@ static void BuildBodyWithCall(PatternRewriter &rewriter, const Location &loc, block->addArguments(inputs, SmallVector(inputs.size(), loc)); mlir::func::CallOp call_op = rewriter.create( loc, func, func_ty.getResults(), block->getArguments()); - rewriter.create(loc, call_op.getResults()); + rewriter.create(loc, call_op.getResults()); } //===----------------------------------------------------------------------===// @@ -889,7 +927,7 @@ NamedAttribute GetConvDimensionNumbersAttr(ArrayRef spatial_dims, return builder->getNamedAttr( "dimension_numbers", - ConvDimensionNumbersAttr::get( + stablehlo::ConvDimensionNumbersAttr::get( builder->getContext(), batch_dim, feature_dim, spatial_dims, kernel_input_feature_dim, kernel_output_feature_dim, kernel_spatial_dimensions, batch_dim, feature_dim, spatial_dims)); @@ -916,7 +954,8 @@ class ConvertBiasAddOp : public OpRewritePattern { auto feature_dim = GetFeatureDimension(data_format, value_type); auto bias_broadcast = Broadcast1DToFeatureDim( loc, op.getValue(), op.getBias(), feature_dim, rewriter); - Value add = rewriter.create(loc, op.getValue(), bias_broadcast); + Value add = + rewriter.create(loc, op.getValue(), bias_broadcast); if (add.getType() != op.getType()) { add = rewriter.create(loc, op.getType(), add); } @@ -925,7 +964,7 @@ class ConvertBiasAddOp : public OpRewritePattern { } }; -// Conterts tf.Conv2D to mhlo.dynamic_conv. +// Conterts tf.Conv2D to stablehlo.dynamic_conv. // TODO(disc): To recover static special case's performance with adding folding, // canonicalization func and removing ConvertConvOp. template @@ -1082,10 +1121,10 @@ class ConvertConvDynamic : public OpRewritePattern { paddings.push_back(pad_high); } auto rhs_dilations_attr = rewriter.getNamedAttr( - "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter)); + "rhs_dilation", GetI64ArrayAttr(rhs_dilations, &rewriter)); auto window_strides_attr = rewriter.getNamedAttr( - "window_strides", GetI64ElementsAttr(window_strides, &rewriter)); + "window_strides", GetI64ArrayAttr(window_strides, &rewriter)); auto dimension_numbers_attr = GetConvDimensionNumbersAttr( spatial_dim_indices, data_format, &rewriter); @@ -1127,7 +1166,7 @@ class ConvertConvDynamic : public OpRewritePattern { new_shape.push_back(1); new_shape.push_back(filter_shape[num_spatial_dims] * filter_shape[num_spatial_dims + 1]); - operands[1] = rewriter.create( + operands[1] = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape(new_shape, filter_ty.getElementType()), @@ -1136,8 +1175,8 @@ class ConvertConvDynamic : public OpRewritePattern { NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, dimension_numbers_attr, feature_group_count_attr, batch_group_count_attr, precision_config_attr}; - rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::ArrayRef(attrs)); + rewriter.replaceOpWithNewOp( + op, op.getType(), operands, llvm::ArrayRef(attrs)); return success(); } @@ -1155,7 +1194,7 @@ using ConvertConv2DDynamic = // // Sample result for Conv2D: // -// %conv = "mhlo.convolution"(%input, %filter) { +// %conv = "stablehlo.convolution"(%input, %filter) { // strides = [1, 2], // paddings = [[1, 0], [1, 1]], // ... @@ -1241,10 +1280,10 @@ class ConvertConvOp : public OpRewritePattern { } auto rhs_dilations_attr = rewriter.getNamedAttr( - "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter)); + "rhs_dilation", GetI64ArrayAttr(rhs_dilations, &rewriter)); auto window_strides_attr = rewriter.getNamedAttr( - "window_strides", GetI64ElementsAttr(window_strides, &rewriter)); + "window_strides", GetI64ArrayAttr(window_strides, &rewriter)); auto dimension_numbers_attr = GetConvDimensionNumbersAttr( spatial_dim_indices, data_format, &rewriter); @@ -1285,7 +1324,7 @@ class ConvertConvOp : public OpRewritePattern { new_shape.push_back(1); new_shape.push_back(filter_shape[num_spatial_dims] * filter_shape[num_spatial_dims + 1]); - operands[1] = rewriter.create( + operands[1] = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape(new_shape, filter_ty.getElementType()), @@ -1295,8 +1334,8 @@ class ConvertConvOp : public OpRewritePattern { dimension_numbers_attr, feature_group_count_attr, batch_group_count_attr, paddings_attr, precision_config_attr}; - rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::ArrayRef(attrs)); + rewriter.replaceOpWithNewOp( + op, op.getType(), operands, llvm::ArrayRef(attrs)); return success(); } }; @@ -1307,7 +1346,7 @@ using ConvertDepthConv2DOp = ConvertConvOp; -// Converts tf.PadV2Op to mhlo.DynamicPadOp. Padding values must be const. +// Converts tf.PadV2Op to stablehlo.DynamicPadOp. Padding values must be const. class ConvertPadOpDynamic : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -1334,38 +1373,38 @@ class ConvertPadOpDynamic : public OpRewritePattern { auto interior_attr = GetI64ElementsAttr(interior_values, &rewriter); Value interior_padding_tensor = - rewriter.create(loc, interior_attr); + rewriter.create(loc, interior_attr); Type paddings_elem_ty = paddings_type.getElementType(); if (!paddings_elem_ty.isInteger(64)) { - interior_padding_tensor = rewriter.create( + interior_padding_tensor = rewriter.create( loc, interior_padding_tensor, paddings_elem_ty); } llvm::SmallVector transposed_shape = {2, input_rank}; - auto transpose_attr = GetI64ElementsAttr({1, 0}, &rewriter); + auto transpose_attr = GetI64ArrayAttr({1, 0}, &rewriter); Value transposed_paddings = - rewriter.create(loc, paddings, transpose_attr); - Value reshaped_paddings = rewriter.create( + rewriter.create(loc, paddings, transpose_attr); + Value reshaped_paddings = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape({input_rank * 2}, paddings_elem_ty), transposed_paddings); - auto left_padding_start_attr = GetI64ElementsAttr({0}, &rewriter); - auto left_padding_limit_attr = GetI64ElementsAttr({input_rank}, &rewriter); - auto left_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter); - Value left_padding_tensor = rewriter.create( + auto left_padding_start_attr = GetI64ArrayAttr({0}, &rewriter); + auto left_padding_limit_attr = GetI64ArrayAttr({input_rank}, &rewriter); + auto left_padding_stride_attr = GetI64ArrayAttr({1}, &rewriter); + Value left_padding_tensor = rewriter.create( loc, reshaped_paddings, left_padding_start_attr, left_padding_limit_attr, left_padding_stride_attr); - auto right_padding_start_attr = GetI64ElementsAttr({input_rank}, &rewriter); + auto right_padding_start_attr = GetI64ArrayAttr({input_rank}, &rewriter); auto right_padding_limit_attr = - GetI64ElementsAttr({2 * input_rank}, &rewriter); - auto right_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter); - Value right_padding_tensor = rewriter.create( + GetI64ArrayAttr({2 * input_rank}, &rewriter); + auto right_padding_stride_attr = GetI64ArrayAttr({1}, &rewriter); + Value right_padding_tensor = rewriter.create( loc, reshaped_paddings, right_padding_start_attr, right_padding_limit_attr, right_padding_stride_attr); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), input, constant_values, left_padding_tensor, right_padding_tensor, interior_padding_tensor); @@ -1375,11 +1414,11 @@ class ConvertPadOpDynamic : public OpRewritePattern { class ConvertGatherNdOpDynamic : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - // Converts tf.GatherNdOp to mhlo.DynamicGatherOp. + // Converts tf.GatherNdOp to stablehlo.DynamicGatherOp. // Here we leave 'slice_sizes' as an Attr, without defining a new // DynamicGatherOp, since GatherDimensionNumbers has already provide enough - // information for shape inference and code generation of mhlo::GatherOp. '?' - // will be filled into slice_sizes for dimensions that are dynamic sized. + // information for shape inference and code generation of stablehlo::GatherOp. + // '?' will be filled into slice_sizes for dimensions that are dynamic sized. // TODO(disc): To recover static special case's performance with folding and // canonicalization. LogicalResult matchAndRewrite(TF::GatherNdOp op, @@ -1450,18 +1489,18 @@ class ConvertGatherNdOpDynamic : public OpRewritePattern { // index_vector_dim int64_t index_vector_dim = indices_rank - 1; - auto dims_attr = GatherDimensionNumbersAttr::get( + auto dims_attr = stablehlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), offset_dims, collapsed_slice_dims, /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, start_index_map, index_vector_dim); // TODO(disc): Remove this if-statement once fold and canonicalization is // implemented. if (params_ty.hasStaticShape() && indices_ty.hasStaticShape()) { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), op.getParams(), op.getIndices(), dims_attr, - GetI64ElementsAttr(slice_sizes, &rewriter)); + GetI64ArrayAttr(slice_sizes, &rewriter)); } else { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), op.getParams(), op.getIndices(), slice_sizes_value, dims_attr); } @@ -1496,16 +1535,18 @@ class ConvertBF16FloorDivOp : public OpRewritePattern { auto out_type = op.getZ().getType(); - l = rewriter.create(op.getLoc(), l, rewriter.getF32Type()); - r = rewriter.create(op.getLoc(), r, rewriter.getF32Type()); + l = rewriter.create(op.getLoc(), l, + rewriter.getF32Type()); + r = rewriter.create(op.getLoc(), r, + rewriter.getF32Type()); auto intermediate = rewriter.create( op.getLoc(), ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l, r); - auto floor_op = - rewriter.create(op.getLoc(), out_type, intermediate); + auto floor_op = rewriter.create(op.getLoc(), out_type, + intermediate); rewriter.replaceOp(op, floor_op.getResult()); return success(); } @@ -1534,9 +1575,9 @@ class ConvertBroadcastToOp : public OpRewritePattern { broadcast_dimensions = llvm::to_vector<4>( llvm::seq(rank_diff, ranked_output_type.getRank())); } - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, output_type, op.getInput(), op.getShape(), - rewriter.getI64TensorAttr(broadcast_dimensions)); + GetI64ArrayAttr(broadcast_dimensions, &rewriter)); return success(); } }; @@ -1574,25 +1615,27 @@ class ConvertRollOp : public OpRewritePattern { // offset = ((offset % axis_size) + axis_size) % axis_size ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value offset = op.getShift(); - auto axis_size = b.create(b.getIntegerAttr( + auto axis_size = b.create(b.getIntegerAttr( getElementTypeOrSelf(offset.getType()), input_shape[axis])); - offset = b.create( - b.create(b.create(offset, axis_size), axis_size), + offset = b.create( + b.create( + b.create(offset, axis_size), axis_size), axis_size); // Stack two copies of the dimension, then slice from the calculated // offset. This also works if shift is not constant. // DynamicSliceOp requires the sizes being integer, and we can get the // information from input shape. - auto concat = b.create( + auto concat = b.create( ValueRange{op.getInput(), op.getInput()}, b.getI64IntegerAttr(axis)); - Value zero = b.create( + Value zero = b.create( b.getIntegerAttr(getElementTypeOrSelf(offset.getType()), 0)); SmallVector slice_begin_indices(input_rank, zero); - slice_begin_indices[axis] = b.create(axis_size, offset); - rewriter.replaceOpWithNewOp( + slice_begin_indices[axis] = + b.create(axis_size, offset); + rewriter.replaceOpWithNewOp( op, input_ty, concat, slice_begin_indices, - rewriter.getI64TensorAttr(input_shape)); + GetI64ArrayAttr(input_shape, &rewriter)); return success(); } }; @@ -1613,13 +1656,13 @@ class ConvertLeakyReluOp : public OpRewritePattern { Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); Value leakyActivationVal = - rewriter.create(loc, features, alphaVal); + rewriter.create(loc, features, alphaVal); - Value compareGtZero = rewriter.create( - loc, features, zeroVal, ComparisonDirection::GT); + Value compareGtZero = rewriter.create( + loc, features, zeroVal, stablehlo::ComparisonDirection::GT); - rewriter.replaceOpWithNewOp(op, compareGtZero, features, - leakyActivationVal); + rewriter.replaceOpWithNewOp( + op, compareGtZero, features, leakyActivationVal); return success(); } }; @@ -1643,29 +1686,29 @@ class ConvertLeakyReluGradOp : public OpRewritePattern { Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); Value leakyGradientVal = - rewriter.create(loc, gradients, alphaVal); + rewriter.create(loc, gradients, alphaVal); - Value compareGtZero = rewriter.create( - loc, features, zeroVal, ComparisonDirection::GT); + Value compareGtZero = rewriter.create( + loc, features, zeroVal, stablehlo::ComparisonDirection::GT); - rewriter.replaceOpWithNewOp(op, featureType, compareGtZero, - gradients, leakyGradientVal); + rewriter.replaceOpWithNewOp( + op, featureType, compareGtZero, gradients, leakyGradientVal); return success(); } }; // Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix. // For a Rank-2 input, it creates the following ops: -// %1 = "mhlo.iota"() {iota_dimension = 0 : i64} -// %2 = "mhlo.iota"() {iota_dimension = 1 : i64} -// %3 = "mhlo.compare"(%1, %2) {comparison_direction = "EQ"} -// %4 = mhlo.constant dense<0.000000e+00> : tensor -// %5 = "mhlo.broadcast"(%4) -// %6 = "mhlo.select"(%3, %input, %5) -// %7 = "mhlo.reduce"(%6, %4) ({ +// %1 = "stablehlo.iota"() {iota_dimension = 0 : i64} +// %2 = "stablehlo.iota"() {iota_dimension = 1 : i64} +// %3 = "stablehlo.compare"(%1, %2) {comparison_direction = "EQ"} +// %4 = stablehlo.constant dense<0.000000e+00> : tensor +// %5 = "stablehlo.broadcast"(%4) +// %6 = "stablehlo.select"(%3, %input, %5) +// %7 = "stablehlo.reduce"(%6, %4) ({ // ^bb0(%arg1: tensor, %arg2: tensor): -// %9 = mhlo.add %arg1, %arg2 : tensor -// "mhlo.return"(%9) : (tensor) -> () +// %9 = stablehlo.add %arg1, %arg2 : tensor +// "stablehlo.return"(%9) : (tensor) -> () // }) {dimensions = dense<0> : tensor<1xi64>} // // If the input's rank N is greater than 2, we will reshape it to R2 first and @@ -1690,35 +1733,35 @@ class ConvertDiagPartOp : public OpRewritePattern { new_size *= input_type.getDimSize(i); new_dims.push_back(input_type.getDimSize(i)); } - Value reshaped_input = rewriter.create( + Value reshaped_input = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape({new_size, new_size}, input_type.getElementType()), op.getInput()); auto iota_type = tensorflow::GetTypeFromTFTensorShape( {new_size, new_size}, rewriter.getIntegerType(32)); - auto iota0 = rewriter.create(op.getLoc(), iota_type, - rewriter.getI64IntegerAttr(0)); - auto iota1 = rewriter.create(op.getLoc(), iota_type, - rewriter.getI64IntegerAttr(1)); - Value compare = rewriter.create(op.getLoc(), iota0, iota1, - ComparisonDirection::EQ); + auto iota0 = rewriter.create( + op.getLoc(), iota_type, rewriter.getI64IntegerAttr(0)); + auto iota1 = rewriter.create( + op.getLoc(), iota_type, rewriter.getI64IntegerAttr(1)); + Value compare = rewriter.create( + op.getLoc(), iota0, iota1, stablehlo::ComparisonDirection::EQ); Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(), 0, &rewriter); - Value zero_matrix = rewriter.create( + Value zero_matrix = rewriter.create( op.getLoc(), reshaped_input.getType(), zero, - GetI64ElementsAttr({new_size, new_size}, &rewriter)); - Value masked = - rewriter.create(op.getLoc(), reshaped_input.getType(), - compare, reshaped_input, zero_matrix); - auto reduce = rewriter.create(op.getLoc(), masked, zero, - GetI64ElementsAttr({0}, &rewriter), - input_type.getElementType()); + GetI64ArrayAttr({new_size, new_size}, &rewriter)); + Value masked = rewriter.create( + op.getLoc(), reshaped_input.getType(), compare, reshaped_input, + zero_matrix); + auto reduce = rewriter.create( + op.getLoc(), masked, zero, GetI64ArrayAttr({0}, &rewriter), + input_type.getElementType()); assert(!input_type.getElementType().isInteger(1) && "data type should not be i1"); - BuildReduceBody(input_type.getElementType(), &reduce.getBody(), - &rewriter); - rewriter.replaceOpWithNewOp( + BuildReduceBody(input_type.getElementType(), + &reduce.getBody(), &rewriter); + rewriter.replaceOpWithNewOp( op, tensorflow::GetTypeFromTFTensorShape(new_dims, input_type.getElementType()), @@ -1756,15 +1799,16 @@ class ConvertMatrixDiagPartV3Op } // Utility method for broadcasting integer constants to a given shape. - BroadcastOp BroadcastConstant(Location loc, Shape shape, int32_t constant, - int int_size, PatternRewriter &rewriter) const { - return rewriter.create( + stablehlo::BroadcastOp BroadcastConstant(Location loc, Shape shape, + int32_t constant, int int_size, + PatternRewriter &rewriter) const { + return rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(shape, rewriter.getIntegerType(int_size)), GetScalarConstOfType(rewriter.getIntegerType(int_size), loc, constant, &rewriter), - GetI64ElementsAttr(shape, &rewriter)); + GetI64ArrayAttr(shape, &rewriter)); } public: @@ -1834,10 +1878,10 @@ class ConvertMatrixDiagPartV3Op RankedTensorType iota_type = tensorflow::GetTypeFromTFTensorShape( indices_shape, rewriter.getIntegerType(32)); - Value iotaM = - rewriter.create(loc, iota_type, rewriter.getI64IntegerAttr(1)); - Value iotaN = - rewriter.create(loc, iota_type, rewriter.getI64IntegerAttr(2)); + Value iotaM = rewriter.create( + loc, iota_type, rewriter.getI64IntegerAttr(1)); + Value iotaN = rewriter.create( + loc, iota_type, rewriter.getI64IntegerAttr(2)); // Boradcasted constants, of the same shape as iotaM and iotaN. Value b_zero = BroadcastConstant(loc, indices_shape, 0, 32, rewriter); @@ -1854,17 +1898,17 @@ class ConvertMatrixDiagPartV3Op // subtract m here. This means we start with the superdiagonals and // move downwards towards the subdiagonals. So the start indices will // be decreasing.) - Value d = rewriter.create(loc, b_k1, iotaM); - Value neg_d = rewriter.create(loc, d); + Value d = rewriter.create(loc, b_k1, iotaM); + Value neg_d = rewriter.create(loc, d); // diag_len_d = min(rows + min(d, 0), cols - max(d, 0)) // (Length of a diagonal for a given d. Same as max_diag_len for m = 0.) - Value diag_len_d = rewriter.create( + Value diag_len_d = rewriter.create( loc, - rewriter.create(loc, b_rows, - rewriter.create(loc, d, b_zero)), - rewriter.create(loc, b_cols, - rewriter.create(loc, d, b_zero))); + rewriter.create( + loc, b_rows, rewriter.create(loc, d, b_zero)), + rewriter.create( + loc, b_cols, rewriter.create(loc, d, b_zero))); // offset is max_diag_len - diag_len_d if we're padding, 0 otherwise. Value cmp; @@ -1883,43 +1927,44 @@ class ConvertMatrixDiagPartV3Op // This offset shifts the diagonals to the "left" or "right", depending // on alignment. - Value offset = rewriter.create( + Value offset = rewriter.create( loc, b_zero.getType(), cmp, - rewriter.create(loc, b_max_diag_len, diag_len_d), b_zero); + rewriter.create(loc, b_max_diag_len, diag_len_d), + b_zero); // x = max(d, 0) - offset // y = max(-d, 0) - offset - Value x = rewriter.create( - loc, rewriter.create(loc, d, b_zero), offset); - Value y = rewriter.create( - loc, rewriter.create(loc, neg_d, b_zero), offset); + Value x = rewriter.create( + loc, rewriter.create(loc, d, b_zero), offset); + Value y = rewriter.create( + loc, rewriter.create(loc, neg_d, b_zero), offset); - Value n_plus_x = rewriter.create(loc, iotaN, x); - Value n_plus_y = rewriter.create(loc, iotaN, y); + Value n_plus_x = rewriter.create(loc, iotaN, x); + Value n_plus_y = rewriter.create(loc, iotaN, y); // GatherOp is happy about letting us index out of bounds values, but those // values will be undefined. So we mask them later. Set up the boolean // expression that tells us which entries, in the output shape, are out of // bounds and thus become the padding_value. - Value x_in_bounds = rewriter.create( + Value x_in_bounds = rewriter.create( loc, rewriter.create(loc, b_false.getType(), n_plus_x, b_zero), rewriter.create(loc, b_false.getType(), n_plus_x, b_cols)); - Value y_in_bounds = rewriter.create( + Value y_in_bounds = rewriter.create( loc, rewriter.create(loc, b_false.getType(), n_plus_y, b_zero), rewriter.create(loc, b_false.getType(), n_plus_y, b_rows)); - Value in_bounds = rewriter.create( + Value in_bounds = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(Shape({num_diags, max_diag_len}), rewriter.getIntegerType(1)), - rewriter.create(loc, x_in_bounds, y_in_bounds)); + rewriter.create(loc, x_in_bounds, y_in_bounds)); // Now combine x and y into the index data structure needed for gather. Shape concat_shape({2, num_diags, max_diag_len}); - Value start_indices = rewriter.create( + Value start_indices = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(concat_shape, rewriter.getIntegerType(32)), @@ -1957,16 +2002,16 @@ class ConvertMatrixDiagPartV3Op // Gather the diagonal entries. // TODO(kramm): For a single diagonal, this might be slower than the // mask + sum approach. Special-case num_diags==1? - auto dims_attr = GatherDimensionNumbersAttr::get( + auto dims_attr = stablehlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/llvm::to_vector<4>(llvm::seq(0, num_dims - 2)), /*collapsedSliceDims=*/collapsed_dims, /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, start_index_map, /*indexVectorDim=*/0); - Value gather = rewriter.create( + Value gather = rewriter.create( loc, op.getInput(), start_indices, dims_attr, - GetI64ElementsAttr(slice_sizes, &rewriter)); + GetI64ArrayAttr(slice_sizes, &rewriter)); // We now need to broadcast the "in_bounds" boolean expression, as well as // the padding value, to do the final select. @@ -1974,22 +2019,22 @@ class ConvertMatrixDiagPartV3Op for (int i = 0; i < output_shape.size() - 2; i++) { broadcast_bounds.push_back(output_shape[i]); } - Value b_in_bounds = rewriter.create( + Value b_in_bounds = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(output_shape, rewriter.getIntegerType(1)), - in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter)); - Value b_padding = rewriter.create( - loc, op.getPaddingValue(), GetI64ElementsAttr(output_shape, &rewriter)); + in_bounds, GetI64ArrayAttr(broadcast_bounds, &rewriter)); + Value b_padding = rewriter.create( + loc, op.getPaddingValue(), GetI64ArrayAttr(output_shape, &rewriter)); // Replace all out-of-bounds values in the result with padding_value. - Value result = - rewriter.create(loc, b_in_bounds, gather, b_padding); + Value result = rewriter.create(loc, b_in_bounds, + gather, b_padding); if (num_diags == 1) { // matrix_diag_part folds away the 1-sized band dimension if we only // extract a single diagonal. - result = rewriter.create(loc, op.getType(), result); + result = rewriter.create(loc, op.getType(), result); } rewriter.replaceOp(op, result); @@ -2012,7 +2057,7 @@ class ConvertEinsumOp : public OpRewritePattern { // creates a scalar constant 1.0 for first operand. if (op.getN() == 1) { equation_str = "," + equation_str; - inputs.push_back(rewriter.create( + inputs.push_back(rewriter.create( op.getLoc(), hlo::getScalarOfType( mlir::getElementTypeOrSelf(op.getOperand(0)), 1))); } @@ -2022,8 +2067,8 @@ class ConvertEinsumOp : public OpRewritePattern { inputs.insert(inputs.end(), operands.begin(), operands.end()); assert(inputs.size() == 2); - rewriter.replaceOpWithNewOp(op, op.getType(), inputs[0], - inputs[1], equation_str); + rewriter.replaceOpWithNewOp( + op, op.getType(), inputs[0], inputs[1], equation_str); return success(); } }; @@ -2084,13 +2129,13 @@ class ConvertFFTOp : public OpRewritePattern { // Last dim larger than expected_dim, slice the input if (input_shape.back() > expected_dim) { - reshaped = rewriter.create( + reshaped = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape(expected_shape, input_ty.getElementType()), - op.getInput(), GetI64ElementsAttr(begin_indices, &rewriter), - GetI64ElementsAttr(expected_shape, &rewriter), - GetI64ElementsAttr(strides, &rewriter)); + op.getInput(), GetI64ArrayAttr(begin_indices, &rewriter), + GetI64ArrayAttr(expected_shape, &rewriter), + GetI64ArrayAttr(strides, &rewriter)); // Last dim smaller than expected_dim, zero-pad the input } else if (input_ty.getShape().back() < expected_dim) { @@ -2099,20 +2144,21 @@ class ConvertFFTOp : public OpRewritePattern { padding.push_back(expected_dim - input_shape.back()); Value zero = GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter); - reshaped = rewriter.create( + reshaped = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(expected_shape, input_ty.getElementType()), - op.getInput(), zero, GetI64ElementsAttr(no_padding, &rewriter), - GetI64ElementsAttr(padding, &rewriter), - GetI64ElementsAttr(no_padding, &rewriter)); + op.getInput(), zero, GetI64ArrayAttr(no_padding, &rewriter), + GetI64ArrayAttr(padding, &rewriter), + GetI64ArrayAttr(no_padding, &rewriter)); } - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), reshaped, - FftTypeAttr::get(rewriter.getContext(), - symbolizeFftType(fft_string).value()), - rewriter.getI64TensorAttr(fft_length)); + stablehlo::FftTypeAttr::get( + rewriter.getContext(), + stablehlo::symbolizeFftType(fft_string).value()), + GetI64ArrayAttr(fft_length, &rewriter)); return success(); } }; @@ -2147,8 +2193,8 @@ class ConvertFusedBatchNormGradBase // To support mixed precision, the statistics type, which maybe more // precise than the input types, are used for this op. Type kernel_type = mlir::cast(scale.getType()).getElementType(); - grad = rewriter.create(loc, grad, kernel_type); - act = rewriter.create(loc, act, kernel_type); + grad = rewriter.create(loc, grad, kernel_type); + act = rewriter.create(loc, act, kernel_type); tensorflow::TensorFormat data_format; if (!FormatFromString(op.getDataFormat().str(), &data_format)) @@ -2167,7 +2213,7 @@ class ConvertFusedBatchNormGradBase SmallVector operand_types = {act.getType(), feature_type, feature_type}; - auto training_op = rewriter.create( + auto training_op = rewriter.create( loc, operand_types, act, scale, mean, var, grad, op.getEpsilon(), feature_dim); @@ -2188,43 +2234,45 @@ class ConvertFusedBatchNormGradBase // scratch1 = rsqrt(var + epsilon) RankedTensorType scalar_float = tensorflow::GetTypeFromTFTensorShape({}, kernel_type); - auto epsilon = rewriter.create( + auto epsilon = rewriter.create( loc, DenseFPElementsAttr::get(scalar_float, {op.getEpsilon()})); auto add_op = rewriter.create( loc, var, epsilon.getResult(), scalar_broadcast_dims); - Value scratch1 = rewriter.create(loc, add_op); + Value scratch1 = rewriter.create(loc, add_op); // scratch2 = sum(y_backprop * (x - mean)) - auto sub_op = rewriter.create( + auto sub_op = rewriter.create( loc, act, Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter)); - auto weighted_grad = rewriter.create(loc, grad, sub_op); + auto weighted_grad = rewriter.create(loc, grad, sub_op); Value scratch2 = ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter); // x_backprop = y_backprop * (scale * scratch1) auto scaled_grad = - rewriter.create(loc, op.getScale(), scratch1); - x_backprop = rewriter.create( + rewriter.create(loc, op.getScale(), scratch1); + x_backprop = rewriter.create( loc, grad, Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim, rewriter)); // scale_backprop = scratch2 * scratch1 - scale_backprop = rewriter.create(loc, scratch1, scratch2); + scale_backprop = + rewriter.create(loc, scratch1, scratch2); // offset_backprop = sum(y_backprop) offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter); } - x_backprop = rewriter.create(loc, x_backprop, act_ele_type); + x_backprop = + rewriter.create(loc, x_backprop, act_ele_type); Value last_val[2]; if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) { // It doesn't matter what values we provide for the last 2 results. last_val[0] = last_val[1] = op.getX(); } else { - auto const_val = rewriter.create( + auto const_val = rewriter.create( op.getLoc(), DenseElementsAttr::get( tensorflow::GetTypeFromTFTensorShape( {0}, getElementTypeOrSelf(op.getResult(3))), @@ -2285,7 +2333,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // TODO(b/69928690): Support mixed precision in the XLA batch // normalization operators. As a workaround, create a new x with the same // element type as scale (which may be more precise than the input type). - Value bn_train_input = rewriter.create( + Value bn_train_input = rewriter.create( op.getLoc(), op.getX(), scale_element_type); TensorType bn_train_input_type_tensor = mlir::cast(bn_train_input.getType()); @@ -2303,7 +2351,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // batch_mean, and batch_var. SmallVector operand_types = {bn_train_input_type_tensor, mean_var_type, mean_var_type}; - auto bn_train_op = rewriter.create( + auto bn_train_op = rewriter.create( op.getLoc(), operand_types, bn_train_input, op.getScale(), op.getOffset(), op.getEpsilon(), feature_dim.getInt()); // HLO op outputs a tuple of tensors. Extract those results. @@ -2320,7 +2368,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { int sample_size_minus_one = std::max(1, sample_size - 1); double factor = static_cast(sample_size) / static_cast(sample_size_minus_one); - auto factor_const_op = rewriter.create( + auto factor_const_op = rewriter.create( op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor)); Value corrected_variance = rewriter.create( @@ -2329,16 +2377,16 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // Convert back to input type to stay aligned with expected output type // for TF op. - y_out = rewriter.create(op.getLoc(), y_out, - input_element_type); + y_out = rewriter.create(op.getLoc(), y_out, + input_element_type); float exponential_avg_factor = op.getExponentialAvgFactor().convertToFloat(); if (exponential_avg_factor != 1.0f) { - auto alpha = rewriter.create( + auto alpha = rewriter.create( op.getLoc(), rewriter.getFloatAttr(mean_element_type, 1.0f - exponential_avg_factor)); - auto beta = rewriter.create( + auto beta = rewriter.create( op.getLoc(), rewriter.getFloatAttr(mean_element_type, exponential_avg_factor)); @@ -2385,7 +2433,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { : 0; auto const_attr_type = tensorflow::GetTypeFromTFTensorShape( {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); - Value dummy_const = rewriter.create( + Value dummy_const = rewriter.create( op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); if (const_attr_type != reserve_space_3_type) dummy_const = rewriter.create( @@ -2397,7 +2445,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { /*reserve_space_3=*/dummy_const}); } } else { // Inference case. - auto bn_train_op = rewriter.create( + auto bn_train_op = rewriter.create( op.getLoc(), /*result_type=*/bn_train_input_type_tensor, bn_train_input, op.getScale(), op.getOffset(), op.getMean(), op.getVariance(), @@ -2405,8 +2453,8 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // Convert back to input type to stay aligned with expected output type // for TF op. - auto y_out = rewriter.create(op.getLoc(), bn_train_op, - input_element_type); + auto y_out = rewriter.create( + op.getLoc(), bn_train_op, input_element_type); // The mean, variance, and reserved space outputs of the batch norm op are // not used for inference. It doesn't matter what values we provide for @@ -2429,7 +2477,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { : 0; auto const_attr_type = tensorflow::GetTypeFromTFTensorShape( {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); - Value dummy_const = rewriter.create( + Value dummy_const = rewriter.create( op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); if (const_attr_type != reserve_space_3_type) dummy_const = rewriter.create( @@ -2541,7 +2589,7 @@ Operation *AvgPoolDivideByCount( // Build all-ones tensor of same shape as the original input. ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1); - auto all_ones_tensor = rewriter.create(loc, splat); + auto all_ones_tensor = rewriter.create(loc, splat); // Get padding for the input. DenseIntElementsAttr input_padding_attr = @@ -2551,20 +2599,23 @@ Operation *AvgPoolDivideByCount( // Count the 1's in each window, using the same padding as for the input, // which gives us the window counts by which `pooled` needs to be divided. - auto divisor = rewriter.create( + auto divisor = rewriter.create( loc, pooled_type, /*operand=*/all_ones_tensor, /*init_value=*/zero, - /*window_dimensions=*/GetI64ElementsAttr(op.getKsize()), - /*window_strides=*/GetI64ElementsAttr(op.getStrides()), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), + /*window_dimensions=*/ + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getKsize()), &rewriter), + /*window_strides=*/ + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getStrides()), &rewriter), + /*base_dilations=*/DenseI64ArrayAttr(), + /*window_dilations=*/DenseI64ArrayAttr(), /*padding=*/input_padding_attr); - BuildReduceBody(element_type, &divisor.getBody(), &rewriter); + BuildReduceBody(element_type, &divisor.getBody(), + &rewriter); // Divide `pooled` by window counts. - result = rewriter.create(loc, pooled_type, pooled, - divisor.getResult(0)); + result = rewriter.create(loc, pooled_type, pooled, + divisor.getResult(0)); } return result; } @@ -2600,8 +2651,8 @@ class ConvertAvgPoolOp : public OpRewritePattern { // Convert if we need enlarge the element type's bitwidth. if (input_element_type != sum_element_type) - input_value = rewriter.create(op.getLoc(), input_value, - sum_element_type); + input_value = rewriter.create( + op.getLoc(), input_value, sum_element_type); // Create the ReduceWindow op. Value init = @@ -2609,12 +2660,14 @@ class ConvertAvgPoolOp : public OpRewritePattern { DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( input_type.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), &rewriter); - auto reduce = rewriter.create( + auto reduce = rewriter.create( op.getLoc(), result_type, input_value, init, - GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); - BuildReduceBody(sum_element_type, &reduce.getBody(), &rewriter); + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getKsize()), &rewriter), + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getStrides()), &rewriter), + /*base_dilations=*/DenseI64ArrayAttr(), + /*window_dilations=*/DenseI64ArrayAttr(), paddings_attr); + BuildReduceBody(sum_element_type, &reduce.getBody(), + &rewriter); // Count the number of elements in the window. The following calculation // is only valid for no paddings. @@ -2630,8 +2683,8 @@ class ConvertAvgPoolOp : public OpRewritePattern { // Convert back if we enlarged the element type's bitwidth. Value result = result_op->getOpResult(0); if (input_element_type != sum_element_type) - result = - rewriter.create(op.getLoc(), result, input_element_type); + result = rewriter.create(op.getLoc(), result, + input_element_type); rewriter.replaceOp(op, result); return success(); @@ -2772,13 +2825,13 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { out_grad_shape[dim] = low_padding[dim] + high_padding[dim] + (out_grad_shape[dim] - 1) * strides[dim] + 1; } - Value reduce_window_input = rewriter.create( + Value reduce_window_input = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(out_grad_shape, element_type), /*operand=*/out_grad_divided->getOpResult(0), /*padding_value=*/zero, - /*edge_padding_low=*/GetI64ElementsAttr(low_padding, &rewriter), - /*edge_padding_high=*/GetI64ElementsAttr(high_padding, &rewriter), - /*interior_padding=*/GetI64ElementsAttr(interior_padding, &rewriter)); + /*edge_padding_low=*/GetI64ArrayAttr(low_padding, &rewriter), + /*edge_padding_high=*/GetI64ArrayAttr(high_padding, &rewriter), + /*interior_padding=*/GetI64ArrayAttr(interior_padding, &rewriter)); // Compute result by convolving `reduce_window_input` with an all-ones // kernel, using `ReduceWindowOp` with `AddOp` body. @@ -2786,29 +2839,31 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { Type sum_element_type = GetSumAccumulationType(element_type); if (element_type != sum_element_type) { // Convert to appropriate sum accumulation type to avoid precision loss. - reduce_window_input = rewriter.create(loc, reduce_window_input, - sum_element_type); + reduce_window_input = rewriter.create( + loc, reduce_window_input, sum_element_type); zero = GetScalarConstOfType(sum_element_type, loc, 0, &rewriter); } - auto ones = GetI64ElementsAttr(DimVector(num_dims, 1), &rewriter); - auto reduce_window_op = rewriter.create( + auto ones = GetI64ArrayAttr(DimVector(num_dims, 1), &rewriter); + auto reduce_window_op = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(orig_input_shape, sum_element_type), /*operand=*/reduce_window_input, /*init_value=*/zero, - /*window_dimensions=*/GetI64ElementsAttr(op.getKsize()), + /*window_dimensions=*/ + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getKsize()), &rewriter), /*window_strides=*/ones, - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), + /*base_dilations=*/DenseI64ArrayAttr(), + /*window_dilations=*/DenseI64ArrayAttr(), /*padding=*/DenseIntElementsAttr()); - BuildReduceBody(sum_element_type, &reduce_window_op.getBody(), - &rewriter); + BuildReduceBody(sum_element_type, + &reduce_window_op.getBody(), &rewriter); Value result = reduce_window_op.getResult(0); if (element_type != sum_element_type) { // Convert back to original element type. - result = rewriter.create(op.getLoc(), result, element_type); + result = rewriter.create(op.getLoc(), result, + element_type); } rewriter.replaceOp(op, {result}); return success(); @@ -2826,7 +2881,7 @@ using ConvertAvgPool3DGradOp = // Sample result for VALID padding mode: // // %init = arith.constant dense<...> : tensor -// %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] +// %max_pool = "stablehlo.reduce"(%inp, %init) ["stablehlo.maximum"] // {window_dimensions = ..., window_strides = ... } // template @@ -2846,7 +2901,7 @@ class ConvertMaxPoolOp : public OpRewritePattern { return failure(); } Location loc = op.getLoc(); - ConstantOp init = GetScalarLimitConstOfType( + stablehlo::ConstantOp init = GetScalarLimitConstOfType( element_type, loc, hlo::kInfinityLowest, &rewriter); auto input_ty = mlir::dyn_cast(op.getInput().getType()); @@ -2854,12 +2909,14 @@ class ConvertMaxPoolOp : public OpRewritePattern { DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), &rewriter); - auto reduce = rewriter.create( + auto reduce = rewriter.create( loc, op.getType(), op.getInput(), init, - GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); - BuildReduceBody(element_type, &reduce.getBody(), &rewriter); + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getKsize()), &rewriter), + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getStrides()), &rewriter), + /*base_dilations=*/DenseI64ArrayAttr(), + /*window_dilations=*/DenseI64ArrayAttr(), paddings_attr); + BuildReduceBody(element_type, &reduce.getBody(), + &rewriter); rewriter.replaceOp(op, reduce.getResult(0)); return success(); @@ -2869,8 +2926,8 @@ class ConvertMaxPoolOp : public OpRewritePattern { using ConvertMaxPool2DOp = ConvertMaxPoolOp; using ConvertMaxPool3DOp = ConvertMaxPoolOp; -// Converts tf.Select (SelectV1) to mhlo.select. It has optional broadcasting on -// the condition only. +// Converts tf.Select (SelectV1) to stablehlo.select. It has optional +// broadcasting on the condition only. class ConvertSelectOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -2931,13 +2988,13 @@ class ConvertSelectOp : public OpRewritePattern { if (needs_broadcast) { Value result_extents = b.create( GetExtentsTensorTypeFor(result_type), then_shape); - cond = b.create( + cond = b.create( tensorflow::GetTypeFromTFTensorShape(result_type.getShape(), b.getI1Type()), cond, result_extents, - GetI64ElementsAttrForSeq(0, cond_type.getRank(), &b)); + GetI64ArrayAttrForSeq(0, cond_type.getRank(), &b)); } - Value select = b.create( + Value select = b.create( result_type, cond, op.getThenValue(), op.getElseValue()); b.create(select); rewriter.replaceOp(op, {assuming_op.getResult(0)}); @@ -2945,7 +3002,7 @@ class ConvertSelectOp : public OpRewritePattern { } }; -// Converts the tf.Slice op into mhlo.real_dynamic_slice +// Converts the tf.Slice op into stablehlo.real_dynamic_slice // TODO(disc): To recover static special case's performance with folding and // canonicalization. class ConvertSliceOpDynamic : public OpRewritePattern { @@ -3025,7 +3082,7 @@ class ConvertSliceOpDynamic : public OpRewritePattern { {static_cast(stride_values.size())}, index_ty), stride_values); - auto d_slice = rewriter.create( + auto d_slice = rewriter.create( loc, op.getOperation()->getResult(0).getType(), input, start_indices, end_indices, stride_indices); rewriter.replaceOp(op, d_slice.getOperation()->getResults()); @@ -3100,8 +3157,8 @@ static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc, class ConvertBatchMatMulV2Op : public OpRewritePattern { public: // TODO(hinsu): Legalize this op to Einsum op. HLO Einsum op needs to be moved - // to CHLO and it is missing legalization to MHLO. Once that is done, this - // pattern's benefit can be changed back to one as well as the fallback + // to CHLO and it is missing legalization to StableHLO. Once that is done, + // this pattern's benefit can be changed back to one as well as the fallback // lowering pattern for the op can be removed. // // Set benefit of this pattern to zero to prefer the fallback pattern when @@ -3138,7 +3195,7 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { llvm::ArrayRef({op.getAdjX() ? rank - 2 : rank - 1})); auto rhs_contracting_dimensions = llvm::to_vector<4>( llvm::ArrayRef({op.getAdjY() ? rank - 1 : rank - 2})); - auto dimension_numbers = DotDimensionNumbersAttr::get( + auto dimension_numbers = stablehlo::DotDimensionNumbersAttr::get( rewriter.getContext(), /*lhs_batching_dimensions=*/batch_dimensions, /*rhs_batching_dimensions=*/batch_dimensions, @@ -3146,10 +3203,10 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { /*rhs_contracting_dimensions=*/rhs_contracting_dimensions); // TODO(silvasean): Emit shape checks for contracting dimensions. // (The batch dimensions are checked by the broadcasting logic) - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), lhs, rhs, dimension_numbers, /*precision_config=*/GetPrecisionConfig(&rewriter), - /*algorithm=*/DotAlgorithmAttr{}); + /*algorithm=*/stablehlo::DotAlgorithmAttr{}); return success(); } }; @@ -3170,20 +3227,20 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { // // will be converted into: // -// %0 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 2]> : tensor<2xi64>, -// start_indices = dense<0> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : +// %0 = "stablehlo.slice"(%input) { +// limit_indices = array, +// start_indices = array, +// strides = array} : // (tensor<4x6xf32>) -> tensor<4x2xf32> -// %1 = "mhlo.slice"(%input) { -// limit_indices = dense<4> : tensor<2xi64>, -// start_indices = dense<[0, 2]> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : +// %1 = "stablehlo.slice"(%input) { +// limit_indices = array, +// start_indices = array, +// strides = array} : // (tensor<4x6xf32>) -> tensor<4x2xf32> -// %2 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 6]> : tensor<2xi64>, -// start_indices = dense<[0, 4]> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : +// %2 = "stablehlo.slice"(%input) { +// limit_indices = array, +// start_indices = array, +// strides = array} : // (tensor<4x6xf32>) -> tensor<4x2xf32> // TODO(antiagainst): consider lowering into TF ops so the pattern can be more // applicable. @@ -3231,11 +3288,11 @@ class ConvertSplitOp : public OpRewritePattern { for (int i = 0; i < num_splits; ++i) { begin_indices[dim_index] = i * slice_size; end_indices[dim_index] = (i + 1) * slice_size; - slices.push_back( - rewriter.create(op.getLoc(), slice_type, op.getValue(), - GetI64ElementsAttr(begin_indices, &rewriter), - GetI64ElementsAttr(end_indices, &rewriter), - GetI64ElementsAttr(strides, &rewriter))); + slices.push_back(rewriter.create( + op.getLoc(), slice_type, op.getValue(), + GetI64ArrayAttr(begin_indices, &rewriter), + GetI64ArrayAttr(end_indices, &rewriter), + GetI64ArrayAttr(strides, &rewriter))); } rewriter.replaceOp(op, slices); @@ -3243,8 +3300,8 @@ class ConvertSplitOp : public OpRewritePattern { } }; -// Converts the tf.Split op into a series of mhlo.real_dynamic_slice ops the -// dimension to split is a constant. +// Converts the tf.Split op into a series of stablehlo.real_dynamic_slice ops +// the dimension to split is a constant. // TODO(disc): To recover static special case's performance with folding and // canonicalization. delete ConvertSplitOp class ConvertSplitOpDynamic : public OpRewritePattern { @@ -3320,7 +3377,7 @@ class ConvertSplitOpDynamic : public OpRewritePattern { tensorflow::GetTypeFromTFTensorShape( {static_cast(strides.size())}, index_ty), strides); - slices.push_back(rewriter.create( + slices.push_back(rewriter.create( loc, op.getOperation()->getResult(i).getType(), input, begin_value, end_value, stride_value)); } @@ -3347,20 +3404,20 @@ class ConvertSplitOpDynamic : public OpRewritePattern { // (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) // // We will generate slices following slices: -// %0 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 1]> : tensor<2xi64>, -// start_indices = dense<0> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : +// %0 = "stablehlo.slice"(%input) { +// limit_indices = array, +// start_indices = array, +// strides = array} : // (tensor<4x6xf32>) -> tensor<4x1xf32> -// %1 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 3]> : tensor<2xi64>, -// start_indices = dense<[0, 1]> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : +// %1 = "stablehlo.slice"(%input) { +// limit_indices = array, +// start_indices = array, +// strides = array} : // (tensor<4x6xf32>) -> tensor<4x2xf32> -// %2 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 6]> : tensor<2xi64>, -// start_indices = dense<[0, 3]> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : +// %2 = "stablehlo.slice"(%input) { +// limit_indices = array, +// start_indices = array, +// strides = array} : // (tensor<4x6xf32>) -> tensor<4x3xf32> class ConvertSplitVOp : public OpRewritePattern { public: @@ -3427,11 +3484,10 @@ class ConvertSplitVOp : public OpRewritePattern { for (int i = 0, end = op.getNumResults(); i < end; ++i) { end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i]; - slices.push_back(rewriter.create( - op.getLoc(), op.getValue(), - GetI64ElementsAttr(begin_indices, &rewriter), - GetI64ElementsAttr(end_indices, &rewriter), - GetI64ElementsAttr(strides, &rewriter))); + slices.push_back(rewriter.create( + op.getLoc(), op.getValue(), GetI64ArrayAttr(begin_indices, &rewriter), + GetI64ArrayAttr(end_indices, &rewriter), + GetI64ArrayAttr(strides, &rewriter))); // Prepare the begin indice for the next slice. begin_indices[dim_index] = end_indices[dim_index]; } @@ -3446,19 +3502,19 @@ class ConvertSplitVOp : public OpRewritePattern { // strides operands are converted to attributes with non-negative indexing. // // If the begin input is not a compile time constant, the begin input needs to -// be sliced and the slice needs to be lowered to mhlo.DynamicSlice. In this -// case, strides must have a known value of 1 (otherwise we have insufficient -// information to conform to XLA's op semantics). +// be sliced and the slice needs to be lowered to stablehlo.DynamicSlice. In +// this case, strides must have a known value of 1 (otherwise we have +// insufficient information to conform to XLA's op semantics). // // For example with an op like following, // tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1} // : tensor -> tensor // // If the %begin input is constant, output would be: -// %reversed = "mhlo.Reverse" (%input) {dimensions = ...} -// %sliced = "mhlo.Slice" (%input) +// %reversed = "stablehlo.Reverse" (%input) {dimensions = ...} +// %sliced = "stablehlo.Slice" (%input) // {start_indices = ..., limit_indices = ..., strides = ...} -// %output = "mhlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor +// %output = "stablehlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor // class ConvertStridedSliceOp : public OpRewritePattern { public: @@ -3512,17 +3568,17 @@ class ConvertStridedSliceOp : public OpRewritePattern { Location loc = op.getLoc(); Value input = op.getInput(); if (!dims_to_reverse.empty()) - input = rewriter.create( + input = rewriter.create( loc, input_ty, op.getInput(), - GetI64ElementsAttr(dims_to_reverse, &rewriter)); - auto sliced = rewriter.create( - loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter), - GetI64ElementsAttr(hlo_end_indices, &rewriter), - GetI64ElementsAttr(hlo_strides, &rewriter)); + GetI64ArrayAttr(dims_to_reverse, &rewriter)); + auto sliced = rewriter.create( + loc, input, GetI64ArrayAttr(hlo_begin_indices, &rewriter), + GetI64ArrayAttr(hlo_end_indices, &rewriter), + GetI64ArrayAttr(hlo_strides, &rewriter)); // Reshape slice result so that the shape is updated depending on // 'new_axis_mask' or 'shrink_axis_mask' attributes. - rewriter.replaceOpWithNewOp(op, op.getType(), sliced); + rewriter.replaceOpWithNewOp(op, op.getType(), sliced); return success(); } @@ -3607,12 +3663,12 @@ class ConvertStridedSliceOp : public OpRewritePattern { continue; } - auto index = rewriter.create( - loc, op.getBegin(), GetI64ElementsAttr({d}, &rewriter), - GetI64ElementsAttr({d + 1}, &rewriter), - GetI64ElementsAttr({1}, &rewriter)); + auto index = rewriter.create( + loc, op.getBegin(), GetI64ArrayAttr({d}, &rewriter), + GetI64ArrayAttr({d + 1}, &rewriter), GetI64ArrayAttr({1}, &rewriter)); // Convert index to scalar. - auto reshaped_index = rewriter.create(loc, type, index); + auto reshaped_index = + rewriter.create(loc, type, index); // If the index is negative, wrap it around with dimension size. auto index_negative = rewriter.create(loc, reshaped_index, zero); @@ -3620,23 +3676,23 @@ class ConvertStridedSliceOp : public OpRewritePattern { input_shape[d], &rewriter); auto wrapped_index = rewriter.create(loc, input_val, reshaped_index); - auto final_index = rewriter.create( + auto final_index = rewriter.create( loc, type, index_negative, wrapped_index, reshaped_index); slice_begin_indices.push_back(final_index); slice_sizes.push_back(1); } - auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter); + auto slice_sizes_attr = GetI64ArrayAttr(slice_sizes, &rewriter); auto sliced_type = tensorflow::GetTypeFromTFTensorShape( slice_sizes, op.getType().getElementType()); // This must be an xla DynamicSlice op due to the inputs that aren't // constant. - auto sliced = rewriter.create( + auto sliced = rewriter.create( loc, sliced_type, op.getInput(), slice_begin_indices, slice_sizes_attr); // Reshape slice result so that the shape is updated depending on // 'new_axis_mask' or 'shrink_axis_mask' attributes. - rewriter.replaceOpWithNewOp(op, op.getType(), sliced); + rewriter.replaceOpWithNewOp(op, op.getType(), sliced); return success(); } @@ -3704,7 +3760,7 @@ class ConvertStridedSliceGradOp Type element_type = mlir::cast(grad.getType()).getElementType(); // Perform reshape to undo any new/shrink axes done by strided slice. - grad = rewriter.create( + grad = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape(shape, element_type), grad); @@ -3741,22 +3797,21 @@ class ConvertStridedSliceGradOp } if (!dims_to_reverse.empty()) { - grad = rewriter.create( + grad = rewriter.create( op.getLoc(), grad.getType(), grad, - GetI64ElementsAttr(dims_to_reverse, &rewriter)); + GetI64ArrayAttr(dims_to_reverse, &rewriter)); } auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter); - rewriter.replaceOpWithNewOp( - op, op.getType(), grad, zero, - GetI64ElementsAttr(padding_low, &rewriter), - GetI64ElementsAttr(padding_high, &rewriter), - GetI64ElementsAttr(padding_interm, &rewriter)); + rewriter.replaceOpWithNewOp( + op, op.getType(), grad, zero, GetI64ArrayAttr(padding_low, &rewriter), + GetI64ArrayAttr(padding_high, &rewriter), + GetI64ArrayAttr(padding_interm, &rewriter)); return success(); } }; -/// Converts the RangeOp tensorflow op to a mhlo.iota op with a scaling and +/// Converts the RangeOp tensorflow op to a stablehlo.iota op with a scaling and /// offset applied to generate the range values. The output tensor needs to /// have a static shape. /// @@ -3765,11 +3820,11 @@ class ConvertStridedSliceGradOp /// : (tensor, tensor, tensor) -> tensor<5xf32> /// /// Output would be: -/// %iota = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32> -/// %scaled = "mhlo.multiply"(%iota, %delta) +/// %iota = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> +/// tensor<5xf32> %scaled = "stablehlo.multiply"(%iota, %delta) /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} : /// (tensor<5xf32>, tensor) -> tensor<5xf32> -/// %result = "mhlo.add"(%scaled, %offset) +/// %result = "stablehlo.add"(%scaled, %offset) /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} : /// (tensor<5xf32>, tensor) -> tensor<5xf32> /// @@ -3785,8 +3840,8 @@ class ConvertRangeOp : public OpRewritePattern { return failure(); } - auto iota = rewriter.create(op.getLoc(), result_type, - rewriter.getI64IntegerAttr(0)); + auto iota = rewriter.create( + op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); auto scaled = rewriter.create( op.getLoc(), result_type, iota, op.getDelta(), hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.getDelta())); @@ -3837,24 +3892,25 @@ class ConvertDynamicRangeOp : public OpRewritePattern { // some conversion to float for the operations. // // %size = ceil(abs((%limit - %start) / %delta)) - auto range = rewriter.create(op.getLoc(), limit, start); - auto abs = rewriter.create(op.getLoc(), range); + auto range = + rewriter.create(op.getLoc(), limit, start); + auto abs = rewriter.create(op.getLoc(), range); // Delta is not necessarily the same type as start and limit. auto abs_cast = - rewriter.create(op.getLoc(), compute_type, abs); + rewriter.create(op.getLoc(), compute_type, abs); auto delta_cast = - rewriter.create(op.getLoc(), compute_type, delta); + rewriter.create(op.getLoc(), compute_type, delta); // Compute the total number of integer steps and convert to the HLO // dimension tensor. auto normalized = - rewriter.create(op.getLoc(), abs_cast, delta_cast); - auto ceil = rewriter.create(op.getLoc(), normalized); - auto steps = rewriter.create( + rewriter.create(op.getLoc(), abs_cast, delta_cast); + auto ceil = rewriter.create(op.getLoc(), normalized); + auto steps = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape({}, rewriter.getI64Type()), ceil); - auto reshape = rewriter.create( + auto reshape = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape({1}, rewriter.getI64Type()), steps); @@ -3864,12 +3920,12 @@ class ConvertDynamicRangeOp : public OpRewritePattern { // %range = %start + %delta * iota(%size) auto out_scalar_type = tensorflow::GetTypeFromTFTensorShape( {}, getElementTypeOrSelf(result_type)); - auto start_out_cast = - rewriter.create(op.getLoc(), out_scalar_type, start); - auto delta_out_cast = - rewriter.create(op.getLoc(), out_scalar_type, delta); + auto start_out_cast = rewriter.create( + op.getLoc(), out_scalar_type, start); + auto delta_out_cast = rewriter.create( + op.getLoc(), out_scalar_type, delta); - auto iota = rewriter.create( + auto iota = rewriter.create( op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0)); auto scaled = rewriter.create( op.getLoc(), result_type, iota, delta_out_cast, @@ -3881,7 +3937,8 @@ class ConvertDynamicRangeOp : public OpRewritePattern { } }; -ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) { +DenseI64ArrayAttr ConvertAxisAttr(Value val, ElementsAttr attr, + Builder *builder) { auto int_attr = mlir::cast(attr); auto type = mlir::cast(val.getType()); @@ -3893,10 +3950,10 @@ ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) { axis.push_back((val.getSExtValue() + rank) % rank); } - return builder->getI64TensorAttr(axis); + return builder->getDenseI64ArrayAttr(axis); } -/// Converts the LinSpace tensorflow op to a mhlo.iota op with a scaling +/// Converts the LinSpace tensorflow op to a stablehlo.iota op with a scaling /// and offset applied to generate the linspace values. The output tensor needs /// to have a static shape. The implementation is defined in C++ because there /// is no type inference for the iota op. @@ -3926,7 +3983,7 @@ class ConvertLinSpaceOp : public OpRewritePattern { op.getLoc(), op.getStart().getType(), op.getStop(), op.getStart(), hlo::getBroadcastDimensionsAttr(&rewriter, op.getStop(), op.getStart())); - Value step_denominator = rewriter.create( + Value step_denominator = rewriter.create( op.getLoc(), op.getNum(), result_type.getElementType()); if (num > 1) { Value one = GetScalarConstOfType(result_type.getElementType(), @@ -3941,8 +3998,8 @@ class ConvertLinSpaceOp : public OpRewritePattern { step_denominator)); // Scale the iota and add the offset. - auto iota = rewriter.create(op.getLoc(), result_type, - rewriter.getI64IntegerAttr(0)); + auto iota = rewriter.create( + op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); auto scaled = rewriter.create( op.getLoc(), result_type, iota, step, hlo::getBroadcastDimensionsAttr(&rewriter, iota, step)); @@ -3953,7 +4010,7 @@ class ConvertLinSpaceOp : public OpRewritePattern { } }; -/// Converts a generic OpTy tensorflow op to a mhlo.reduce op over +/// Converts a generic OpTy tensorflow op to a stablehlo.reduce op over /// ReductionOp. /// `is_accumulation` controls whether it uses higher precision for the actual /// reduction. This is set to false for ops like max where there is no precision @@ -4011,15 +4068,15 @@ class GenericConvertReductionOp : public OpRewritePattern { // repeated arithmetic operations. Type reduce_element_type = is_accumulation ? GetAccumulationType(element_type) : element_type; - auto casted_input = - rewriter.create(loc, op.getInput(), reduce_element_type); + auto casted_input = rewriter.create( + loc, op.getInput(), reduce_element_type); // Each reduction op can have a different initial value. Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter); - auto reduction = rewriter.create( + auto reduction = rewriter.create( loc, casted_input.getResult(), init, - GetI64ElementsAttr(xla_dimensions, &rewriter), reduce_element_type); + GetI64ArrayAttr(xla_dimensions, &rewriter), reduce_element_type); BuildReduceBody(reduce_element_type, &reduction.getBody(), &rewriter); Value result = reduction.getResult(0); @@ -4043,7 +4100,7 @@ class GenericConvertReductionOp : public OpRewritePattern { Value divisor_tensor = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape({}, rewriter.getI64Type()), divisor_casted); - Value divisor = rewriter.create( + Value divisor = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape({}, reduce_element_type), divisor_tensor); auto broadcast_dims = rewriter.getDenseI64ArrayAttr({}); @@ -4051,7 +4108,7 @@ class GenericConvertReductionOp : public OpRewritePattern { broadcast_dims); } - result = rewriter.create(loc, result, element_type); + result = rewriter.create(loc, result, element_type); // Need to reshape back after the reduction if we're keeping the reduced // dimensions. Note that we do this through successive (nominally 1) @@ -4079,12 +4136,13 @@ class GenericConvertReductionOp : public OpRewritePattern { // Converts Mean op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"] +// %sum = "stablehlo.reduce"(%inp, %init) ["stablehlo.add"] // {dimensions = ...} // %divisor = arith.constant dense<...> : tensor -// %mean = "mhlo.divide"(%sum, %divisor) +// %mean = "stablehlo.divide"(%sum, %divisor) class ConvertMeanOp - : public GenericConvertReductionOp { + : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; static Value GetInitialValue(Type reduce_element_type, Location loc, @@ -4096,10 +4154,10 @@ class ConvertMeanOp // Converts Sum op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"] +// %sum = "stablehlo.reduce"(%inp, %init) ["stablehlo.add"] // {dimensions = ...} -class ConvertSumOp - : public GenericConvertReductionOp { +class ConvertSumOp : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; @@ -4113,10 +4171,11 @@ class ConvertSumOp // Converts Max op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %max = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] +// %max = "stablehlo.reduce"(%inp, %init) ["stablehlo.maximum"] // {dimensions = ...} class ConvertMaxOp - : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; @@ -4131,10 +4190,11 @@ class ConvertMaxOp // Converts Min op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %min = "mhlo.reduce"(%inp, %init) ["mhlo.minimum"] +// %min = "stablehlo.reduce"(%inp, %init) ["stablehlo.minimum"] // {dimensions = ...} class ConvertMinOp - : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; @@ -4149,10 +4209,11 @@ class ConvertMinOp // Converts Prod op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %prod = "mhlo.reduce"(%inp, %init) ["mhlo.multiply"] +// %prod = "stablehlo.reduce"(%inp, %init) ["stablehlo.multiply"] // {dimensions = ...} class ConvertProdOp - : public GenericConvertReductionOp { + : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; @@ -4165,10 +4226,10 @@ class ConvertProdOp // Converts All op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %max = "mhlo.reduce"(%inp, %init) ["mhlo.and"] +// %max = "stablehlo.reduce"(%inp, %init) ["stablehlo.and"] // {dimensions = ...} -class ConvertAllOp - : public GenericConvertReductionOp { +class ConvertAllOp : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; static Value GetInitialValue(Type reduce_element_type, Location loc, @@ -4180,10 +4241,10 @@ class ConvertAllOp // Converts Any op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %max = "mhlo.reduce"(%inp, %init) ["mhlo.or"] +// %max = "stablehlo.reduce"(%inp, %init) ["stablehlo.or"] // {dimensions = ...} -class ConvertAnyOp - : public GenericConvertReductionOp { +class ConvertAnyOp : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; static Value GetInitialValue(Type reduce_element_type, Location loc, @@ -4240,17 +4301,15 @@ class ConvertArgMinMaxOp : public OpRewritePattern { IntegerAttr iota_dimension = IntegerAttr::get(rewriter.getIntegerType(64), axis); Value input_shape = rewriter.create(loc, op.getInput()); - Value index_values = rewriter.create( + Value index_values = rewriter.create( loc, index_type, input_shape, iota_dimension); Value operands[] = {op.getInput(), index_values}; Value init_values[] = {init_value, index_init_value}; - DenseIntElementsAttr reduction_dimensions = - GetI64ElementsAttr({axis}, &rewriter); - auto reduction = rewriter.create( + auto reduction = rewriter.create( loc, llvm::ArrayRef(operands), - llvm::ArrayRef(init_values), reduction_dimensions, + llvm::ArrayRef(init_values), GetI64ArrayAttr({axis}, &rewriter), TypeRange({input_element_type, index_element_type})); auto direction = Derived::GetDirection(); BuildArgMinMaxReductionBody(input_element_type, index_element_type, @@ -4266,8 +4325,8 @@ class ConvertArgMinMaxOp : public OpRewritePattern { // // %init_index = arith.constant dense<...> : tensor // %init = arith.constant dense<...> : tensor -// %reduce = "mhlo.reduce"(%selected_input, %select_index, %init, -// %init_index) ["mhlo.arg_max"] +// %reduce = "stablehlo.reduce"(%selected_input, %select_index, %init, +// %init_index) ["stablehlo.arg_max"] class ConvertArgMaxOp : public ConvertArgMinMaxOp { public: @@ -4279,7 +4338,9 @@ class ConvertArgMaxOp hlo::kInfinityLowest, &rewriter); } - static ComparisonDirection GetDirection() { return ComparisonDirection::GE; } + static stablehlo::ComparisonDirection GetDirection() { + return stablehlo::ComparisonDirection::GE; + } }; // Converts tensorflow ArgMin op to mhlo operations. The actual @@ -4287,8 +4348,8 @@ class ConvertArgMaxOp // // %init_index = arith.constant dense<...> : tensor // %init = arith.constant dense<...> : tensor -// %reduce = "mhlo.reduce"(%selected_input, %select_index, %init, -// %init_index) ["mhlo.arg_min"] +// %reduce = "stablehlo.reduce"(%selected_input, %select_index, %init, +// %init_index) ["stablehlo.arg_min"] class ConvertArgMinOp : public ConvertArgMinMaxOp { public: @@ -4300,13 +4361,15 @@ class ConvertArgMinOp hlo::kInfinityMax, &rewriter); } - static ComparisonDirection GetDirection() { return ComparisonDirection::LE; } + static stablehlo::ComparisonDirection GetDirection() { + return stablehlo::ComparisonDirection::LE; + } }; // Converts TF TensorScatterUpdate/Min/Max/Add/Sub op into Scatter Op with // assignment: // -// %result = "mhlo.scatter"(%tensor, %indices, %updates) +// %result = "stablehlo.scatter"(%tensor, %indices, %updates) // { dimensions = ... } // template @@ -4381,7 +4444,7 @@ class ConvertTensorScatterOp : public OpRewritePattern { mlir::dyn_cast(updates.getType()).getRank(); int64_t window_dims = tensor_rank - num_index_dims; - auto dims_attr = ScatterDimensionNumbersAttr::get( + auto dims_attr = stablehlo::ScatterDimensionNumbersAttr::get( rewriter.getContext(), llvm::to_vector<4>( llvm::seq(updates_rank - window_dims, updates_rank)), @@ -4392,7 +4455,7 @@ class ConvertTensorScatterOp : public OpRewritePattern { indices_rank - 1); Location loc = op.getLoc(); - auto scatter = rewriter.create( + auto scatter = rewriter.create( loc, op.getType(), ValueRange(Value(op.getTensor())), op.getIndices(), updates, dims_attr); Derived::BuildScatterBody(tensor_ty.getElementType(), @@ -4416,7 +4479,7 @@ class ConvertTensorScatterUpdateOp Type type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); block->addArguments({type, type}, SmallVector(2, loc)); - builder.create(loc, block->getArgument(1)); + builder.create(loc, block->getArgument(1)); } }; @@ -4433,9 +4496,9 @@ class ConvertTensorScatterAddOp Type type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); block->addArguments({type, type}, SmallVector(2, loc)); - auto add_op = builder.create(loc, block->getArgument(0), - block->getArgument(1)); - builder.create(loc, add_op.getResult()); + auto add_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, add_op.getResult()); } }; @@ -4452,9 +4515,9 @@ class ConvertTensorScatterSubOp Type type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); block->addArguments({type, type}, SmallVector(2, loc)); - auto sub_op = builder.create(loc, block->getArgument(0), - block->getArgument(1)); - builder.create(loc, sub_op.getResult()); + auto sub_op = builder.create( + loc, block->getArgument(0), block->getArgument(1)); + builder.create(loc, sub_op.getResult()); } }; @@ -4471,9 +4534,9 @@ class ConvertTensorScatterMinOp Type type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); block->addArguments({type, type}, SmallVector(2, loc)); - auto min_op = builder.create(loc, block->getArgument(0), - block->getArgument(1)); - builder.create(loc, min_op.getResult()); + auto min_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, min_op.getResult()); } }; @@ -4490,9 +4553,9 @@ class ConvertTensorScatterMaxOp Type type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); block->addArguments({type, type}, SmallVector(2, loc)); - auto max_op = builder.create(loc, block->getArgument(0), - block->getArgument(1)); - builder.create(loc, max_op.getResult()); + auto max_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, max_op.getResult()); } }; @@ -4500,10 +4563,10 @@ class ConvertTensorScatterMaxOp // For shape [S1, S2] and multiples [M1, M2], // MS1 = M1 * S1; MS2 = M2 * S2 // -// %broadcast = mhlo.broadcast_in_dim(%input) { +// %broadcast = stablehlo.broadcast_in_dim(%input) { // broadcast_dimensions = [0, 2] // } -// %result = "mhlo.reshape"(%broadcast) : (tensor) +// %result = "stablehlo.reshape"(%broadcast) : (tensor) // -> tensor class ConvertTileOp : public OpRewritePattern { public: @@ -4556,12 +4619,12 @@ class ConvertTileOp : public OpRewritePattern { tensorflow::GetTypeFromTFTensorShape(broadcasted_shape, element_type); Type output_type = op.getType(); - Value result = rewriter.create( + Value result = rewriter.create( loc, broadcasted_type, op.getInput(), - GetI64ElementsAttr(broadcast_dimensions, &rewriter)); + GetI64ArrayAttr(broadcast_dimensions, &rewriter)); if (output_type != broadcasted_type) { - result = rewriter.create(loc, output_type, result); + result = rewriter.create(loc, output_type, result); } rewriter.replaceOp(op, {result}); @@ -4570,7 +4633,7 @@ class ConvertTileOp : public OpRewritePattern { } }; -// Converts the tf.TileOp op into mhlo.dynamic_reshape +// Converts the tf.TileOp op into stablehlo.dynamic_reshape // TODO(disc): To recover static special case's performance with folding and // canonicalization. class ConvertTileOpDynamic : public OpRewritePattern { @@ -4583,9 +4646,11 @@ class ConvertTileOpDynamic : public OpRewritePattern { // // %out_dim_size = [S1, M1, S2, M2] // %broadcast_dimensions = [1, 3]; - // %broadcast = mhlo.d_broadcast_in_dim(%input, %out_dim_size, %braodcast_dimensions); + // %broadcast = stablehlo.d_broadcast_in_dim( + // %input, %out_dim_size, %braodcast_dimensions); // %shape = [MS1, MS2] - // %result = "mhlo.d_reshape"(%broadcast, %shape) : (tensor) -> tensor + // %result = "stablehlo.d_reshape"(%broadcast, %shape) + // : (tensor) -> tensor // clang-format on LogicalResult matchAndRewrite(TF::TileOp op, PatternRewriter &rewriter) const final { @@ -4640,8 +4705,7 @@ class ConvertTileOpDynamic : public OpRewritePattern { for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) { broadcast_dimensions.push_back(1 + 2 * dim_idx); } - auto broadcast_dims_attr = - GetI64ElementsAttr(broadcast_dimensions, &rewriter); + auto broadcast_dims_attr = GetI64ArrayAttr(broadcast_dimensions, &rewriter); Value out_dim_size_tensor = rewriter.create( loc, @@ -4652,7 +4716,7 @@ class ConvertTileOpDynamic : public OpRewritePattern { ShapedType::kDynamic); RankedTensorType broadcast_type = tensorflow::GetTypeFromTFTensorShape(broadcast_shape, element_type); - Value broadcast = rewriter.create( + Value broadcast = rewriter.create( loc, broadcast_type, input, out_dim_size_tensor, broadcast_dims_attr); // %shape = [MS1, MS2] @@ -4666,8 +4730,8 @@ class ConvertTileOpDynamic : public OpRewritePattern { Value shape = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape({input_rank}, index_ty), shape_values); - rewriter.replaceOpWithNewOp(op, op.getType(), - broadcast, shape); + rewriter.replaceOpWithNewOp(op, op.getType(), + broadcast, shape); return success(); } }; @@ -4694,13 +4758,15 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), &rewriter); - auto result = rewriter.create( + auto result = rewriter.create( loc, op.getType(), op.getOrigInput(), op.getGrad(), GetScalarConstOfType(element_type, loc, 0, &rewriter), - GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getKsize()), &rewriter), + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getStrides()), &rewriter), paddings_attr); - BuildReduceBody(element_type, &result.getScatter(), &rewriter); + BuildReduceBody(element_type, &result.getScatter(), + &rewriter); { OpBuilder::InsertionGuard guard(rewriter); Block *block = rewriter.createBlock(&result.getSelect()); @@ -4710,10 +4776,10 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); block->addArguments({type, type}, SmallVector(2, loc)); - auto reducer = rewriter.create(loc, block->getArgument(0), - block->getArgument(1), - ComparisonDirection::GE); - rewriter.create(loc, reducer.getResult()); + auto reducer = rewriter.create( + loc, block->getArgument(0), block->getArgument(1), + stablehlo::ComparisonDirection::GE); + rewriter.create(loc, reducer.getResult()); } rewriter.replaceOp(op, result); @@ -4728,8 +4794,8 @@ using ConvertMaxPool3DGradOp = ConvertMaxPoolGradOp; // Converts tf.Conv?DBackpropInputOp into: -// %rev_filter = "mhlo.reverse"(%filter) -// %result = "mhlo.convolution"(%out_backprop, %rev_filter) +// %rev_filter = "stablehlo.reverse"(%filter) +// %result = "stablehlo.convolution"(%out_backprop, %rev_filter) template class ConvertConvBackpropInputOp : public OpRewritePattern { public: @@ -4858,8 +4924,8 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { int64_t expanded_output_size = (output_size - 1) * stride + 1; int64_t pad_after = padded_out_size - expanded_output_size - pad_before; - // Populate metadata for the upcoming mhlo.conv op using the result of - // the computations performed above. + // Populate metadata for the upcoming stablehlo.conv op using the result + // of the computations performed above. lhs_dilation.push_back(stride); rhs_dilation.push_back(dilation); paddings.push_back(pad_before); @@ -4889,7 +4955,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { Type filter_element_ty = filter_ty.getElementType(); auto ty = tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); - filter = rewriter.create(op.getLoc(), ty, filter); + filter = rewriter.create(op.getLoc(), ty, filter); // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G]. llvm::SmallVector perm(num_dims + 1); @@ -4897,15 +4963,15 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { std::swap(perm[num_spatial_dims], perm[num_spatial_dims + 1]); std::swap(new_shape[num_spatial_dims], new_shape[num_spatial_dims + 1]); ty = tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); - filter = rewriter.create( - op.getLoc(), ty, filter, GetI64ElementsAttr(perm, &rewriter)); + filter = rewriter.create( + op.getLoc(), ty, filter, GetI64ArrayAttr(perm, &rewriter)); // 3. Reshape to [H, W, ..., in_depth, out_depth / G]. new_shape[num_spatial_dims] *= new_shape[num_spatial_dims + 1]; new_shape[num_spatial_dims + 1] = new_shape.back(); new_shape.pop_back(); ty = tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); - filter = rewriter.create(op.getLoc(), ty, filter); + filter = rewriter.create(op.getLoc(), ty, filter); } SmallVector kernel_spatial_dims; @@ -4913,21 +4979,21 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { std::iota(kernel_spatial_dims.begin(), kernel_spatial_dims.end(), 0); // Mirror the filter in the spatial dimensions. - filter = rewriter.create( - op.getLoc(), filter, - GetI64ElementsAttr(kernel_spatial_dims, &rewriter)); + filter = rewriter.create( + op.getLoc(), filter, GetI64ArrayAttr(kernel_spatial_dims, &rewriter)); // activation gradients // = gradients (with padding and dilation) mirrored_weights - Value result = rewriter.create( + Value result = rewriter.create( op.getLoc(), op.getType(), op.getOutBackprop(), filter, /*window_strides=*/ - GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, - &rewriter), - /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter), - GetI64ElementsAttr(rhs_dilation, &rewriter), + GetI64ArrayAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, + &rewriter), + /*padding=*/paddings_attr, + /*lhs_dilation=*/GetI64ArrayAttr(lhs_dilation, &rewriter), + /*rhs_dilation=*/GetI64ArrayAttr(rhs_dilation, &rewriter), /*window_reversal=*/nullptr, - ConvDimensionNumbersAttr::get( + stablehlo::ConvDimensionNumbersAttr::get( rewriter.getContext(), /*inputBatchDimension=*/batch_dim, /*inputFeatureDimension=*/feature_dim, @@ -4961,7 +5027,7 @@ using ConvertConv3DBackpropInputOp = /*num_spatial_dims=*/3>; // Converts tf.Conv?DBackpropFilterOp into: -// %result = "mhlo.convolution"(%input, %out_backprop) +// %result = "stablehlo.convolution"(%input, %out_backprop) template class ConvertConvBackpropFilterOp : public OpRewritePattern { public: @@ -5125,15 +5191,15 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { const int batch_dim = tensorflow::GetTensorBatchDimIndex(num_dims, data_format); - Value result = rewriter.create( + Value result = rewriter.create( op.getLoc(), op.getType(), op.getInput(), op.getOutBackprop(), - /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter), + /*window_strides=*/GetI64ArrayAttr(window_strides, &rewriter), /*padding=*/paddings_attr, /*lhs_dilation=*/ - GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, - &rewriter), - GetI64ElementsAttr(rhs_dilation, &rewriter), + GetI64ArrayAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, + &rewriter), + GetI64ArrayAttr(rhs_dilation, &rewriter), /*window_reversal=*/nullptr, - ConvDimensionNumbersAttr::get( + stablehlo::ConvDimensionNumbersAttr::get( rewriter.getContext(), // Swap batch_dim and feature_dim in the activations. /*inputBatchDimension=*/feature_dim, @@ -5203,22 +5269,22 @@ class ConvertOneHotOp : public OpRewritePattern { // just using static broadcasting. auto index_type = tensorflow::GetTypeFromTFTensorShape(output_dims, element_type); - auto iota = rewriter.create( + auto iota = rewriter.create( loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis)); - auto broadcast_indices = rewriter.create( + auto broadcast_indices = rewriter.create( loc, index_type, op.getIndices(), - GetI64ElementsAttr(broadcast_dims, &rewriter)); + GetI64ArrayAttr(broadcast_dims, &rewriter)); - Value compare = rewriter.create( - loc, broadcast_indices, iota, ComparisonDirection::EQ); - Value on_value = rewriter.create( + Value compare = rewriter.create( + loc, broadcast_indices, iota, stablehlo::ComparisonDirection::EQ); + Value on_value = rewriter.create( loc, op.getType(), op.getOnValue(), - GetI64ElementsAttr(output_dims, &rewriter)); - Value off_value = rewriter.create( + GetI64ArrayAttr(output_dims, &rewriter)); + Value off_value = rewriter.create( loc, op.getType(), op.getOffValue(), - GetI64ElementsAttr(output_dims, &rewriter)); - Value result = rewriter.create(loc, op.getType(), compare, - on_value, off_value); + GetI64ArrayAttr(output_dims, &rewriter)); + Value result = rewriter.create( + loc, op.getType(), compare, on_value, off_value); rewriter.replaceOp(op, {result}); @@ -5234,17 +5300,17 @@ class ConvertOneHotOp : public OpRewritePattern { // operations within a computation. The token type can come from other // infeed/outfeed/send/recv ops or can be generated using create_token op with // no operands. Here we emit a create_token op to generate the token type -// operand of infeed. The mhlo.InfeedOp can produce multiple results and later -// will be exported to XLA infeed op with single tuple return type. +// operand of infeed. The stablehlo.InfeedOp can produce multiple results and +// later will be exported to XLA infeed op with single tuple return type. // // For example the following IR: // %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>) // // would be lowered to // -// %token = "mhlo.create_token"() : () -> !mhlo.token -// %data_and_token = "mhlo.infeed"(%token) {infeed_config = ""} : -// (!mhlo.token) -> tensor<3xi32>, tensor<4xf32>, !mhlo.token> +// %token = "stablehlo.create_token"() : () -> !stablehlo.token +// %data_and_token = "stablehlo.infeed"(%token) {infeed_config = ""} : +// (!stablehlo.token) -> tensor<3xi32>, tensor<4xf32>, !stablehlo.token> // class ConvertInfeedDequeueTupleOp : public OpRewritePattern { @@ -5265,16 +5331,16 @@ class ConvertInfeedDequeueTupleOp // Infeed takes a single token operand. Generate the token using // create_token op to pass to the infeed op. - auto token = rewriter.create( - op.getLoc(), mhlo::TokenType::get(rewriter.getContext())); + auto token = rewriter.create( + op.getLoc(), stablehlo::TokenType::get(rewriter.getContext())); result_types.push_back(token.getType()); ArrayAttr layout; // filled in during the xla-adjust-layout pass - auto data_and_token = - rewriter.create(op.getLoc(), result_types, token, - /*infeed_config=*/rewriter.getStringAttr(""), - /*layout=*/layout); + auto data_and_token = rewriter.create( + op.getLoc(), result_types, token, + /*infeed_config=*/rewriter.getStringAttr(""), + /*layout=*/layout); result_types.pop_back(); // remove the token type. @@ -5301,9 +5367,9 @@ class ConvertInfeedDequeueTupleOp } if (op->hasAttr("layouts")) { - // Append a UnitAttr for the "token" operand of the mhlo.infeed op here to - // avoid compilation failure when exporting "layouts" attribute of the - // corresponding InfeedDequeueTupleOp to a graph node. + // Append a UnitAttr for the "token" operand of the stablehlo.infeed op + // here to avoid compilation failure when exporting "layouts" attribute of + // the corresponding InfeedDequeueTupleOp to a graph node. data_and_token->setAttr("layout", op->getAttr("layouts")); } llvm::SmallVector results; @@ -5328,10 +5394,11 @@ class ConvertInfeedDequeueTupleOp // // would be lowered to // -// %token = "mhlo.create_token"() : () -> !mhlo.token -// %outfeed_token = "mhlo.outfeed"(%val_1, %val_2, %token) {outfeed_config = ""} +// %token = "stablehlo.create_token"() : () -> !stablehlo.token +// %outfeed_token = "stablehlo.outfeed"(%val_1, %val_2, %token) {outfeed_config +// = ""} // : -// (tensor<3xi32>, tensor<4xf32>, !mhlo.token) -> !mhlo.token +// (tensor<3xi32>, tensor<4xf32>, !stablehlo.token) -> !stablehlo.token // class ConvertOutfeedEnqueueTupleOp : public OpRewritePattern { @@ -5340,11 +5407,13 @@ class ConvertOutfeedEnqueueTupleOp LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op, PatternRewriter &rewriter) const override { - auto token_type = mhlo::TokenType::get(rewriter.getContext()); - auto token = rewriter.create(op.getLoc(), token_type); + auto token_type = stablehlo::TokenType::get(rewriter.getContext()); + auto token = + rewriter.create(op.getLoc(), token_type); - rewriter.create(op.getLoc(), token_type, op.getInputs(), token, - /*outfeed_config=*/rewriter.getStringAttr("")); + rewriter.create( + op.getLoc(), token_type, op.getInputs(), token, + /*outfeed_config=*/rewriter.getStringAttr("")); rewriter.eraseOp(op); return success(); } @@ -5406,11 +5475,10 @@ class ConvertUnpackOp : public OpRewritePattern { begin_indices[axis] = i; end_indices[axis] = i + 1; - auto slice_op = rewriter.create( - op.getLoc(), op.getValue(), - GetI64ElementsAttr(begin_indices, &rewriter), - GetI64ElementsAttr(end_indices, &rewriter), - GetI64ElementsAttr(strides, &rewriter)); + auto slice_op = rewriter.create( + op.getLoc(), op.getValue(), GetI64ArrayAttr(begin_indices, &rewriter), + GetI64ArrayAttr(end_indices, &rewriter), + GetI64ArrayAttr(strides, &rewriter)); // Reshape to drop the axis dimension. auto result = rewriter.create( op.getLoc(), op.getType(i), slice_op, @@ -5487,7 +5555,7 @@ class ConvertUnpackOpDynamic : public OpRewritePattern { for (int64_t i = 0; i < op.getNumResults(); ++i) { begin_indices[axis] = rewriter.create(loc, i, 32); end_indices[axis] = rewriter.create(loc, i + 1, 32); - Value slice_op = rewriter.create( + Value slice_op = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(slice_shape, value_type.getElementType()), @@ -5513,8 +5581,8 @@ class ConvertUnpackOpDynamic : public OpRewritePattern { tensorflow::GetTypeFromTFTensorShape( {static_cast(shape_values.size())}, i32_ty), shape_values); - Value reshape_op = rewriter.create(loc, op.getType(i), - slice_op, new_shape); + Value reshape_op = rewriter.create( + loc, op.getType(i), slice_op, new_shape); results.push_back(reshape_op); } @@ -5551,7 +5619,7 @@ class ConvertSigmoidGradOpDynamic : public OpRewritePattern { assert(mlir::isa(elem_tp)); attr = rewriter.getFloatAttr(elem_tp, 1); } - Value one = rewriter.create( + Value one = rewriter.create( loc, DenseElementsAttr::get( tensorflow::GetTypeFromTFTensorShape({}, elem_tp), attr)); @@ -5616,9 +5684,9 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { // 'operand' parameter to scatter to for the final scatter op. Value init = ConcreteClass::GetInitialValue(data_type.getElementType(), op.getLoc(), &rewriter); - auto broadcasted_init = rewriter.create( + auto broadcasted_init = rewriter.create( op.getLoc(), output_type, init, - GetI64ElementsAttr(output_shape, &rewriter)); + GetI64ArrayAttr(output_shape, &rewriter)); // Parameters for the generated scatter op. SmallVector inserted_window_dims(1, 0); @@ -5626,7 +5694,7 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { int64_t index_vector_dim = segment_ids_rank; // Put all parameters in a StructAttr. - auto dims_attr = ScatterDimensionNumbersAttr::get( + auto dims_attr = stablehlo::ScatterDimensionNumbersAttr::get( rewriter.getContext(), llvm::to_vector<4>(llvm::seq(segment_ids_rank, data_rank)), inserted_window_dims, @@ -5634,7 +5702,7 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { /*scatterIndicesBatchingDims=*/{}, scatter_dims_to_operand_dims, index_vector_dim); - auto scatter = rewriter.create( + auto scatter = rewriter.create( op.getLoc(), op.getType(), ValueRange(Value(broadcasted_init)), op.getSegmentIds(), op.getData(), dims_attr); BuildReduceBody(data_type.getElementType(), @@ -5647,7 +5715,8 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { class ConvertUnsortedSegmentMaxOp : public GenericConvertUnsortedSegmentReductionOp< - ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, MaxOp> { + ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, + stablehlo::MaxOp> { public: using GenericConvertUnsortedSegmentReductionOp:: GenericConvertUnsortedSegmentReductionOp; @@ -5661,7 +5730,8 @@ class ConvertUnsortedSegmentMaxOp class ConvertUnsortedSegmentMinOp : public GenericConvertUnsortedSegmentReductionOp< - ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, MinOp> { + ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, + stablehlo::MinOp> { public: using GenericConvertUnsortedSegmentReductionOp:: GenericConvertUnsortedSegmentReductionOp; @@ -5675,7 +5745,8 @@ class ConvertUnsortedSegmentMinOp class ConvertUnsortedSegmentProdOp : public GenericConvertUnsortedSegmentReductionOp< - ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, MulOp> { + ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, + stablehlo::MulOp> { public: using GenericConvertUnsortedSegmentReductionOp:: GenericConvertUnsortedSegmentReductionOp; @@ -5688,7 +5759,8 @@ class ConvertUnsortedSegmentProdOp class ConvertUnsortedSegmentSumOp : public GenericConvertUnsortedSegmentReductionOp< - ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, AddOp> { + ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, + stablehlo::AddOp> { public: using GenericConvertUnsortedSegmentReductionOp:: GenericConvertUnsortedSegmentReductionOp; @@ -5780,11 +5852,11 @@ class ConvertRandomShuffleOp : public OpRewritePattern { auto keys = CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0, /*upper_limit=*/u32_max, &rewriter); - auto sorted = createSortOp( + auto sorted = stablehlo::createSortOp( &rewriter, op.getLoc(), {keys, current}, {rewriter.getIntegerType(32), input_type.getElementType()}, /*dimension=*/-1, /*isStable=*/false, - /*direction=*/ComparisonDirection::LT); + /*direction=*/stablehlo::ComparisonDirection::LT); current = sorted.getResult(1); } rewriter.replaceOp(op, current); @@ -5796,7 +5868,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { // Generate range(n) as the initial value for the indices to be swapped. auto indices_type = tensorflow::GetTypeFromTFTensorShape( {first_dim_size}, rewriter.getIntegerType(32)); - Value indices = rewriter.create( + Value indices = rewriter.create( op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0)); // Generate random numbers to be used as swaps for the indices. @@ -5812,28 +5884,26 @@ class ConvertRandomShuffleOp : public OpRewritePattern { auto scalar_i32_type = tensorflow::GetTypeFromTFTensorShape({}, builder->getIntegerType(32)); - auto one_cross_i64_type = tensorflow::GetTypeFromTFTensorShape( - {1}, builder->getIntegerType(64)); - auto scalar_one = - DenseIntElementsAttr::get(one_cross_i64_type, ArrayRef(1)); + auto scalar_one = builder->getDenseI64ArrayAttr({1}); // We need to swap the indices[i] with indices[swaps[i]]. First get // these index values. - Value source_index = - builder->create(loc, indices, i, scalar_one); - Value swap_index = builder->create( + Value source_index = builder->create( + loc, indices, i, scalar_one); + Value swap_index = builder->create( loc, scalar_i32_type, - builder->create(loc, swaps, i, scalar_one)); - Value target_index = builder->create( + builder->create(loc, swaps, i, + scalar_one)); + Value target_index = builder->create( loc, indices, swap_index, scalar_one); // Then perform the swap. // indices[i] <- indices[swaps[i]] - indices = builder->create( + indices = builder->create( loc, indices.getType(), indices, target_index, llvm::ArrayRef(i)); // indices[swaps[i]] <- indices[i] - indices = builder->create( + indices = builder->create( loc, indices.getType(), indices, source_index, llvm::ArrayRef(swap_index)); @@ -5850,7 +5920,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { // Gather the data using the swapped indices as the shuffled order. auto slice_sizes = tensorflow::ConvertMlirShapeToTF(input_type.getShape()); slice_sizes[0] = 1; - auto dims_attr = GatherDimensionNumbersAttr::get( + auto dims_attr = stablehlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/llvm::to_vector<4>(llvm::seq(1, input_rank)), /*collapsedSliceDims=*/{0}, @@ -5874,14 +5944,14 @@ class ConvertRandomShuffleOp : public OpRewritePattern { index_to_i64); slice_sizes_values.push_back(i64_to_tensor); } else { - slice_sizes_values.push_back(rewriter.create( + slice_sizes_values.push_back(rewriter.create( op.getLoc(), GetI64ElementsAttr({slice_sizes[i]}, &rewriter))); } } - auto slice_sizes_concat = rewriter.create( + auto slice_sizes_concat = rewriter.create( op.getLoc(), slice_sizes_values, rewriter.getI64IntegerAttr(0)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), op.getValue(), swaped_indices, slice_sizes_concat, dims_attr); @@ -5903,7 +5973,7 @@ class ConvertXlaShardingOp : public OpRewritePattern { NamedAttribute call_target_name = rewriter.getNamedAttr( "call_target_name", rewriter.getStringAttr("Sharding")); - auto custom_call = rewriter.create( + auto custom_call = rewriter.create( op.getLoc(), op.getType(), op.getInput(), ArrayRef{call_target_name}); custom_call->setAttr(kShardingAttr, op.get_XlaShardingAttr()); @@ -5959,8 +6029,8 @@ class ConvertInplaceUpdateOp : public OpRewritePattern { tensorflow::GetTypeFromTFTensorShape(split_updates_shape, updates_type.getElementType())); - auto cst = - rewriter.create(op.getLoc(), zero_attr).getResult(); + auto cst = rewriter.create(op.getLoc(), zero_attr) + .getResult(); auto split_updates = rewriter.create( op.getLoc(), split_updates_type, cst, updates); @@ -5970,7 +6040,7 @@ class ConvertInplaceUpdateOp : public OpRewritePattern { for (auto pair : llvm::zip(unpacked_indices.getOutput(), split_updates.getOutput())) { input_indices.front() = std::get<0>(pair); - input = rewriter.create( + input = rewriter.create( op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices); } @@ -5999,7 +6069,7 @@ class ConvertXlaDynamicUpdateSliceOp auto unpacked_indices = rewriter.create( op.getLoc(), unpacked_indices_type, op.getIndices(), IntegerAttr::get(rewriter.getIntegerType(64), 0)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), op.getInput(), op.getUpdate(), unpacked_indices.getOutput()); return success(); @@ -6029,30 +6099,30 @@ class ConvertXlaReduceScatterOp Location loc = op.getLoc(); Type element_type = getElementTypeOrSelf(op.getInput().getType()); - auto reduce_scatter = rewriter.create( + auto reduce_scatter = rewriter.create( loc, op.getType(), op.getInput(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), scatter_dimension.getSExtValue()), - replica_groups, ChannelHandleAttr()); + replica_groups, stablehlo::ChannelHandleAttr()); StringRef reduce_op = op.getReduceOp(); if (reduce_op == "Add") { - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); + BuildReduceBody( + element_type, &reduce_scatter.getComputation(), &rewriter); } else if (reduce_op == "Mul") { - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); + BuildReduceBody( + element_type, &reduce_scatter.getComputation(), &rewriter); } else if (reduce_op == "Min") { - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); + BuildReduceBody( + element_type, &reduce_scatter.getComputation(), &rewriter); } else if (reduce_op == "Max") { - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); + BuildReduceBody( + element_type, &reduce_scatter.getComputation(), &rewriter); } else { // For mean, add replicas in the same group. Then divide the sum by the // number of replicas in each group below. assert(reduce_op == "Mean"); - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); + BuildReduceBody( + element_type, &reduce_scatter.getComputation(), &rewriter); } Value result = reduce_scatter.getResult(); @@ -6072,7 +6142,7 @@ class ConvertXlaReduceScatterOp } }; -// Converts tf.XlaReduceWindow to mhlo.ReduceWindow +// Converts tf.XlaReduceWindow to stablehlo.ReduceWindow class ConvertXlaReduceWindowOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -6093,17 +6163,13 @@ class ConvertXlaReduceWindowOp Location loc = op.getLoc(); SmallVector result_types{op.getResult().getType()}; - // Create the mhlo.SelectAndScatter op. - auto reduce_window_op = rewriter.create( + // Create the stablehlo.SelectAndScatter op. + auto reduce_window_op = rewriter.create( loc, result_types, op.getInput(), op.getInitValue(), - mlir::cast(hlo::convertElementsAttr( - window_dimensions, rewriter.getIntegerType(64))), - mlir::cast(hlo::convertElementsAttr( - window_strides, rewriter.getIntegerType(64))), - mlir::cast(hlo::convertElementsAttr( - base_dilations, rewriter.getIntegerType(64))), - mlir::cast(hlo::convertElementsAttr( - window_dilations, rewriter.getIntegerType(64))), + ToDenseI64ArrayAttr(window_dimensions, &rewriter), + ToDenseI64ArrayAttr(window_strides, &rewriter), + ToDenseI64ArrayAttr(base_dilations, &rewriter), + ToDenseI64ArrayAttr(window_dilations, &rewriter), mlir::cast( hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); // Insert a call to the reducer in the region of the mhlo op. @@ -6156,7 +6222,8 @@ class ConvertClipByValueOp : public OpRewritePattern { rewriter.create(op.getLoc(), input_ty, max, shape); } - rewriter.replaceOpWithNewOp(op, input_ty, min, input, max); + rewriter.replaceOpWithNewOp(op, input_ty, min, input, + max); return success(); } }; @@ -6176,7 +6243,7 @@ class ConvertConstOp : public OpRewritePattern { return failure(); Location loc = op.getLoc(); - Value result = rewriter.create(loc, op.getValue()); + Value result = rewriter.create(loc, op.getValue()); if (result.getType() != op.getType()) result = rewriter.create(loc, op.getType(), result); rewriter.replaceOp(op, result); @@ -6196,10 +6263,12 @@ class ConvertCumOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const override { auto input = mlir::dyn_cast>(op.getX()); - if (!input) return failure(); + if (!input) { + return rewriter.notifyMatchFailure(op, "input X not ranked tensor"); + } auto input_type = mlir::dyn_cast(input.getType()); if (!input_type || !input_type.hasStaticShape()) { - return failure(); + return rewriter.notifyMatchFailure(op, "input not static shape"); } ArrayRef input_shape = input_type.getShape(); @@ -6208,7 +6277,7 @@ class ConvertCumOp : public OpRewritePattern { // We can only match when the axis is a constant scalar. DenseIntElementsAttr axis_attr; if (!matchPattern(op.getAxis(), m_Constant(&axis_attr))) { - return failure(); + return rewriter.notifyMatchFailure(op, "axis not constant"); } // Get the dimension to apply the reduction on, and offset properly if it is @@ -6222,8 +6291,8 @@ class ConvertCumOp : public OpRewritePattern { // the input and then later reverse the output. if (op.getReverse()) { llvm::SmallVector dims_to_reverse({axis}); - input = rewriter.create( - op.getLoc(), input, GetI64ElementsAttr(dims_to_reverse, &rewriter)); + input = rewriter.create( + op.getLoc(), input, GetI64ArrayAttr(dims_to_reverse, &rewriter)); } // Convert if we need to enlarge the element type's bitwidth to avoid @@ -6231,10 +6300,14 @@ class ConvertCumOp : public OpRewritePattern { Type input_element_type = input_type.getElementType(); // TODO(hinsu): Handle complex element types. - if (!input_element_type.isIntOrFloat()) return failure(); + if (!input_element_type.isIntOrFloat()) { + return rewriter.notifyMatchFailure(op, + "input element type not int or float"); + } Type sum_element_type = GetSumAccumulationType(input_element_type); - input = rewriter.create(op.getLoc(), input, sum_element_type); + input = rewriter.create(op.getLoc(), input, + sum_element_type); SmallVector window_dims(rank, 1); SmallVector window_strides(rank, 1); @@ -6248,16 +6321,17 @@ class ConvertCumOp : public OpRewritePattern { {rank, 2}, rewriter.getIntegerType(64)), paddings); - int64_t init_value = (std::is_same::value) ? 0 : 1; + int64_t init_value = + (std::is_same::value) ? 0 : 1; Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value, &rewriter); - auto reduce = rewriter.create( + auto reduce = rewriter.create( op.getLoc(), input.getType(), input, init, - GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_dims)), - GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); + GetI64ArrayAttr(window_dims, &rewriter), + GetI64ArrayAttr(window_strides, &rewriter), + /*base_dilations=*/DenseI64ArrayAttr(), + /*window_dilations=*/DenseI64ArrayAttr(), paddings_attr); BuildReduceBody(sum_element_type, &reduce.getBody(), &rewriter); Value result = reduce.getResult(0); @@ -6272,20 +6346,20 @@ class ConvertCumOp : public OpRewritePattern { llvm::SmallVector interior_padding(rank, 0); low_padding[axis] = 1; high_padding[axis] = -1; - result = rewriter.create( - op.getLoc(), result, init, GetI64ElementsAttr(low_padding, &rewriter), - GetI64ElementsAttr(high_padding, &rewriter), - GetI64ElementsAttr(interior_padding, &rewriter)); + result = rewriter.create( + op.getLoc(), result, init, GetI64ArrayAttr(low_padding, &rewriter), + GetI64ArrayAttr(high_padding, &rewriter), + GetI64ArrayAttr(interior_padding, &rewriter)); } // Convert back if we enlarged the element type's bitwidth. - result = - rewriter.create(op.getLoc(), result, input_element_type); + result = rewriter.create(op.getLoc(), result, + input_element_type); if (op.getReverse()) { llvm::SmallVector dims_to_reverse({axis}); - result = rewriter.create( - op.getLoc(), result, GetI64ElementsAttr(dims_to_reverse, &rewriter)); + result = rewriter.create( + op.getLoc(), result, GetI64ArrayAttr(dims_to_reverse, &rewriter)); } rewriter.replaceOp(op, result); @@ -6293,8 +6367,8 @@ class ConvertCumOp : public OpRewritePattern { } }; -using ConvertCumsumOp = ConvertCumOp; -using ConvertCumprodOp = ConvertCumOp; +using ConvertCumsumOp = ConvertCumOp; +using ConvertCumprodOp = ConvertCumOp; // Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard // dialect lowerings. This involves extracting the shape type, extracting and @@ -6374,8 +6448,8 @@ class ConvertDynamicExpandDimsOp : public OpRewritePattern { auto from_extents = rewriter.create(op.getLoc(), dims); - rewriter.replaceOpWithNewOp(op, result_ty, input, - from_extents); + rewriter.replaceOpWithNewOp( + op, result_ty, input, from_extents); return success(); } }; @@ -6421,13 +6495,13 @@ class ConvertDynamicSqueezeOp : public OpRewritePattern { auto from_extents = rewriter.create(op.getLoc(), dims); - rewriter.replaceOpWithNewOp(op, result_ty, input, - from_extents); + rewriter.replaceOpWithNewOp( + op, result_ty, input, from_extents); return success(); } }; -// Converts tf.XlaConvV2 to mhlo.Conv +// Converts tf.XlaConvV2 to stablehlo.Conv class ConvertXlaConvV2Op : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -6446,23 +6520,17 @@ class ConvertXlaConvV2Op : public OpRewritePattern { return failure(); auto window_strides_named_attr = rewriter.getNamedAttr( - "window_strides", - mlir::cast(hlo::convertElementsAttr( - window_strides_attr, rewriter.getIntegerType(64)))); + "window_strides", ToDenseI64ArrayAttr(window_strides_attr, &rewriter)); auto padding_named_attr = rewriter.getNamedAttr( "padding", mlir::cast(hlo::convertElementsAttr( padding_attr, rewriter.getIntegerType(64)))); auto lhs_dilation_named_attr = rewriter.getNamedAttr( - "lhs_dilation", - mlir::cast(hlo::convertElementsAttr( - lhs_dilation_attr, rewriter.getIntegerType(64)))); + "lhs_dilation", ToDenseI64ArrayAttr(lhs_dilation_attr, &rewriter)); auto rhs_dilation_named_attr = rewriter.getNamedAttr( - "rhs_dilation", - mlir::cast(hlo::convertElementsAttr( - rhs_dilation_attr, rewriter.getIntegerType(64)))); + "rhs_dilation", ToDenseI64ArrayAttr(rhs_dilation_attr, &rewriter)); int64_t feature_group_count_val = feature_group_count_attr.getValues()[0].getInt(); @@ -6477,14 +6545,14 @@ class ConvertXlaConvV2Op : public OpRewritePattern { dnums.ParseFromString(op.getDimensionNumbersAttr().getValue().str()); auto dimension_numbers_named_attr = rewriter.getNamedAttr( "dimension_numbers", - xla::ConvertConvDimensionNumbers(dnums, &rewriter)); + xla::stablehlo::ConvertConvDimensionNumbers(dnums, &rewriter)); xla::PrecisionConfig precision_config; precision_config.ParseFromString( op.getPrecisionConfigAttr().getValue().str()); auto precision_config_named_attr = rewriter.getNamedAttr( "precision_config", - xla::ConvertPrecisionConfig(&precision_config, &rewriter)); + xla::stablehlo::ConvertPrecisionConfig(&precision_config, &rewriter)); SmallVector operands{op.getLhs(), op.getRhs()}; NamedAttribute attrs[] = { @@ -6492,13 +6560,13 @@ class ConvertXlaConvV2Op : public OpRewritePattern { lhs_dilation_named_attr, rhs_dilation_named_attr, feature_group_count_named_attr, batch_group_count_named_attr, dimension_numbers_named_attr, precision_config_named_attr}; - rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::ArrayRef(attrs)); + rewriter.replaceOpWithNewOp( + op, op.getType(), operands, llvm::ArrayRef(attrs)); return success(); } }; -// Converts tf.XlaSelectAndScatter to mhlo.SelectAndScatter +// Converts tf.XlaSelectAndScatter to stablehlo.SelectAndScatter class ConvertXlaSelectAndScatterOp : public OpRewritePattern { public: @@ -6516,13 +6584,11 @@ class ConvertXlaSelectAndScatterOp Location loc = op.getLoc(); SmallVector result_types{op.getResult().getType()}; - // Create the mhlo.SelectAndScatter op. - auto select_and_scatter_op = rewriter.create( + // Create the stablehlo.SelectAndScatter op. + auto select_and_scatter_op = rewriter.create( loc, result_types, op.getOperand(), op.getSource(), op.getInitValue(), - mlir::cast(hlo::convertElementsAttr( - window_dimensions, rewriter.getIntegerType(64))), - mlir::cast(hlo::convertElementsAttr( - window_strides, rewriter.getIntegerType(64))), + ToDenseI64ArrayAttr(window_dimensions, &rewriter), + ToDenseI64ArrayAttr(window_strides, &rewriter), mlir::cast( hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); @@ -6545,7 +6611,7 @@ class ConvertXlaSelectAndScatterOp } }; -// Convert tf.XlaSort to mhlo.Sort +// Convert tf.XlaSort to stablehlo.Sort class ConvertXlaSortOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -6554,10 +6620,10 @@ class ConvertXlaSortOp : public OpRewritePattern { PatternRewriter &rewriter) const override { // Create the sort op. Type element_type = getElementTypeOrSelf(op.getInput().getType()); - auto sort_op = - createSortOp(&rewriter, op.getLoc(), {op.getInput()}, {element_type}, - /*dimension=*/-1, /*isStable=*/false, - /*direction=*/ComparisonDirection::LT); + auto sort_op = stablehlo::createSortOp( + &rewriter, op.getLoc(), {op.getInput()}, {element_type}, + /*dimension=*/-1, /*isStable=*/false, + /*direction=*/stablehlo::ComparisonDirection::LT); rewriter.replaceOp(op, sort_op.getResult(0)); return success(); } @@ -6575,7 +6641,7 @@ inline std::optional TensorFlowRngAlgToXla( return std::nullopt; } -// Converts tf.XlaRngBitGenerator op to mhlo.RngBitGenerator op. +// Converts tf.XlaRngBitGenerator op to stablehlo.RngBitGenerator op. class ConvertXlaRngBitGeneratorOp : public OpRewritePattern { public: @@ -6596,10 +6662,10 @@ class ConvertXlaRngBitGeneratorOp return op.emitOpError() << "unknown algorithm"; } - auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get( + auto algorithm_attr = mlir::stablehlo::RngAlgorithmAttr::get( rewriter.getContext(), - *mlir::mhlo::symbolizeRngAlgorithm(xla_alg.value())); - auto rng_bit_generator_op = rewriter.create( + *mlir::stablehlo::symbolizeRngAlgorithm(xla_alg.value())); + auto rng_bit_generator_op = rewriter.create( loc, op.getResultTypes(), algorithm_attr, op.getInitialState()); rewriter.replaceOp(op, rng_bit_generator_op.getResults()); @@ -6608,7 +6674,7 @@ class ConvertXlaRngBitGeneratorOp } }; -// Converts tf.XlaVariadicReduceV2 to mhlo.Reduce +// Converts tf.XlaVariadicReduceV2 to stablehlo.Reduce class ConvertXlaVariadicReduceV2Op : public OpRewritePattern { public: @@ -6626,10 +6692,12 @@ class ConvertXlaVariadicReduceV2Op func_ty.getResults(), [](Type ty) { return mlir::cast(ty).getElementType(); })}; - // Create the mhlo.reduce op. - auto reduce_op = rewriter.create( + // Create the stablehlo.reduce op. + auto reduce_op = rewriter.create( loc, op.getInputs(), op.getInitValues(), - GetI64ElementsAttr(op.getDimensionsToReduce()), elementTypes); + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getDimensionsToReduce()), + &rewriter), + elementTypes); // Insert a call to the reducer in the region of the mhlo op. BuildBodyWithCall(rewriter, loc, func, func_ty, &reduce_op.getBody()); @@ -6640,7 +6708,7 @@ class ConvertXlaVariadicReduceV2Op } }; -// Convert tf.XlaVariadicSort to mhlo.Sort +// Convert tf.XlaVariadicSort to stablehlo.Sort class ConvertXlaVariadicSortOp : public OpRewritePattern { public: @@ -6651,8 +6719,8 @@ class ConvertXlaVariadicSortOp Location loc = op.getLoc(); ElementsAttr dimension; matchPattern(op.getDimension(), m_Constant(&dimension)); - // Create the mhlo.sort op. - auto sort_op = rewriter.create( + // Create the stablehlo.sort op. + auto sort_op = rewriter.create( loc, op.getInputs(), dimension.getValues()[0].getInt(), op.getIsStable()); mlir::SymbolRefAttr func = op.getComparator(); @@ -6667,7 +6735,7 @@ class ConvertXlaVariadicSortOp } }; -// Convert tf.XlaReducePrecision to mhlo.ReducePrecision +// Convert tf.XlaReducePrecision to stablehlo.ReducePrecision class ConvertXlaReducePrecisionOp : public OpRewritePattern { public: @@ -6685,7 +6753,7 @@ class ConvertXlaReducePrecisionOp APInt mantissa_bits = op.getMantissaBitsAttr().getValue(); IntegerAttr new_mantissa_attr = IntegerAttr::get(int32_type, mantissa_bits.truncSSat(32)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), op.getOperand(), new_exponent_attr, new_mantissa_attr); return success(); @@ -6699,7 +6767,7 @@ class LowerYieldOp : public OpConversionPattern { LogicalResult matchAndRewrite( TF::YieldOp op, TF::YieldOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; @@ -6731,20 +6799,20 @@ class LowerControlFlowOp : public OpConversionPattern { // result types. This is only done for the While op for now. llvm::SmallVector element_types; int64_t num_results = op.getNumResults(); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { element_types.reserve(num_results); for (Value value : adaptor.getOperands()) { element_types.push_back(getElementTypeOrSelf(value.getType())); } } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { // Explicitly handle the Case op because it has variadic regions and takes // the number of regions as an input along with the operands. mhlo_op = rewriter.create(loc, op.getResultTypes(), adaptor.getBranchIndex(), op.getBranches().size()); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { llvm::SmallVector while_result_types; while_result_types.reserve(num_results); for (int64_t idx = 0; idx < num_results; ++idx) { @@ -6766,7 +6834,7 @@ class LowerControlFlowOp : public OpConversionPattern { // Update region's entry blocks argument types to handle quantized element // types. - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { TypeConverter::SignatureConversion signature(num_results); Block &block = region.front(); for (const auto &[block_idx, original_ty] : @@ -6787,6 +6855,7 @@ class LowerControlFlowOp : public OpConversionPattern { } // end namespace #include "tensorflow/compiler/mlir/tf2xla/transforms/generated_legalize_tf.inc" + // LINT.IfChange void PopulateLegalizeTfPatterns(MLIRContext *context, RewritePatternSet *patterns) { @@ -6886,12 +6955,21 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertConv2DDynamic, ConvertPadOpDynamic, ConvertGatherNdOpDynamic, - LowerControlFlowOp, - LowerControlFlowOp, - LowerControlFlowOp, + LowerControlFlowOp, + LowerControlFlowOp, + LowerControlFlowOp, LowerYieldOp>(context); // clang-format on } // LINT.ThenChange(:MlirAlwaysOps) -} // end namespace mhlo + +} // end namespace hlo + +namespace mhlo { +// Passthrough to avoid updating downstream users namespacing +void PopulateLegalizeTfPatterns(MLIRContext *context, + RewritePatternSet *patterns) { + hlo::PopulateLegalizeTfPatterns(context, patterns); +} +} // namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index 46f3ebfe19104d..5507c82bc6f479 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -20,8 +20,9 @@ include "mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/Dialect/Tensor/IR/TensorOps.td" include "stablehlo/dialect/ChloOps.td" +include "stablehlo/dialect/StablehloOps.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "mhlo/IR/hlo_ops.td" +include "mhlo/IR/hlo_ops.td" // for hlo_utils.td def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>; def UnsignedIntTensor : TensorOf<[UI8, UI16, UI32, UI64]>; @@ -33,44 +34,51 @@ def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; // BatchNorm op patterns. //===----------------------------------------------------------------------===// -def FalseBoolAttr : AttrConstraint().getValue()">>; -def TrueBoolAttr : AttrConstraint().getValue()">>; +def FalseBoolAttr : AttrConstraint($_self).getValue()">>; +def TrueBoolAttr : AttrConstraint($_self).getValue()">>; def CastValueToI64: NativeCodeCall< "CastValueToI64($0.getLoc(), $1, &$_builder)">; def CastValueToElementType: NativeCodeCall< - "$_builder.create($0.getLoc(), $1, " + "$_builder.create($0.getLoc(), $1, " "getElementTypeOrSelf($2.getType()))">; // Here, $0 is an ElementsAttr with exactly one element of type integer. $1 is // the corresponding value of ranked tensor type whose axis is referred in $0. def GetHLOAxisFromTFAxis : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, $1.getType().cast().getRank(), &$_builder)">; + "$0, llvm::cast($1.getType()).getRank(), &$_builder)">; // Same as the above but with $1 of type operand_range from variadic TensorFlow // input. def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, (*$1.begin()).getType().cast().getRank(), " + "$0, llvm::cast((*$1.begin()).getType()).getRank(), " "&$_builder)">; -def CastElementsToI64Elements : NativeCodeCall< - "hlo::convertElementsAttr(" - "$0.cast(), $_builder.getIntegerType(64)).cast()">; +def CastElementsToI64Elements : NativeCodeCall<[{ + llvm::cast(hlo::convertElementsAttr( + llvm::cast($0), $_builder.getIntegerType(64))) + }]>; -def EmptyDotAlgorithmAttr : NativeCodeCall<"mlir::mhlo::DotAlgorithmAttr{}">; +def CastElementsToI64Array : NativeCodeCall<[{ + ToDenseI64ArrayAttr( + llvm::cast(hlo::convertElementsAttr( + llvm::cast($0), $_builder.getIntegerType(64))), &$_builder) + }]>; + +def EmptyDotAlgorithmAttr : NativeCodeCall<"mlir::stablehlo::DotAlgorithmAttr{}">; def ConstDefaultResultAccuracyAttr : - ConstantAttr; + ConstantAttr; //===----------------------------------------------------------------------===// // ApproximateEqual op pattern. //===----------------------------------------------------------------------===// -class MHLO_ComparisonDirectionValue : - ConstantAttr; +class StableHLO_ComparisonDirectionValue : + ConstantAttr; class CHLO_ComparisonDirectionValue : ConstantAttr; @@ -78,8 +86,8 @@ class CHLO_ComparisonDirectionValue : // TODO(b/228291745): Assert that $x and $y have the same shape. def : Pat<(TF_ApproximateEqualOp:$result $x, $y, $tolerance), (CHLO_BroadcastCompareOp - (MHLO_AbsOp:$abs (MHLO_SubtractOp $x, $y)), - (CastValueToElementType $result, (MHLO_ConstantOp $tolerance), $abs), + (StableHLO_AbsOp:$abs (StableHLO_SubtractOp $x, $y)), + (CastValueToElementType $result, (StableHLO_ConstantOp $tolerance), $abs), (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE))>; @@ -136,7 +144,7 @@ def LowerRightShiftUnsigned : // // return floor(div(x, y)) def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), - (MHLO_FloorOp + (StableHLO_FloorOp (CHLO_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))), [(IEEEFloatTensor $l)]>; @@ -151,7 +159,7 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), // dimensions. This computes the broadcast of 'l' to broadcast('l', 'r') // without returning the broadcast of 'r' to broadcast('l', 'r'). def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), - (MHLO_SelectOp + (StableHLO_SelectOp (CHLO_BroadcastAndOp (CHLO_BroadcastCompareOp (CHLO_BroadcastMulOp:$mul @@ -162,18 +170,18 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp (CHLO_BroadcastCompareOp:$l_cmp $l, - (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), + (StableHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp:$r_cmp $r, - (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), + (StableHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (BinBroadcastDimensions $l_cmp, $r_cmp), CHLO_ComparisonDirectionValue<"NE">, (CHLO_DEFAULT_COMPARISON_TYPE)), (NullDenseI64ArrayAttr)), (CHLO_BroadcastSubOp $div, - (MHLO_ConstantOp:$ones (GetScalarOfType<1> $div)), + (StableHLO_ConstantOp:$ones (GetScalarOfType<1> $div)), (NullDenseI64ArrayAttr)), $div), [(SignedIntTensor $l)]>; @@ -189,16 +197,16 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), // return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y // : trunc_mod def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), - (MHLO_SelectOp + (StableHLO_SelectOp (CHLO_BroadcastAndOp (CHLO_BroadcastCompareOp (CHLO_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), - (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), + (StableHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"NE">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp (CHLO_BroadcastCompareOp:$r_cmp $r, - (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), + (StableHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, @@ -219,10 +227,10 @@ def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), def Get2DTransposePerm: NativeCodeCall< "Get2DTransposePerm($0, &$_builder)">; -def : Pat<(TF_RiscAddOp $l, $r), (MHLO_AddOp $l, $r)>; +def : Pat<(TF_RiscAddOp $l, $r), (StableHLO_AddOp $l, $r)>; def : Pat<(TF_RiscDotOp $a, $b, $transpose_a, $transpose_b), - (MHLO_DotOp + (StableHLO_DotOp (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), /*precision_config=*/(NullArrayAttr))>; @@ -264,7 +272,7 @@ class EqualityPat (CHLO_BroadcastCompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction, (CHLO_DEFAULT_COMPARISON_TYPE)), - [(MHLO_Tensor $l)]>; + [(HLO_Tensor $l)]>; def : EqualityPat>; def : EqualityPat>; @@ -274,17 +282,17 @@ def : EqualityPat>; //===----------------------------------------------------------------------===// def OneElementAttrPred - : CPred<"$_self.cast().getShapedType().getNumElements() == 1">; + : CPred<"llvm::cast($_self).getShapedType().getNumElements() == 1">; def OneElementAttr : ElementsAttrBase, "Scalar ElementsAttr">; def HasRankedFirstOperand - : Constraint()">>; + : Constraint((*$0.begin()).getType())">>; def IsShapedTensor - : Constraint()">>; + : Constraint($0.getType())">>; // This pattern converts TensorFlow axis format to HLO axis format which // doesn't wrap around like TensorFlow and is always positive. For this @@ -295,7 +303,7 @@ def IsShapedTensor // if HLO constant op is introduced as an replacement for the TensorFlow // Constant op. def : Pat<(TF_ConcatV2Op $inputs, (ConstantLikeMatcher OneElementAttr:$axis)), - (MHLO_ConcatenateOp $inputs, + (StableHLO_ConcatenateOp $inputs, (GetHLOAxisFromTFAxisVariadic $axis, $inputs)), [(HasRankedFirstOperand $inputs)]>; @@ -304,16 +312,16 @@ def : Pat<(TF_ConcatV2Op $inputs, (ConstantLikeMatcher OneElementAttr:$axis)), //===----------------------------------------------------------------------===// def : Pat<(TF_CollectivePermuteOp $input, (ConstantLikeMatcher ElementsAttr:$source_target_pairs)), - (MHLO_CollectivePermuteOp $input, + (StableHLO_CollectivePermuteOp $input, (CastElementsToI64Elements $source_target_pairs), - (NullChannelHandleAttr))>; + (StableHLO_NullChannelHandleAttr))>; //===----------------------------------------------------------------------===// // CrossReplicaSum op patterns. //===----------------------------------------------------------------------===// def : Pat<(TF_CrossReplicaSumOp $input, (ConstantLikeMatcher ElementsAttr:$group_assignment)), - (MHLO_CrossReplicaSumOp $input, + (StableHLO_CrossReplicaSumOp $input, (CastElementsToI64Elements $group_assignment))>; //===----------------------------------------------------------------------===// @@ -322,27 +330,27 @@ def : Pat<(TF_CrossReplicaSumOp $input, (ConstantLikeMatcher ElementsAttr:$group def ValueToVariadic: NativeCodeCall<"SmallVector{$0}">; def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (ConstantLikeMatcher ElementsAttr:$group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), - (MHLO_AllToAllOp (ValueToVariadic $input), $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment), (NullChannelHandleAttr))>; + (StableHLO_AllToAllOp (ValueToVariadic $input), $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment), (StableHLO_NullChannelHandleAttr))>; //===----------------------------------------------------------------------===// // FFT op patterns. //===----------------------------------------------------------------------===// -class MHLO_FftTypeValue : - ConstantAttr; +class StableHLO_FftTypeValue : + ConstantAttr; def GetInnerDimFromValue : NativeCodeCall< - "GetInnerDimFromValue($0.getType().cast(), &$_builder)">; + "GetInnerDimFromValue(llvm::cast($0.getType()), &$_builder)">; def CheckInnerDimStatic - : Constraint(), &$_builder)">>; + : Constraint($0.getType()), &$_builder)">>; def : Pat<(TF_FFTOp:$res $input), - (MHLO_FftOp $input, MHLO_FftTypeValue<"FFT">, (GetInnerDimFromValue $res)), + (StableHLO_FftOp $input, StableHLO_FftTypeValue<"FFT">, (GetInnerDimFromValue $res)), [(CheckInnerDimStatic $input)]>; def : Pat<(TF_IFFTOp:$res $input), - (MHLO_FftOp $input, MHLO_FftTypeValue<"IFFT">, (GetInnerDimFromValue $res)), + (StableHLO_FftOp $input, StableHLO_FftTypeValue<"IFFT">, (GetInnerDimFromValue $res)), [(CheckInnerDimStatic $input)]>; //===----------------------------------------------------------------------===// @@ -355,7 +363,7 @@ def : Pat<(TF_IFFTOp:$res $input), def LegalizeGatherV2 : Pat<(TF_GatherV2Op AnyRankedTensor:$params, AnyRankedTensor:$indices, (ConstantLikeMatcher ElementsAttr:$axis), $batch_dims), - (MHLO_TorchIndexSelectOp $params, $indices, + (StableHLO_TorchIndexSelectOp $params, $indices, (GetHLOAxisFromTFAxis $axis, $params), (GetHLOAxisFromTFAxis $batch_dims, $indices))>; @@ -364,17 +372,17 @@ def LegalizeGatherV2 : //===----------------------------------------------------------------------===// class SliceDenseIntElementsAttrColumn2D : NativeCodeCall< - "SliceDenseIntElementsAttrColumn2D($0.cast(), " # column # " )">; + "SliceDenseIntElementsAttrColumn2D(llvm::cast($0), " # column # " )">; class SliceDenseIntElementsAttr : NativeCodeCall< - "SliceDenseIntElementsAttr($0.cast(), " # index # ", " # axis # ")">; + "SliceDenseIntElementsAttr(llvm::cast($0), " # index # ", " # axis # ")">; // Interior padding attribute based on the TF padding. -def GetInteriorPadding : NativeCodeCall < - "GetInteriorPadding($0.cast())">; +def GetInteriorPadding : NativeCodeCall< + "GetInteriorPadding(llvm::cast($0))">; def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), - (MHLO_PadOp $input, $c, + (StableHLO_PadOp $input, $c, (SliceDenseIntElementsAttrColumn2D<"0"> $padding), (SliceDenseIntElementsAttrColumn2D<"1"> $padding), (GetInteriorPadding $padding))>; @@ -394,55 +402,55 @@ foreach src = [TF_PreventGradientOp, TF_CheckNumericsOp] in // MatMul op patterns. //===----------------------------------------------------------------------===// -def GetPrecisionConfig: NativeCodeCall< +def StableHLO_GetPrecisionConfig: NativeCodeCall< "GetPrecisionConfig(&$_builder)">; def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), - (MHLO_DotOp + (StableHLO_DotOp (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), - /*precision_config=*/(GetPrecisionConfig))>; + /*precision_config=*/(StableHLO_GetPrecisionConfig))>; //===----------------------------------------------------------------------===// // Lower `tf.ZerosLike` //===----------------------------------------------------------------------===// def : Pat<(TF_ZerosLikeOp AnyTensor:$arg), - (MHLO_ConstantLike<"0"> $arg)>; + (StableHLO_ConstantLike<"0"> $arg)>; //===----------------------------------------------------------------------===// // Lower `tf.OnesLike` //===----------------------------------------------------------------------===// def : Pat<(TF_OnesLikeOp AnyTensor:$arg), - (MHLO_ConstantLike<"1"> $arg)>; + (StableHLO_ConstantLike<"1"> $arg)>; //===----------------------------------------------------------------------===// // Elu op patterns. //===----------------------------------------------------------------------===// def : Pat<(TF_EluOp AnyTensor:$features), - (MHLO_SelectOp - (MHLO_CompareOp + (StableHLO_SelectOp + (StableHLO_CompareOp $features, - (MHLO_ConstantLike<"0">:$zero $features), - MHLO_ComparisonDirectionValue<"GT">, (MHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_ConstantLike<"0">:$zero $features), + StableHLO_ComparisonDirectionValue<"GT">, (STABLEHLO_DEFAULT_COMPARISON_TYPE)), $features, - (MHLO_Expm1Op $features, ConstDefaultResultAccuracyAttr))>; + (StableHLO_Expm1Op $features, ConstDefaultResultAccuracyAttr))>; def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), - (MHLO_SelectOp + (StableHLO_SelectOp (CHLO_BroadcastCompareOp $features, - (MHLO_ConstantOp:$zero (GetScalarOfType<0> $features)), + (StableHLO_ConstantOp:$zero (GetScalarOfType<0> $features)), (BinBroadcastDimensions $zero, $features), CHLO_ComparisonDirectionValue<"GT">, (CHLO_DEFAULT_COMPARISON_TYPE)), $gradients, - (MHLO_MulOp + (StableHLO_MulOp $gradients, (CHLO_BroadcastAddOp $features, - (MHLO_ConstantOp:$one (GetScalarOfType<1> $features)), + (StableHLO_ConstantOp:$one (GetScalarOfType<1> $features)), (BinBroadcastDimensions $one, $features))))>; //===----------------------------------------------------------------------===// @@ -455,24 +463,24 @@ def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$featur // TODO(hinsu): Lower quantized types after supporting them in GetScalarOfType. def : Pat<(TF_ReluOp AnyTensor:$input), (CHLO_BroadcastMaxOp - (MHLO_ConstantOp:$zero (GetScalarOfType<0> $input)), $input, + (StableHLO_ConstantOp:$zero (GetScalarOfType<0> $input)), $input, (BinBroadcastDimensions $zero, $input)), [(TF_IntOrFpTensor $input)]>; // TODO(hinsu): Lower quantized types after supporting them in GetScalarOfType. def : Pat<(TF_Relu6Op AnyRankedTensor:$input), - (MHLO_ClampOp (MHLO_ConstantOp (GetScalarOfType<0> $input)), $input, - (MHLO_ConstantOp (GetScalarOfType<6> $input))), + (StableHLO_ClampOp (StableHLO_ConstantOp (GetScalarOfType<0> $input)), $input, + (StableHLO_ConstantOp (GetScalarOfType<6> $input))), [(TF_IntOrFpTensor $input)]>; // ReluGrad(gradients, features) = gradients * (features > 0) // The condition that $gradients and $features need to have the same shape is // implicitly enforced: $zero is created to have the same shape as $features, -// MHLO_SelectOp enforces that $gradients and $zero have the same shape. +// StableHLO_SelectOp enforces that $gradients and $zero have the same shape. def : Pat<(TF_ReluGradOp AnyTensor:$gradients, AnyTensor:$features), - (MHLO_SelectOp - (MHLO_CompareOp $features, (MHLO_ConstantLike<"0">:$zero $features), - MHLO_ComparisonDirectionValue<"GT">, (MHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_SelectOp + (StableHLO_CompareOp $features, (StableHLO_ConstantLike<"0">:$zero $features), + StableHLO_ComparisonDirectionValue<"GT">, (STABLEHLO_DEFAULT_COMPARISON_TYPE)), $gradients, $zero)>; //===----------------------------------------------------------------------===// @@ -482,9 +490,9 @@ def : Pat<(TF_ReluGradOp AnyTensor:$gradients, AnyTensor:$features), /// Converts a TF::SoftsignOp to HLO. /// Softsign(features) = features / (1 + abs(features)) def : Pat<(TF_SoftsignOp AnyTensor:$input), - (MHLO_DivOp + (StableHLO_DivOp $input, - (MHLO_AddOp (MHLO_ConstantLike<"1"> $input), (MHLO_AbsOp $input)) + (StableHLO_AddOp (StableHLO_ConstantLike<"1"> $input), (StableHLO_AbsOp $input)) ) >; @@ -493,12 +501,12 @@ def : Pat<(TF_SoftsignOp AnyTensor:$input), def : Pattern< (TF_SoftsignGradOp AnyRankedTensor:$gradients, AnyRankedTensor:$features), [(CHLO_BroadcastAddOp:$add - (MHLO_ConstantOp:$one (GetScalarOfType<1> $features)), (MHLO_AbsOp $features), + (StableHLO_ConstantOp:$one (GetScalarOfType<1> $features)), (StableHLO_AbsOp $features), (BinBroadcastDimensions $one, $features) ), (CHLO_BroadcastDivOp $gradients, - (MHLO_MulOp $add, $add), + (StableHLO_MulOp $add, $add), (BinBroadcastDimensions $gradients, $add) ) ]>; @@ -511,15 +519,15 @@ def UnpackStartingIndices: NativeCodeCall< "UnpackTensorAlongZeroDim($0.getLoc(), $1, &$_builder).getOutput()">; def CanBeTranslatedToDynamicSlice : Constraint())">>; + "CanBeTranslatedToDynamicSlice($0, $1, llvm::cast($2))">>; def TFSliceSizes2HLOSliceSizes : NativeCodeCall< - "TFSliceSizes2HLOSliceSizes($0, $1, $2.cast()," + "TFSliceSizes2HLOSliceSizes($0, $1, llvm::cast($2)," "&$_builder)">; -def : Pat<(TF_SliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, +def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, (ConstantLikeMatcher AnyAttr:$slice_sizes)), - (MHLO_DynamicSliceOp $input, + (StableHLO_DynamicSliceOp $input, (UnpackStartingIndices $op, $starting_indices), (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)), [(CanBeTranslatedToDynamicSlice $input, $starting_indices, @@ -529,8 +537,8 @@ def : Pat<(TF_SliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, // Select op patterns. //===----------------------------------------------------------------------===// - def : Pat<(TF_SelectV2Op MHLO_Tensor:$pred, MHLO_Tensor:$on_true, - MHLO_Tensor:$on_false), + def : Pat<(TF_SelectV2Op HLO_Tensor:$pred, HLO_Tensor:$on_true, + HLO_Tensor:$on_false), (CHLO_BroadcastSelectOp $pred, $on_true, $on_false)>; //===----------------------------------------------------------------------===// @@ -563,47 +571,47 @@ def : Pat<(TF_LegacyCallOp:$op $args, $args_attrs, $res_attrs, //===----------------------------------------------------------------------===// // Handles axis conversion for TF reverse. -def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1.cast(), &$_builder)">; +def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, llvm::cast($1), &$_builder)">; def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher ElementsAttr:$axis)), - (MHLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; + (StableHLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; //===----------------------------------------------------------------------===// // Unary op patterns. //===----------------------------------------------------------------------===// foreach Mapping = [ - [TF_AbsOp, MHLO_AbsOp], - [TF_CeilOp, MHLO_CeilOp], - [TF_ComplexAbsOp, MHLO_AbsOp], - [TF_ErfOp, MHLO_ErfOp], - [TF_FloorOp, MHLO_FloorOp], - [TF_ImagOp, MHLO_ImagOp], - [TF_InvertOp, MHLO_NotOp], - [TF_IsFiniteOp, MHLO_IsFiniteOp], - [TF_LogicalNotOp, MHLO_NotOp], - [TF_NegOp, MHLO_NegOp], - [TF_RealOp, MHLO_RealOp], + [TF_AbsOp, StableHLO_AbsOp], + [TF_CeilOp, StableHLO_CeilOp], + [TF_ComplexAbsOp, StableHLO_AbsOp], + [TF_ErfOp, CHLO_ErfOp], + [TF_FloorOp, StableHLO_FloorOp], + [TF_ImagOp, StableHLO_ImagOp], + [TF_InvertOp, StableHLO_NotOp], + [TF_IsFiniteOp, StableHLO_IsFiniteOp], + [TF_LogicalNotOp, StableHLO_NotOp], + [TF_NegOp, StableHLO_NegOp], + [TF_RealOp, StableHLO_RealOp], ] in { - def : Pat<(Mapping[0] MHLO_Tensor:$input), + def : Pat<(Mapping[0] HLO_Tensor:$input), (Mapping[1] $input)>; } foreach Mapping = [ - [TF_CosOp, MHLO_CosineOp], - [TF_ExpOp, MHLO_ExpOp], - [TF_Expm1Op, MHLO_Expm1Op], - [TF_LogOp, MHLO_LogOp], - [TF_Log1pOp, MHLO_Log1pOp], - [TF_RsqrtOp, MHLO_RsqrtOp], - [TF_SigmoidOp, MHLO_LogisticOp], - [TF_SinOp, MHLO_SineOp], - [TF_SqrtOp, MHLO_SqrtOp], - [TF_TanhOp, MHLO_TanhOp], - [TF_TanOp, MHLO_TanOp] + [TF_CosOp, StableHLO_CosineOp], + [TF_ExpOp, StableHLO_ExpOp], + [TF_Expm1Op, StableHLO_Expm1Op], + [TF_LogOp, StableHLO_LogOp], + [TF_Log1pOp, StableHLO_Log1pOp], + [TF_RsqrtOp, StableHLO_RsqrtOp], + [TF_SigmoidOp, StableHLO_LogisticOp], + [TF_SinOp, StableHLO_SineOp], + [TF_SqrtOp, StableHLO_SqrtOp], + [TF_TanhOp, StableHLO_TanhOp], + [TF_TanOp, StableHLO_TanOp] ] in { - def : Pat<(Mapping[0] MHLO_Tensor:$input), + def : Pat<(Mapping[0] HLO_Tensor:$input), (Mapping[1] $input, ConstDefaultResultAccuracyAttr)>; } @@ -622,28 +630,28 @@ foreach Mapping = [ [TF_LgammaOp, CHLO_LgammaOp], [TF_SinhOp, CHLO_SinhOp], ] in { - def : Pat<(Mapping[0] MHLO_AnyTensor:$input), + def : Pat<(Mapping[0] HLO_AnyTensor:$input), (Mapping[1] $input)>; } -def : Pat<(TF_AngleOp $x), (MHLO_Atan2Op (MHLO_ImagOp $x), (MHLO_RealOp $x))>; +def : Pat<(TF_AngleOp $x), (StableHLO_Atan2Op (StableHLO_ImagOp $x), (StableHLO_RealOp $x))>; // TODO(bixia): Lower with Truncate=True for floating point value conversions. -def : Pat<(TF_CastOp $arg, ConstBoolAttrFalse), (MHLO_ConvertOp $arg)>; +def : Pat<(TF_CastOp $arg, ConstBoolAttrFalse), (StableHLO_ConvertOp $arg)>; def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)), - (MHLO_TransposeOp $arg, (CastElementsToI64Elements $permutation))>; + (StableHLO_TransposeOp $arg, (CastElementsToI64Array $permutation))>; -// Lowering these ops with static shape to mhlo.reshape +// Lowering these ops with static shape to stablehlo.reshape foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in { - def : Pat<(TfOp:$res MHLO_Tensor:$arg, $ignored), - (MHLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)], [], + def : Pat<(TfOp:$res HLO_Tensor:$arg, $ignored), + (StableHLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)], [], (addBenefit 2)>; } // Returns NaN if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. -def : Pat<(TF_SignOp $x), (MHLO_SignOp $x)>; +def : Pat<(TF_SignOp $x), (StableHLO_SignOp $x)>; def BothElementTypesSameWidthIntOrFloat : Constraint; // TODO(jpienaar): Lower constant like to constant to broadcast if dynamic -// and going to MHLO. +// and going to StableHLO. //===----------------------------------------------------------------------===// // Random ops. //===----------------------------------------------------------------------===// // TODO(b/148269299): handle random number generator seeds/states correctly. -class MHLO_RngDistributionValue : - ConstantAttr; +class StableHLO_RngDistributionValue : + ConstantAttr; def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2), - (MHLO_RngOp - (MHLO_ConstantOp + (StableHLO_RngOp + (StableHLO_ConstantOp (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 0.0)">)), - (MHLO_ConstantOp + (StableHLO_ConstantOp (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 1.0)">)), (CastValueToI64 $old, $shape), - MHLO_RngDistributionValue<"UNIFORM">), + StableHLO_RngDistributionValue<"UNIFORM">), [(IsShapedTensor $shape)]>; def : Pat<(TF_RandomStandardNormalOp:$old $shape, $seed, $seed2), - (MHLO_RngOp - (MHLO_ConstantOp + (StableHLO_RngOp + (StableHLO_ConstantOp (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 0.0)">)), - (MHLO_ConstantOp + (StableHLO_ConstantOp (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 1.0)">)), (CastValueToI64 $old, $shape), - MHLO_RngDistributionValue<"NORMAL">), + StableHLO_RngDistributionValue<"NORMAL">), [(IsShapedTensor $shape)]>; //===----------------------------------------------------------------------===// // Sigmoid grad op. //===----------------------------------------------------------------------===// -// TODO(hinsu): Handle unranked inputs by broadcasting constant one to the -// shape of $l instead of having it as a constant. +// Only handle static shape here, dynamic shape is handled by +// ConvertSigmoidGradOpDynamic +def HasStaticShape : Constraint< + CPred<"::llvm::dyn_cast($0.getType()).hasStaticShape()">>; + def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (MHLO_MulOp - (MHLO_MulOp $r, $l), - (MHLO_SubtractOp (MHLO_ConstantOp (ConstantSplat<"1"> $l)), $l))>; + (StableHLO_MulOp + (StableHLO_MulOp $r, $l), + (StableHLO_SubtractOp (StableHLO_ConstantOp (ConstantSplat<"1"> $l)), $l)), + [(HasStaticShape $l)]>; //===----------------------------------------------------------------------===// // Softplus op. @@ -707,22 +719,22 @@ def EpsilonValue : NativeCodeCall<"GetEpsilonValue($0.getType())">; def : Pattern<(TF_SoftplusOp AnyTensor:$features), [ - (MHLO_ExpOp:$features_exp $features, ConstDefaultResultAccuracyAttr), + (StableHLO_ExpOp:$features_exp $features, ConstDefaultResultAccuracyAttr), (CHLO_BroadcastAddOp:$threshold - (MHLO_LogOp (MHLO_ConstantOp (EpsilonValue $features)), ConstDefaultResultAccuracyAttr), - (MHLO_ConstantOp (GetScalarOfType<2> $features)), + (StableHLO_LogOp (StableHLO_ConstantOp (EpsilonValue $features)), ConstDefaultResultAccuracyAttr), + (StableHLO_ConstantOp (GetScalarOfType<2> $features)), (NullDenseI64ArrayAttr) ), - (MHLO_SelectOp:$output + (StableHLO_SelectOp:$output (CHLO_BroadcastCompareOp $features, - (MHLO_NegOp $threshold), + (StableHLO_NegOp $threshold), (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"GT">, (CHLO_DEFAULT_COMPARISON_TYPE) ), $features, - (MHLO_SelectOp + (StableHLO_SelectOp (CHLO_BroadcastCompareOp $features, $threshold, @@ -731,7 +743,7 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), (CHLO_DEFAULT_COMPARISON_TYPE) ), $features_exp, - (MHLO_Log1pOp $features_exp, ConstDefaultResultAccuracyAttr) + (StableHLO_Log1pOp $features_exp, ConstDefaultResultAccuracyAttr) ) ), (replaceWithValue $output) @@ -742,7 +754,7 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), //===----------------------------------------------------------------------===// def : Pat<(TF_XlaReplicaIdOp), - (TF_CastOp (MHLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>; + (TF_CastOp (StableHLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>; //===----------------------------------------------------------------------===// // XlaGather op. @@ -754,9 +766,9 @@ def HasValidGatherDims : Constraint>; def : Pat<(TF_XlaGatherOp $operand, $start_indices, (ConstantLikeMatcher ElementsAttr:$slice_sizes), $dimension_numbers, $indices_are_sorted), - (MHLO_GatherOp $operand, $start_indices, + (StableHLO_GatherOp $operand, $start_indices, (ToGatherDimNumsAttr $dimension_numbers), - (CastElementsToI64Elements $slice_sizes), + (CastElementsToI64Array $slice_sizes), $indices_are_sorted), [(HasValidGatherDims $dimension_numbers)]>; @@ -773,7 +785,7 @@ def HasValidDotDims : Constraint>; def HasValidPrecisionConfig : Constraint>; def : Pat<(TF_XlaDotOp $lhs, $rhs, $dimension_numbers, $precision_config), - (MHLO_DotGeneralOp $lhs, $rhs, + (StableHLO_DotGeneralOp $lhs, $rhs, (ToDotDimNumsAttr $dimension_numbers), (ToPrecisionConfigsAttr $precision_config), (EmptyDotAlgorithmAttr)), @@ -784,7 +796,7 @@ def : Pat<(TF_XlaDotOp $lhs, $rhs, $dimension_numbers, $precision_config), //===----------------------------------------------------------------------===// def : Pat<(TF_XlaDotV2Op $lhs, $rhs, $dimension_numbers, $precision_config), - (MHLO_DotGeneralOp $lhs, $rhs, + (StableHLO_DotGeneralOp $lhs, $rhs, (ToDotDimNumsAttr $dimension_numbers), (ToPrecisionConfigsAttr $precision_config), (EmptyDotAlgorithmAttr)), @@ -794,9 +806,9 @@ def : Pat<(TF_XlaDotV2Op $lhs, $rhs, $dimension_numbers, $precision_config), // XlaDynamicSlice op. //===----------------------------------------------------------------------===// -def : Pat<(TF_XlaDynamicSliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, +def : Pat<(TF_XlaDynamicSliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, (ConstantLikeMatcher AnyAttr:$slice_sizes)), - (MHLO_DynamicSliceOp $input, + (StableHLO_DynamicSliceOp $input, (UnpackStartingIndices $op, $starting_indices), (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes))>; @@ -805,11 +817,11 @@ def : Pat<(TF_XlaDynamicSliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_in //===----------------------------------------------------------------------===// def : Pat<(TF_XlaEinsumOp $lhs, $rhs, $equation), - (MHLO_EinsumOp $lhs, $rhs, $equation)>; + (StableHLO_EinsumOp $lhs, $rhs, $equation)>; //===----------------------------------------------------------------------===// // XlaOptimizationBarrierOp op. //===----------------------------------------------------------------------===// def : Pat<(TF_XlaOptimizationBarrierOp $args), - (MHLO_OptimizationBarrierOp $args)>; + (StableHLO_OptimizationBarrierOp $args)>; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc index 2d9bc167d2c0a4..14cec354ddcb9e 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/IRMapping.h" // from @llvm-project @@ -31,6 +32,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h" @@ -43,7 +45,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/status.h" #include "xla/tsl/platform/statusor.h" @@ -75,13 +76,11 @@ bool IsBounded(Type ty) { if (ranked_ty.hasStaticShape()) return true; - auto encoding = - mlir::dyn_cast_or_null(ranked_ty.getEncoding()); - if (!encoding) return false; + auto bounds = hlo::encodingToBounds(ranked_ty.getEncoding()); + if (bounds.empty()) return false; for (int i = 0; i < ranked_ty.getRank(); ++i) { - if (ranked_ty.isDynamicDim(i) && - encoding.getBounds()[i] == ShapedType::kDynamic) { + if (ranked_ty.isDynamicDim(i) && bounds[i] == ShapedType::kDynamic) { return false; } } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc index 161ae934df7d05..7f3ec19a70967a 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/IRMapping.h" // from @llvm-project @@ -50,6 +51,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h" @@ -69,7 +72,6 @@ limitations under the License. #include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" @@ -154,7 +156,7 @@ Tf2XlaRewriter::~Tf2XlaRewriter() { if (context_) context_->Unref(); } -absl::StatusOr Tf2XlaRewriter::ImportXlaComputation( +absl::StatusOr Tf2XlaRewriter::ImportXlaComputation( XlaComputation& computation) { xla::DebugOptions debug_options; TF_ASSIGN_OR_RETURN(auto hlo_module_config, @@ -193,8 +195,8 @@ absl::StatusOr Tf2XlaRewriter::ImportXlaComputation( xla::HloFunctionImporter::ImportInstructions( *hlo_module->entry_computation(), arguments, symbol_table, &builder)); - mhlo::TupleOp root_tuple = - mlir::dyn_cast_or_null(root_value.getDefiningOp()); + stablehlo::TupleOp root_tuple = + mlir::dyn_cast_or_null(root_value.getDefiningOp()); if (!root_tuple) { return tsl::errors::InvalidArgument( "Imported XLA Root Value is not a tuple op"); @@ -259,13 +261,11 @@ bool IsBounded(Type ty) { if (ranked_ty.hasStaticShape()) return true; - auto encoding = - mlir::dyn_cast_or_null(ranked_ty.getEncoding()); - if (!encoding) return false; + ArrayRef bounds = hlo::encodingToBounds(ranked_ty.getEncoding()); + if (bounds.empty()) return false; for (int i = 0; i < ranked_ty.getRank(); ++i) { - if (ranked_ty.isDynamicDim(i) && - encoding.getBounds()[i] == ShapedType::kDynamic) { + if (ranked_ty.isDynamicDim(i) && bounds[i] == ShapedType::kDynamic) { return false; } } @@ -410,23 +410,23 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() { if (failed(VerifyOpResults(op_context))) return failure(); - absl::StatusOr tuple_result_or_status = + absl::StatusOr tuple_result_or_status = CompileWithHloImporter(op_context); if (!tuple_result_or_status.ok()) { return op_->emitRemark() << tuple_result_or_status.status().ToString(); } - mhlo::TupleOp tuple_result = tuple_result_or_status.value(); + stablehlo::TupleOp tuple_result = tuple_result_or_status.value(); - llvm::SmallVector output_values; - if (failed(GetKernelOutputs(op_context, tuple_result, output_values))) { - return failure(); - } + llvm::SmallVector output_values; + if (failed(GetKernelOutputs(op_context, tuple_result, output_values))) { + return failure(); + } rewriter_.replaceOp(op_, output_values); return success(); } -absl::StatusOr Tf2XlaRewriter::CompileWithHloImporter( +absl::StatusOr Tf2XlaRewriter::CompileWithHloImporter( tensorflow::OpKernelContext& op_context) { // XLA can only return a single value. Wrap all output op return values // in a Tuple op that gets unpacked later. @@ -470,7 +470,7 @@ mlir::LogicalResult Tf2XlaRewriter::VerifyOpResults( // multiple values. We get around this by returning a tuple as an XLA op. We // then unpack it here to return the multiple values instead. mlir::LogicalResult Tf2XlaRewriter::UnpackTupleResults( - mhlo::TupleOp tuple_result, llvm::SmallVector& outputs) { + stablehlo::TupleOp tuple_result, llvm::SmallVector& outputs) { if (tuple_result->getNumOperands() != op_->getNumResults()) { return op_->emitRemark() << "Translated TF2XLA tuple has different " "number of results than original op"; @@ -485,7 +485,7 @@ mlir::LogicalResult Tf2XlaRewriter::UnpackTupleResults( } mlir::LogicalResult Tf2XlaRewriter::GetKernelOutputs( - tensorflow::OpKernelContext& op_context, mhlo::TupleOp tuple_results, + tensorflow::OpKernelContext& op_context, stablehlo::TupleOp tuple_results, llvm::SmallVector& outputs) { outputs.reserve(op_->getNumResults()); diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h index c5c417e27ba022..c89316638a2ea5 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h @@ -28,12 +28,12 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/op_kernel.h" @@ -58,12 +58,12 @@ class Tf2XlaRewriter { // Compiles the given Operation with XlaBuilder and imports the generated HLO // via the HLO -> MHLO importer. - absl::StatusOr CompileWithHloImporter( + absl::StatusOr CompileWithHloImporter( tensorflow::OpKernelContext& op_context); // Import the given XlaComputation into the parent module. Returns the given // generated function. - absl::StatusOr ImportXlaComputation( + absl::StatusOr ImportXlaComputation( xla::XlaComputation& computation); // Prepares OpKernelContext params common to all the ops. @@ -83,12 +83,12 @@ class Tf2XlaRewriter { mlir::LogicalResult VerifyOpResults(tensorflow::OpKernelContext& op_context); mlir::LogicalResult GetKernelOutputs(tensorflow::OpKernelContext& op_context, - mhlo::TupleOp tuple_results, + stablehlo::TupleOp tuple_results, llvm::SmallVector& outputs); // Given a translated function with a single return value, unpack the tuple // results. - mlir::LogicalResult UnpackTupleResults(mhlo::TupleOp tuple_result, + mlir::LogicalResult UnpackTupleResults(stablehlo::TupleOp tuple_result, llvm::SmallVector& outputs); // Tries to legalize the specified TensorFlow op, if supported. diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc index eaad485ccab96a..e20be6bb9a173c 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -33,20 +34,19 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/status.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tensorflow/core/framework/op_kernel.h" namespace mlir { namespace mhlo { @@ -102,7 +102,7 @@ class Tf2XlaRewriterTestPeer { tf2xla_rewriter_(op, empty_rewriter_, /*device_type=*/"XLA_CPU_JIT") {} - absl::StatusOr ImportXlaComputationIntoModule( + absl::StatusOr ImportXlaComputationIntoModule( XlaComputation& computation) { return tf2xla_rewriter_.ImportXlaComputation(computation); } @@ -184,7 +184,7 @@ class Tf2XlaRewriterTest : public ::testing::Test { return main_func.getBody().front().front(); } - absl::StatusOr ImportXlaComputationIntoModule( + absl::StatusOr ImportXlaComputationIntoModule( XlaComputation& computation) { SourceMgrDiagnosticHandler sourceMgrHandler(source_manager_, &context_); @@ -204,7 +204,8 @@ TEST_F(Tf2XlaRewriterTest, LegalizesOpWithTf2xlaHloImporter) { TF_EXPECT_OK(LegalizeModule()); int num_tuple_ops = 0; - module_->walk([&num_tuple_ops](TupleOp tuple_op) { num_tuple_ops += 1; }); + module_->walk( + [&num_tuple_ops](stablehlo::TupleOp tuple_op) { num_tuple_ops += 1; }); EXPECT_EQ(num_tuple_ops, 0); } @@ -214,7 +215,7 @@ TEST_F(Tf2XlaRewriterTest, ImportsXlaComputationIntoModule) { XlaComputation computation = GetTestXlaComputation(); - TF_ASSERT_OK_AND_ASSIGN(TupleOp root_tuple, + TF_ASSERT_OK_AND_ASSIGN(stablehlo::TupleOp root_tuple, ImportXlaComputationIntoModule(computation)); ModuleOp parent_module = @@ -261,7 +262,7 @@ TEST_F(Tf2XlaRewriterTest, ImportsSingleComputation) { EXPECT_EQ(computation.proto().computations_size(), 2); TF_ASSERT_OK(CreateMlirModule()); - TF_ASSERT_OK_AND_ASSIGN(TupleOp root_tuple, + TF_ASSERT_OK_AND_ASSIGN(stablehlo::TupleOp root_tuple, ImportXlaComputationIntoModule(computation)); EXPECT_TRUE(root_tuple); @@ -356,7 +357,7 @@ TEST_F(Tf2XlaRewriterTest, ErrorsWithInvalidNumberOfParametersToArgs) { EXPECT_EQ(computation.proto().computations_size(), 2); TF_ASSERT_OK(CreateMlirModule()); - absl::StatusOr status_or_tuple_op = + absl::StatusOr status_or_tuple_op = ImportXlaComputationIntoModule(computation); EXPECT_FALSE(status_or_tuple_op.ok()); } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc index d99f80ff5eacd5..e3d5d5f1b5a5d3 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc @@ -21,11 +21,13 @@ limitations under the License. #include "llvm/Support/Error.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" @@ -87,9 +89,8 @@ static void IncrementCounterFor(tensorflow::monitoring::Counter<1>* counter, } bool HasBounds(RankedTensorType type) { - auto encoding = mlir::dyn_cast_or_null( - type.getEncoding()); - return (encoding && !encoding.getBounds().empty()); + auto bounds = hlo::encodingToBounds(type.getEncoding()); + return !bounds.empty(); } bool HasStaticShapeOrBounded(Value val) { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc index f5364586ec73c9..6c964c4d9a1403 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "mhlo/transforms/rewriters.h" #include "absl/log/log.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" @@ -35,18 +36,19 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" +#include "stablehlo/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep, dependent dialect #include "xla/mlir_hlo/mhlo/transforms/rewriters.h" #include "xla/mlir_hlo/mhlo/utils/type_conversion.h" #include "tensorflow/core/lib/monitoring/counter.h" @@ -154,6 +156,22 @@ mlir::LogicalResult ApplyPatterns(Operation *op, RewritePatternSet &patterns, return result; } +mlir::LogicalResult StablehloToMhlo(Operation *op) { + ConversionTarget target(*op->getContext()); + stablehlo::setupStablehloToHloConversionTarget(target); + + RewritePatternSet patterns(op->getContext()); + stablehlo::StablehloToHloTypeConverter shlo_converter; + stablehlo::populateStablehloToHloPatterns(&patterns, &shlo_converter, + patterns.getContext()); + stablehlo::registerFuncOpsForTypeConversion(target, patterns, shlo_converter); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { + return op->emitError("TF2XLA failed to convert StableHLO to MHLO"); + } + return success(); +} + /// When `tf2xla_fallback_device_type` is not `None`, also uses legalization /// patterns from TF2XLA fallback for provided device type (see /// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is @@ -208,20 +226,30 @@ LogicalResult legalizeTF(Operation *op, bool legalize_chlo, // Populate with CHLO->HLO lowerings to account for TF ops legalized to // CHLO first. stablehlo::StablehloToHloTypeConverter hlo_converter; + stablehlo::populateStablehloToHloPatterns(&patterns, &hlo_converter, context); if (legalize_chlo) { - chlo::populateChloToHloPatterns(context, &hlo_converter, &patterns); + chlo::populateChloToHighLevelMhloOpPatterns(context, &patterns); + stablehlo::populateChloToStablehloPatterns(context, &patterns); } // ConstantLike op is convenient to create splat constants, but is // canonicalized to plain HLO constant if statically shaped. Add the // canonicalization pattern to pattern list to enable multi-hop lowering. chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context); - return ApplyPatterns(op, patterns, legalize_chlo); + if (failed(ApplyPatterns(op, patterns, legalize_chlo))) { + return failure(); + } + + // HLO->MLIR raises to StableHLO, but users of this pass expect MHLO. + return StablehloToMhlo(op); } // Performs the lowering to XLA dialect. void LegalizeTF::runOnOperation() { auto op = getOperation(); + VLOG(3) << "LegalizeTF(legalize_chlo=" << legalize_chlo_ + << ", prefer_tf2xla=" << prefer_tf2xla_ << ") on module:\n" + << mlir::debugString(*op); auto op_name = op->getName().getStringRef().str(); mlir_legalization_count->GetCell(op_name)->IncrementBy(1); std::optional tf2xla_fallback_device_type = std::nullopt; diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index babd62f6b13f89..80e58756bbfaad 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -33,8 +33,8 @@ limitations under the License. #include "mlir/Support/ToolUtilities.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include "tensorflow/core/platform/init_main.h" // NOLINTNEXTLINE diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index a90f25aab887cf..4435ef59a7e385 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -53,16 +53,10 @@ td_library( gentbl_cc_library( name = "tfr_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tfr_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tfr_ops.cc.inc", - ), - ], + tbl_outs = { + "ir/tfr_ops.h.inc": ["-gen-op-decls"], + "ir/tfr_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfr_ops.td", deps = [ @@ -73,12 +67,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tfr_decompose_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/generated_decompose.inc", - ), - ], + tbl_outs = {"passes/generated_decompose.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/decompose_patterns.td", deps = [ diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc index 6780328b8e8975..d44e65f029ada3 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc @@ -111,12 +111,11 @@ class TFRInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type result_type, Location conversion_loc) const final { - if (!input.getType().isa() || - !result_type.isa()) { + if (!isa(input.getType()) || !isa(result_type)) { return nullptr; } - auto input_itype = input.getType().cast(); - auto result_itype = result_type.cast(); + auto input_itype = llvm::cast(input.getType()); + auto result_itype = llvm::cast(result_type); if (input_itype.getWidth() == result_itype.getWidth()) return nullptr; if (input_itype.getWidth() > result_itype.getWidth()) { return builder.create(conversion_loc, result_type, @@ -150,10 +149,10 @@ Operation *TFRDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (arith::ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, - value.cast()); + llvm::cast(value)); if (func::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, type, - value.cast()); + return builder.create( + loc, type, llvm::cast(value)); return nullptr; } @@ -180,11 +179,11 @@ LogicalResult ConstantTensorOp::verify() { auto input_type = op.getArg().getType(); auto output_type = op.getOut().getType(); - if (auto output_tensor_type = output_type.dyn_cast()) { + if (auto output_tensor_type = llvm::dyn_cast(output_type)) { return success(); } - auto output_tensor_type = output_type.dyn_cast(); + auto output_tensor_type = llvm::dyn_cast(output_type); if (!output_tensor_type || !output_tensor_type.hasStaticShape()) { op.emitError("output type should be static and ranked."); return failure(); @@ -198,7 +197,7 @@ LogicalResult ConstantTensorOp::verify() { return success(same_scalar); } - if (auto input_vector_type = input_type.dyn_cast()) { + if (auto input_vector_type = llvm::dyn_cast(input_type)) { bool same_element_type = output_tensor_type.getElementType() == input_vector_type.getElementType(); bool same_shape = @@ -230,7 +229,7 @@ LogicalResult TFRFuncOp::verify() { for (auto arg : llvm::enumerate(func.getFunctionType().getInputs())) { Type arg_type = arg.value(); - if (auto tensor = arg_type.dyn_cast()) { + if (auto tensor = llvm::dyn_cast(arg_type)) { if (first_tensor == -1) { first_tensor = arg.index(); } @@ -240,7 +239,7 @@ LogicalResult TFRFuncOp::verify() { continue; } - if (auto tensor_list = arg_type.dyn_cast()) { + if (auto tensor_list = llvm::dyn_cast(arg_type)) { if (first_tensor_list == -1) { first_tensor_list = arg.index(); } @@ -250,7 +249,7 @@ LogicalResult TFRFuncOp::verify() { continue; } - if (!arg_type.isa()) { + if (!isa(arg_type)) { if (first_attr == -1) { first_attr = arg.index(); } @@ -307,7 +306,7 @@ LogicalResult TFRFuncOp::verify() { bool seen_tensor_list = false, has_tensor_list_order_error = false, has_multiple_tensor_lists_error = false; for (auto result_type : func.getFunctionType().getResults()) { - if (auto tensor = result_type.dyn_cast()) { + if (auto tensor = llvm::dyn_cast(result_type)) { if (seen_tensor_list) { has_tensor_list_order_error = true; } else { @@ -317,7 +316,7 @@ LogicalResult TFRFuncOp::verify() { continue; } - if (auto tensor_list = result_type.dyn_cast()) { + if (auto tensor_list = llvm::dyn_cast(result_type)) { if (seen_tensor_list) { has_multiple_tensor_lists_error = true; } else { @@ -413,7 +412,7 @@ class ConvertConstToTensorConst : public OpRewritePattern { if (matchPattern(cst_tensor_op.getArg(), m_Constant(&array))) { llvm::DenseSet all_types; for (auto it : array) { - TypedAttr typed_attr = it.dyn_cast(); + TypedAttr typed_attr = llvm::dyn_cast(it); if (!typed_attr) return failure(); all_types.insert(typed_attr.getType()); } @@ -423,7 +422,7 @@ class ConvertConstToTensorConst : public OpRewritePattern { DenseElementsAttr attr = DenseElementsAttr::get(new_out_type, array.getValue()); new_cst = rewriter.create(loc, new_out_type, attr); - if (out_type.isa()) { + if (isa(out_type)) { new_cst = rewriter.create(loc, out_type, new_cst->getResult(0)); } rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0)); @@ -434,7 +433,7 @@ class ConvertConstToTensorConst : public OpRewritePattern { if (matchPattern(cst_tensor_op.getArg(), m_Constant(&scalar))) { Type new_out_type = RankedTensorType::get({}, scalar.getType()); new_cst = rewriter.create(loc, new_out_type, scalar); - if (out_type.isa()) { + if (isa(out_type)) { new_cst = rewriter.create(loc, out_type, new_cst->getResult(0)); } rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0)); @@ -445,9 +444,9 @@ class ConvertConstToTensorConst : public OpRewritePattern { }; inline bool isQuantizedType(Type type) { - auto tensor_type = type.dyn_cast(); + auto tensor_type = llvm::dyn_cast(type); return (tensor_type && - tensor_type.getElementType().isa()); + isa(tensor_type.getElementType())); } class RemoveRedundantCast : public OpRewritePattern { @@ -471,8 +470,8 @@ class RemoveRedundantCast : public OpRewritePattern { return failure(); } - auto input_tensor_type = input_type.dyn_cast(); - auto output_tensor_type = output_type.dyn_cast(); + auto input_tensor_type = llvm::dyn_cast(input_type); + auto output_tensor_type = llvm::dyn_cast(output_type); if (!input_tensor_type || !output_tensor_type) { return failure(); } @@ -493,7 +492,7 @@ class RemoveRedundantCast : public OpRewritePattern { // If the two types are the same, the back-to-back tfr.cast ops can be // removed. - if (input_type == output_type || output_type.isa()) { + if (input_type == output_type || isa(output_type)) { rewriter.replaceOp(cast_op, {input}); return success(); } @@ -501,8 +500,8 @@ class RemoveRedundantCast : public OpRewritePattern { // If the rank of the input tensor isn't ranked, we replace the pair // with tf.EnsureShape op so it can be removed after shape inference or // confirmed at runtime. - if (input_type.isa()) { - auto shape = output_type.cast().getShape(); + if (isa(input_type)) { + auto shape = llvm::cast(output_type).getShape(); auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape); rewriter.replaceOpWithNewOp(cast_op, output_type, input, shape_attr); @@ -548,7 +547,7 @@ class RemoveRedundantGetElement : public OpRewritePattern { Value input = preceding_build_list.getOperand(index.getInt()); Type output_type = ge_op.getType(); if (input.getType() != output_type && - !output_type.isa()) { + !isa(output_type)) { return failure(); } rewriter.replaceOp(ge_op, {input}); @@ -599,10 +598,8 @@ quant::QuantizedType getQuantizedElementType(CastOp cast_op) { if (!cast_op || !cast_op.getInputElementType()) { return {}; } - return cast_op.getInputElementType() - .cast() - .getValue() - .dyn_cast(); + return llvm::dyn_cast( + llvm::cast(cast_op.getInputElementType()).getValue()); } class RemoveRawDataOp : public OpRewritePattern { @@ -681,15 +678,15 @@ class RemoveQParamsOp : public OpRewritePattern { // them to constants. rewriter.setInsertionPoint(qparams_op); Location loc = qparams_op->getLoc(); - if (auto qtype = cast_qtype.dyn_cast()) { + if (auto qtype = llvm::dyn_cast(cast_qtype)) { scale_op = rewriter.create( loc, RankedTensorType::get({}, rewriter.getF32Type()), rewriter.getF32FloatAttr(qtype.getScale())); zp_op = rewriter.create( loc, RankedTensorType::get({}, rewriter.getI32Type()), rewriter.getI32IntegerAttr(qtype.getZeroPoint())); - } else if (auto qtype = - cast_qtype.dyn_cast()) { + } else if (auto qtype = llvm::dyn_cast( + cast_qtype)) { SmallVector scales(qtype.getScales().begin(), qtype.getScales().end()); SmallVector zps(qtype.getZeroPoints().begin(), @@ -745,7 +742,7 @@ class RemoveScaleFactorOp : public OpRewritePattern { return failure(); } const double out_scale = - out_scale_op.getValue().cast().getValueAsDouble(); + llvm::cast(out_scale_op.getValue()).getValueAsDouble(); auto in_scales_op = scale_factor_op.getInScales().getDefiningOp(); @@ -778,7 +775,8 @@ class RemoveScaleFactorOp : public OpRewritePattern { // The shape of scale_type is {} (rank 0) for per-tensor quantized tensor, // and {num_channels} (rank 1) for per-channel quantized one. - auto scale_type = filter_scale_attr.getType().dyn_cast(); + auto scale_type = + llvm::dyn_cast(filter_scale_attr.getType()); if (scale_type.getRank() != 0 && scale_type.getRank() != 1) { return failure(); } @@ -995,14 +993,14 @@ Type TFRDialect::parseType(DialectAsmParser &parser) const { void TFRDialect::printType(Type type, DialectAsmPrinter &os) const { llvm::ArrayRef attrs; - if (type.isa()) { + if (isa(type)) { os << "attr"; return; } - if (auto tensor_ty = type.dyn_cast()) { + if (auto tensor_ty = llvm::dyn_cast(type)) { attrs = tensor_ty.getAttrKeys(); os << "tensor"; - } else if (auto tensor_list_ty = type.dyn_cast()) { + } else if (auto tensor_list_ty = llvm::dyn_cast(type)) { attrs = tensor_list_ty.getAttrKeys(); os << "tensor_list"; } else { diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td index 7cdaee96512dfc..d1014fec8e3e26 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td @@ -49,7 +49,7 @@ def TFR_Dialect : Dialect { // tensor argument types class TFR_Type : DialectType()">, + CPred<"llvm::isa($_self)">, "TFR " # name #" type">, BuildableType<"$_builder.getType()">; def TFR_TensorType : TFR_Type<"TFRTensor">; @@ -178,7 +178,7 @@ def TFR_CastOp : TFR_Op<"cast", [Pure]> { // Return element type of the input tensor type. Only available when the // input is a MLIR built-in tensor type. Attribute getInputElementType() { - if (auto ty = getArg().getType().dyn_cast()) { + if (auto ty = llvm::dyn_cast(getArg().getType())) { return TypeAttr::get(ty.getElementType()); } return {}; diff --git a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc index 9cc555b7893563..fb0640536d4fe5 100644 --- a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc +++ b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Inliner.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" @@ -142,8 +143,9 @@ LogicalResult SimplifySCFIfOp::InlineRegion(Location loc, Operation *inline_point, Region *region) const { InlinerInterface interface(loc.getContext()); - if (failed(inlineRegion(interface, region, inline_point, {}, - inline_point->getResults(), loc, + InlinerConfig config; + if (failed(inlineRegion(interface, config.getCloneCallback(), region, + inline_point, {}, inline_point->getResults(), loc, /*shouldCloneInlinedRegion=*/true))) { return failure(); } diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose.cc b/tensorflow/compiler/mlir/tfr/passes/decompose.cc index 3a5d6f23072b00..105cd8de2041aa 100644 --- a/tensorflow/compiler/mlir/tfr/passes/decompose.cc +++ b/tensorflow/compiler/mlir/tfr/passes/decompose.cc @@ -47,6 +47,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/Transforms/Inliner.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -282,6 +283,7 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() { // The Inliner will automatically use the registered dialect inliner. InlinerInterface inliner(&getContext()); + InlinerConfig config; func::FuncOp func = getOperation(); SymbolTable table(external_tfr_module_.has_value() ? *external_tfr_module_ @@ -301,7 +303,7 @@ LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() { // Use the inliner to replace all the uses of the call_op by its // composition. - if (failed(inlineCall(inliner, + if (failed(inlineCall(inliner, config.getCloneCallback(), cast(call_op.getOperation()), cast(callee.getOperation()), callee.getCallableRegion(), diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose_patterns.td b/tensorflow/compiler/mlir/tfr/passes/decompose_patterns.td index 503fd6256f16ed..d3b0322095d8d7 100644 --- a/tensorflow/compiler/mlir/tfr/passes/decompose_patterns.td +++ b/tensorflow/compiler/mlir/tfr/passes/decompose_patterns.td @@ -21,7 +21,7 @@ include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.td" class Quantize : NativeCodeCall<"TFR::Quantize(" # value # ", $0, $1, $_builder)">; class HasStringAttr : AttrConstraint< - CPred<"$_self.cast().getValue() == \"" # value # "\"">>; + CPred<"llvm::cast($_self).getValue() == \"" # value # "\"">>; def QuantActRangeNonePattern : Pattern< diff --git a/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc b/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc index 34ae51c14ed177..0a30c8f21b5843 100644 --- a/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc +++ b/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -54,12 +55,11 @@ void RewriteQuantizedIOPass::runOnOperation() { // with input_arg(tensor) -> tfr.cast for (BlockArgument arg : block.getArguments()) { Type arg_type = arg.getType(); - if (auto quant_type = arg_type.cast() - .getElementType() - .dyn_cast()) { + if (auto quant_type = llvm::dyn_cast( + llvm::cast(arg_type).getElementType())) { if (arg.hasOneUse() && llvm::isa(*arg.user_begin())) { - arg.setType( - arg_type.cast().clone(quant_type.getStorageType())); + arg.setType(llvm::cast(arg_type).clone( + quant_type.getStorageType())); } else { std::string error_message; llvm::raw_string_ostream os{error_message}; @@ -77,17 +77,17 @@ void RewriteQuantizedIOPass::runOnOperation() { // with tfr.cast(tensor) -> output for (OpOperand& returned_value : terminator->getOpOperands()) { auto returned_type = - returned_value.get().getType().dyn_cast(); + llvm::dyn_cast(returned_value.get().getType()); if (!returned_type || - !returned_type.getElementType().isa()) { + !llvm::isa(returned_type.getElementType())) { continue; } if (auto returned_op = returned_value.get().getDefiningOp()) { - auto new_type = returned_type.clone(returned_type.getElementType() - .cast() - .getStorageType()); + auto new_type = returned_type.clone( + llvm::cast(returned_type.getElementType()) + .getStorageType()); auto new_op = builder.create( returned_op->getLoc(), new_type, returned_op.getArg()); returned_value.set(new_op.getResult()); diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 7c18a25ef08365..2439e8e3b5e924 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -60,16 +60,10 @@ td_library( gentbl_cc_library( name = "runtime_fallback_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "runtime_fallback_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "runtime_fallback_ops.cc.inc", - ), - ], + tbl_outs = { + "runtime_fallback_ops.h.inc": ["-gen-op-decls"], + "runtime_fallback_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "runtime_fallback/runtime_fallback_ops.td", deps = [":runtime_fallback_ops_td_files"], @@ -556,6 +550,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", + "@stablehlo//:register", "@tf_runtime//:init_tfrt_dialects", "@tf_runtime//:print_stream_pass", ], diff --git a/tensorflow/compiler/mlir/tfrt/ir/BUILD b/tensorflow/compiler/mlir/tfrt/ir/BUILD index b29066807fbf78..ae5379f2102f36 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/BUILD @@ -141,16 +141,10 @@ td_library( gentbl_cc_library( name = "tfrt_fallback_opdefs_inc_gen", compatible_with = get_compatible_with_portable(), # copybara: comment - tbl_outs = [ - ( - ["-gen-op-decls"], - "tfrt_fallback.h.inc", - ), - ( - ["-gen-op-defs"], - "tfrt_fallback.cpp.inc", - ), - ], + tbl_outs = { + "tfrt_fallback.h.inc": ["-gen-op-decls"], + "tfrt_fallback.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tfrt_fallback.td", deps = [":tfrt_fallback_td_files"], @@ -159,16 +153,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tfrt_fallback_async_opdefs_inc_gen", compatible_with = get_compatible_with_portable(), # copybara: comment - tbl_outs = [ - ( - ["-gen-op-decls"], - "tfrt_fallback_async.h.inc", - ), - ( - ["-gen-op-defs"], - "tfrt_fallback_async.cpp.inc", - ), - ], + tbl_outs = { + "tfrt_fallback_async.h.inc": ["-gen-op-decls"], + "tfrt_fallback_async.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tfrt_fallback_async.td", deps = [":tfrt_fallback_td_files"], @@ -176,23 +164,14 @@ gentbl_cc_library( gentbl_cc_library( name = "tfrt_fallback_sync_opdefs_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "tfrt_fallback_sync.h.inc", - ), - ( - ["-gen-op-defs"], - "tfrt_fallback_sync.cpp.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect=tfrt_fallback_sync", - ], - "tfrt_fallback_sync_dialect.h.inc", - ), - ], + tbl_outs = { + "tfrt_fallback_sync.h.inc": ["-gen-op-decls"], + "tfrt_fallback_sync.cpp.inc": ["-gen-op-defs"], + "tfrt_fallback_sync_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=tfrt_fallback_sync", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tfrt_fallback_sync.td", test = True, @@ -219,23 +198,14 @@ td_library( gentbl_cc_library( name = "tfrt_gpu_opdefs_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "gpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "gpu_ops.cpp.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect=gpurt", - ], - "gpurt_dialect.h.inc", - ), - ], + tbl_outs = { + "gpu_ops.h.inc": ["-gen-op-decls"], + "gpu_ops.cpp.inc": ["-gen-op-defs"], + "gpurt_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=gpurt", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "gpu_ops.td", test = True, diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD index 374aad2a242d9b..200f66fd722fef 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD @@ -23,16 +23,10 @@ td_library( gentbl_cc_library( name = "mlrt_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "mlrt_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "mlrt_ops.cpp.inc", - ), - ], + tbl_outs = { + "mlrt_ops.h.inc": ["-gen-op-decls"], + "mlrt_ops.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mlrt_ops.td", deps = [":mlrt_td_files"], @@ -96,16 +90,10 @@ td_library( gentbl_cc_library( name = "tf_mlrt_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "tf_mlrt_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tf_mlrt_ops.cpp.inc", - ), - ], + tbl_outs = { + "tf_mlrt_ops.h.inc": ["-gen-op-decls"], + "tf_mlrt_ops.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_mlrt_ops.td", deps = [":tf_mlrt_td_files"], @@ -113,16 +101,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_mlrt_tpu_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "tf_mlrt_tpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tf_mlrt_tpu_ops.cpp.inc", - ), - ], + tbl_outs = { + "tf_mlrt_tpu_ops.h.inc": ["-gen-op-decls"], + "tf_mlrt_tpu_ops.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_mlrt_tpu_ops.td", deps = [":tf_mlrt_tpu_td_files"], @@ -130,16 +112,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "tf_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tf_ops.cpp.inc", - ), - ], + tbl_outs = { + "tf_ops.h.inc": ["-gen-op-decls"], + "tf_ops.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_ops.td", deps = [":tf_mlrt_td_files"], diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td index b260dcb402f3f2..13409c3ece1f3d 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td @@ -29,7 +29,7 @@ def Mlrt_Dialect : Dialect { } def MlrtFutureType : DialectType()">, "!mlrt.future type">, + CPred<"::llvm::isa<::mlrt::compiler::FutureType>($_self)">, "!mlrt.future type">, BuildableType<"$_builder.getType<::mlrt::compiler::FutureType>()"> { let description = [{ `!mlrt.future type` represents a C++ mlrt::Future. @@ -37,7 +37,7 @@ def MlrtFutureType : DialectType()">, "!mlrt.promise type">, + CPred<"::llvm::isa<::mlrt::compiler::PromiseType>($_self)">, "!mlrt.promise type">, BuildableType<"$_builder.getType<::mlrt::compiler::PromiseType>()"> { let description = [{ `!mlrt.promise type` represents a C++ mlrt::Promise. @@ -45,7 +45,7 @@ def MlrtPromiseType : DialectType()">, "!mlrt.async_handle type">, + CPred<"::llvm::isa<::mlrt::compiler::AsyncHandleType>($_self)">, "!mlrt.async_handle type">, BuildableType<"$_builder.getType<::mlrt::compiler::AsyncHandleType>()"> { let description = [{ `!mlrt.async_handle type` represents a C++ mlrt::AsyncHandle. diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td index 9cf997e0c3e8ce..e706ac0e36c723 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td @@ -37,7 +37,7 @@ class TensorflowMlrt_Op traits = []> : // This corresponds to tensorflow::Tensor. def TFTensorType : DialectType()">, "!tf_mlrt.tensor type">, + CPred<"::llvm::isa<::tensorflow::tf_mlrt::TFTensorType>($_self)">, "!tf_mlrt.tensor type">, BuildableType<"$_builder.getType<::tensorflow::tf_mlrt::TFTensorType>()"> { let description = [{ `!tf_mlrt.tensor type` represents a tensorflow::Tensor. @@ -46,7 +46,7 @@ def TFTensorType : DialectType()">, "!tf_mlrt.device type">, + CPred<"::llvm::isa<::tensorflow::tf_mlrt::TFDeviceType>($_self)">, "!tf_mlrt.device type">, BuildableType<"$_builder.getType<::tensorflow::tf_mlrt::TFDeviceType>()"> { let description = [{ `!tf_mlrt.device type` represents a tensorflow::device. diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td index 0c42590f9aa7ee..6587f825d7a00a 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td @@ -31,7 +31,7 @@ def Fallback_Dialect : Dialect { // This corresponds to tensorflow::Tensor. def TFTensorType : DialectType()">, "!tfrt_fallback.tf_tensor type">, + CPred<"::llvm::isa<::tfrt::fallback::TFTensorType>($_self)">, "!tfrt_fallback.tf_tensor type">, BuildableType<"$_builder.getType<::tfrt::fallback::TFTensorType>()"> { let description = [{ `!tfrt_fallback.tf_tensor type` represents a tensorflow::Tensor. @@ -40,7 +40,7 @@ def TFTensorType : DialectType()">, "!tfrt_fallback.tf_allocator type">, + CPred<"::llvm::isa<::tfrt::fallback::TFAllocatorType>($_self)">, "!tfrt_fallback.tf_allocator type">, BuildableType<"$_builder.getType<::tfrt::fallback::TFAllocatorType>()"> { let description = [{ `!tfrt_fallback.tf_tensor type` represents a tensorflow::Tensor. diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir index a862e6abf7274f..fa2ec0b14c8166 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir @@ -18,3 +18,26 @@ %2 = "tf.IfrtCall"(%arg0, %array_key) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [1 : i32]}> {__tpu_compile_metadata_text = "retvals { sharding { } }"} : (tensor<1x3xf32>, tensor) -> tensor<1x1xf32> return %2 : tensor<1x1xf32> } + + +// ----- +// Variable is used by two CPU ops +// +// CHECK-LABEL: func @serving_default +// CHECK-NEXT: [[HANDLE:%.*]] = "tf.VarHandleOp"() +// CHECK-NEXT: [[ARRAYKEY:%.*]], [[FURTURE:%.*]] = "tf_mlrt.tf_ifrt_load_variable"([[HANDLE]]) +// CHECK-SAME: <{used_by_host = true}> : (tensor>>) -> (tensor, !mlrt.future) +// CHECK: [[TENSOR:%.*]] = "tf_mlrt.tf_await"([[FURTURE]]) : (!mlrt.future) -> tensor<3x1xf32> +// CHECK-NEXT: "tf.AddV2"([[TENSOR]], %cst) : (tensor<3x1xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> +// CHECK-NEXT: "tf.Sub"([[TENSOR]], %cst) : (tensor<3x1xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> +// CHECK-NEXT: return +// + func.func @serving_default() { + %0 = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> + %array_key, %tensor = "tf.IfrtLoadVariable"(%0) <{used_by_host = true}> : (tensor>>) -> (tensor, tensor<3x1xf32>) + %cst_24 = "tf.Const"() <{value = dense<[[0.0], [1.0], [2.0]]> : tensor<3x1xf32>}> : () -> tensor<3x1xf32> + %1 = "tf.AddV2"(%tensor, %cst_24) : (tensor<3x1xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> + %2 = "tf.Sub"(%tensor, %cst_24) : (tensor<3x1xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> + + return + } diff --git a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc index 0de1d1eaabf4bd..c6d21e330ad6ec 100644 --- a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc +++ b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/dialect/Register.h" // from @stablehlo #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" @@ -61,6 +62,7 @@ int main(int argc, char **argv) { mlrt::compiler::MlrtDialect>(); tensorflow::RegisterTPUDialects(®istry); tensorflow::RegisterGpuDialects(®istry); + mlir::stablehlo::registerAllDialects(registry); tfrt::RegisterTFRTDialects(registry); tensorflow::tfrt_compiler::RegisterTPULowerClusterToRuntimeOpsPassPipeline(); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index 2162d37eebcfef..c3b2df1dd0f852 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -23,7 +23,7 @@ package_group( "//learning/pathways/serving/runtime/...", "//learning/pathways/serving/tests/...", "//learning/brain/tfrt/ifrt/...", - "//learning/brain/tfrt/mlir/mlrt/application/pathways/compiler/...", + "//learning/brain/tfrt/tfrt_session/...", # Allow visibility from the mlir language server. "//learning/brain/mlir/mlir_lsp_server/...", "//learning/infra/mira/experimental/orbax_model/serving/sidecar/...", @@ -33,15 +33,10 @@ package_group( gentbl_cc_library( name = "pass_inc_gen", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TfrtIfrtServing", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "-name=TfrtIfrtServing", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", deps = [ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc index 24252c40ae7da9..1bd737b98c3787 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc @@ -118,11 +118,10 @@ TEST_F(Tf2HloTest, Empty) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, {})); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -168,11 +167,10 @@ TEST_F(Tf2HloTest, Tuple) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -219,11 +217,10 @@ TEST_F(Tf2HloTest, Spmd) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -307,11 +304,10 @@ TEST_F(Tf2HloTest, UsingDefaultDeviceAssignment) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -420,11 +416,10 @@ TEST_F(Tf2HloTest, XlaCallHostCallback) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -530,11 +525,10 @@ TEST_F(Tf2HloTest, SameArgProduceSameKeyFingerprint) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -592,11 +586,10 @@ TEST_F(Tf2HloTest, DifferentCompileMetadataProduceDifferentKeyFingerprint) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc index 368a91ac54f955..98058a3b32028c 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -75,16 +76,27 @@ class RewriteIfrtLoadVariablePass builder.create( load_variable_op->getLoc(), result_types, load_variable_op->getOperands(), load_variable_op->getAttrs()); - for (auto user : load_variable_op.getTensorFuture().getUsers()) { - builder.setInsertionPoint(user); - auto await_op = builder.create( - user->getLoc(), load_variable_op.getTensorFuture().getType(), - mlrt_load_variable_op.getTensorFuture()); + tf_mlrt::TFAwaitOp await_op; + for (auto user : llvm::make_early_inc_range( + load_variable_op.getTensorFuture().getUsers())) { + // Materialize the future for the first use. Reuse it for the rest of + // the uses. + if (!await_op) { + builder.setInsertionPoint(user); + await_op = builder.create( + user->getLoc(), load_variable_op.getTensorFuture().getType(), + mlrt_load_variable_op.getTensorFuture()); + } else { + if (user->isBeforeInBlock(await_op)) { + await_op->moveBefore(user); + } + } user->replaceUsesOfWith(load_variable_op.getTensorFuture(), await_op.getResult()); } - for (auto user : load_variable_op.getArrayKey().getUsers()) { + for (auto user : llvm::make_early_inc_range( + load_variable_op.getArrayKey().getUsers())) { user->replaceUsesOfWith(load_variable_op.getArrayKey(), mlrt_load_variable_op.getArrayKey()); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc index 58ad4d8561628a..a82ba0be0cd234 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -1059,8 +1059,6 @@ class TfToMlrtConversionPass type_converter_.addTargetMaterialization(future_to_tensor_materialization); type_converter_.addSourceMaterialization(future_to_tensor_materialization); - type_converter_.addArgumentMaterialization( - future_to_tensor_materialization); if (use_tpu_host_allocator_for_inputs_.hasValue()) { options_.use_tpu_host_allocator_for_inputs = diff --git a/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc b/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc index 2e33dcb9e67d5a..0ed5a6ac1b6a8a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc @@ -60,10 +60,8 @@ struct RewriteStatefulPartitionedCallToXlaLaunchOnCpu for (int i = 0; i < op.getNumOperands(); ++i) { auto value = op.getOperand(i); - if (value.getType() - .cast() - .getElementType() - .isa()) { + if (llvm::isa( + llvm::cast(value.getType()).getElementType())) { resources.push_back(i); } else if (auto* def = value.getDefiningOp(); def && llvm::isa(def)) { diff --git a/tensorflow/compiler/mlir/tools/BUILD b/tensorflow/compiler/mlir/tools/BUILD new file mode 100644 index 00000000000000..3b29e0f5666497 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/BUILD @@ -0,0 +1,51 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +cc_library( + name = "translate_cl_options", + srcs = [ + "tf_mlir_translate_cl.cc", + ], + hdrs = [ + "tf_mlir_translate_cl.h", + ], + deps = [ + "@llvm-project//llvm:Support", + ], + alwayslink = 1, +) + +cc_library( + name = "translate_registration", + srcs = [ + "tf_mlir_translate_registration.cc", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow/translate/tools:file_tf_mlir_translate", + "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", + "//tensorflow/compiler/mlir/tools:translate_cl_options", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TranslateLib", + "@local_tsl//tsl/platform:protobuf", + "@local_xla//xla/client:client_library", + "@local_xla//xla/client:compile_only_client", + "@local_xla//xla/service/cpu:cpu_compiler", + "@local_xla//xla/service/cpu:cpu_transfer_manager", + "@local_xla//xla/stream_executor/host:host_platform", + "@local_xla//xla/stream_executor/host:host_platform_id", + ], + alwayslink = 1, +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index 42b29d86d31e51..0c504a62de1627 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -30,24 +30,12 @@ td_library( gentbl_cc_library( name = "tf_framework_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "tf_framework_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tf_framework_ops.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "tf_framework_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "tf_framework_dialect.cc.inc", - ), - ], + tbl_outs = { + "tf_framework_ops.h.inc": ["-gen-op-decls"], + "tf_framework_ops.cc.inc": ["-gen-op-defs"], + "tf_framework_dialect.h.inc": ["-gen-dialect-decls"], + "tf_framework_dialect.cc.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_framework_ops.td", deps = [":td_files"], @@ -56,16 +44,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_status_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-enum-decls"], - "tf_status.h.inc", - ), - ( - ["-gen-enum-defs"], - "tf_status.cc.inc", - ), - ], + tbl_outs = { + "tf_status.h.inc": ["-gen-enum-decls"], + "tf_status.cc.inc": ["-gen-enum-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_status.td", deps = [":td_files"], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td index d8e7617cc352ba..64f782d02346e8 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td @@ -43,7 +43,7 @@ def TFFramework_Dialect : Dialect { } def TFFramework_OpKernelContextType : DialectType()">, + CPred<"llvm::isa<::mlir::kernel_gen::tf_framework::OpKernelContextType>($_self)">, "op_kernel_construction">, BuildableType<"$_builder.getType<::mlir::kernel_gen::tf_framework::OpKernelContextType>()"> { let description = [{ @@ -53,7 +53,7 @@ def TFFramework_OpKernelContextType : DialectType()">>, + "llvm::isa<::mlir::kernel_gen::tf_framework::JITCallableType>($_self)">>, BuildableType<"$_builder.getType<::mlir::kernel_gen::tf_framework::JITCallableType>()"> { let description = [{ A `callable` represents the result of JIT compilation. Conceptually, it @@ -107,7 +107,7 @@ def TFFramework_TFAllocOp : TFFramework_Op<"alloc", [ }]>]; let extraClassDeclaration = [{ - MemRefType getType() { return getResult().getType().cast(); } + MemRefType getType() { return llvm::cast(getResult().getType()); } static constexpr StringRef kReuseOutputAttrName = "reuse_output"; static constexpr StringRef kReuseInputCandidatesAttrName = "reuse_input_candidates"; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 6321ccbc3f5d87..aba9f13f0f2649 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -170,9 +170,9 @@ absl::Status LowerHlotoLoops(mlir::ModuleOp module, pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass(mlir::createCSEPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addNestedPass(mlir::mhlo::createShapeSimplification()); - pm.addNestedPass(mlir::mhlo::createMergeAssumingOpsPass()); - pm.addNestedPass(mlir::mhlo::createBroadcastPropagationPass()); + pm.addNestedPass(mlir::kernel_gen::createShapeSimplificationPass()); + pm.addNestedPass(mlir::kernel_gen::createMergeAssumingOpsPass()); + pm.addNestedPass(mlir::kernel_gen::createBroadcastPropagationPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass(mlir::createCSEPass()); diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/broadcast_propagation.mlir similarity index 99% rename from third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir rename to tensorflow/compiler/mlir/tools/kernel_gen/tests/broadcast_propagation.mlir index 4bf50644127e70..f366f1938e0a38 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/broadcast_propagation.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s --split-input-file --mhlo-broadcast-propagation | \ +// RUN: kernel-gen-opt %s --split-input-file --mhlo-broadcast-propagation | \ // RUN: FileCheck %s // CHECK-LABEL: @single_bcast diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/merge_assuming_ops.mlir similarity index 99% rename from third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir rename to tensorflow/compiler/mlir/tools/kernel_gen/tests/merge_assuming_ops.mlir index f8ff1a33d1c97b..d463da199549e3 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/merge_assuming_ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect \ +// RUN: kernel-gen-opt --split-input-file --allow-unregistered-dialect \ // RUN: --mhlo-merge-assuming-ops --canonicalize --cse %s | \ // RUN: FileCheck %s diff --git a/third_party/xla/xla/mlir_hlo/tests/shape_simplification.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/shape_simplification.mlir similarity index 98% rename from third_party/xla/xla/mlir_hlo/tests/shape_simplification.mlir rename to tensorflow/compiler/mlir/tools/kernel_gen/tests/shape_simplification.mlir index 998918bdfa0744..f7ff67753bc235 100644 --- a/third_party/xla/xla/mlir_hlo/tests/shape_simplification.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/shape_simplification.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -split-input-file -shape-simplification %s | FileCheck %s +// RUN: kernel-gen-opt -split-input-file -shape-simplification %s | FileCheck %s // Incompatible shapes. No folding. // CHECK-LABEL: func @f diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 13b61395bcda08..262f9fc56d78f2 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -85,13 +85,10 @@ cc_library( gentbl_cc_library( name = "kernel_gen_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [( - [ - "-gen-pass-decls", - "-name=KernelGen", - ], - "kernel_gen_passes.h.inc", - )], + tbl_outs = {"kernel_gen_passes.h.inc": [ + "-gen-pass-decls", + "-name=KernelGen", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], @@ -181,15 +178,18 @@ cc_library( cc_library( name = "passes", srcs = [ + "broadcast_propagation_pass.cc", "buffer_reuse_pass.cc", "bufferize_pass.cc", "copy_cleanup_pass.cc", "embed_tf_framework_pass.cc", "func_to_jit_invocations.cc", "fuse_inner_parallel_loops_pass.cc", + "merge_assuming_ops_pass.cc", "parallel_loops_to_sequential.cc", "rewrite_tf_framework_assert.cc", "same_shape_propagation.cc", + "shape_simplification_pass.cc", "shape_to_descriptors_pass.cc", "tensorflow_abi_knowledge_propagation.cc", ], @@ -200,8 +200,6 @@ cc_library( ":embed_tf_framework", # buildcleaner: keep ":kernel_gen_passes_inc_gen", ":tf_framework_legalize_to_llvm", # buildcleaner: keep - ":utils", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -211,6 +209,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:MathDialect", @@ -226,7 +225,9 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", + "@local_xla//xla/mlir_hlo", + "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/mlir_hlo:transforms_passes", + "@stablehlo//:base", ], ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/broadcast_propagation_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/broadcast_propagation_pass.cc new file mode 100644 index 00000000000000..840de572368c83 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/broadcast_propagation_pass.cc @@ -0,0 +1,462 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace kernel_gen { + +using mhlo::DynamicBroadcastInDimOp; + +#define GEN_PASS_DEF_BROADCASTPROPAGATIONPASS +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +namespace { + +// To avoid duplicate broadcasts, we collect all the intended broadcasts ahead +// of realizing any broadcasts in the IR. These are broadcasted versions of +// values that we are interested in, and they are uniquely characterized by a +// `BroadcastIntent` value. +struct BroadcastIntent { + RankedTensorType resultType; + Value targetValue; + Value outputDimensions; + Attribute broadcastDimensions; + bool operator==(BroadcastIntent rhs) const { + return resultType == rhs.resultType && targetValue == rhs.targetValue && + outputDimensions == rhs.outputDimensions && + broadcastDimensions == rhs.broadcastDimensions; + } + bool operator!=(BroadcastIntent rhs) const { return !(*this == rhs); } +}; + +} // namespace +} // namespace kernel_gen +} // namespace mlir + +namespace llvm { + +using mlir::kernel_gen::BroadcastIntent; + +template <> +struct DenseMapInfo { + static BroadcastIntent getEmptyKey() { + return {DenseMapInfo::getEmptyKey(), + DenseMapInfo::getEmptyKey(), + DenseMapInfo::getEmptyKey(), + DenseMapInfo::getEmptyKey()}; + } + static BroadcastIntent getTombstoneKey() { + return {DenseMapInfo::getTombstoneKey(), + DenseMapInfo::getTombstoneKey(), + DenseMapInfo::getTombstoneKey(), + DenseMapInfo::getTombstoneKey()}; + } + static unsigned getHashValue(const BroadcastIntent &intent) { + return hash_combine( + DenseMapInfo::getHashValue(intent.resultType), + DenseMapInfo::getHashValue(intent.targetValue), + DenseMapInfo::getHashValue(intent.outputDimensions), + DenseMapInfo::getHashValue( + intent.broadcastDimensions)); + } + static bool isEqual(const BroadcastIntent &lhs, const BroadcastIntent &rhs) { + return lhs == rhs; + } +}; + +} // namespace llvm + +namespace mlir { +namespace kernel_gen { +namespace { + +bool allowsForElementwiseBroadcastPropagation(Operation *op) { + if (op && op->hasTrait() && + op->hasTrait() && op->getNumResults() == 1) { + return true; + } + if (op && op->hasTrait() && + op->getNumResults() == 1) { + return true; + } + return false; +} + +bool allowsForBroadcastPropagation(Operation *op) { + return llvm::isa_and_nonnull(op) || + allowsForElementwiseBroadcastPropagation(op); +} + +DenseIntElementsAttr composeBroadcastDimensionsAttr(OpBuilder &builder, + DenseIntElementsAttr a, + DenseIntElementsAttr b) { + SmallVector bVec = + llvm::to_vector(llvm::map_range(b, [](const APInt &it) { + return static_cast(it.getLimitedValue()); + })); + SmallVector composedVec = llvm::to_vector(llvm::map_range( + a, [bVec](const APInt &it) { return bVec[it.getLimitedValue()]; })); + return builder.getI64TensorAttr(composedVec); +} + +// Find all the broadcast intents and their dependencies. Start analyzing from +// the root an collect all broadcast intents that can help broadcast propagation +// from there. +void findBroadcastIntents( + DynamicBroadcastInDimOp root, Block *parentBlock, + BroadcastIntent &rootBcastIntent, + SmallVector &bcastIntents, + DenseMap> + &bcastIntentDependencies) { + OpBuilder builder(root.getContext()); + + // Use the result vector of broadcast intents as a worklist. The set of + // broadcast intents helps to ensure their uniqueness. + DenseSet bcastIntentsSet; + auto addToWorklistIfNew = [&](BroadcastIntent bcastIntent) { + if (!bcastIntentsSet.count(bcastIntent)) { + bcastIntentsSet.insert(bcastIntent); + bcastIntents.push_back(bcastIntent); + } + }; + + // Derive the broadcast intent associated with the root broadcast operation. + // Add it to the worklist to seed the analysis. + rootBcastIntent = {mlir::cast(root.getResult().getType()), + root.getOperand(), root.getOutputDimensions(), + root.getBroadcastDimensions()}; + addToWorklistIfNew(rootBcastIntent); + + // We use result vector of broadcast intents as a worklist, the first `i` + // intents of which have been processed. + for (int64_t i = 0; i < static_cast(bcastIntents.size()); ++i) { + BroadcastIntent it = bcastIntents[i]; + Operation *producerOp = it.targetValue.getDefiningOp(); + + // We can propagate broadcasts over (broadcasting) element-wise operations + // and dynamic_broadcast_in_dim ops with the restriction that they must be + // in the same block as they may depend on assuming regions. + if (!producerOp || producerOp->getBlock() != parentBlock || + !allowsForBroadcastPropagation(producerOp)) { + continue; + } + + // We can skip broadcasting producers (dynamic_broadcast_in_dim ops) if we + // compose their broadcasting dimensions. + if (auto producerBcastOp = + llvm::dyn_cast(producerOp)) { + DenseIntElementsAttr composedBcastDims = composeBroadcastDimensionsAttr( + builder, producerBcastOp.getBroadcastDimensions(), + mlir::cast(it.broadcastDimensions)); + BroadcastIntent bcastedOperandIntent = { + it.resultType, producerBcastOp.getOperand(), it.outputDimensions, + composedBcastDims}; + + // Record dependency and "recur". + bcastIntentDependencies[it] = {bcastedOperandIntent}; + addToWorklistIfNew(bcastedOperandIntent); + continue; + } + + // We can propagate broadcasts over (broadcasting) element-wise operations. + // Instead of broadcasting the result of such an op, we can broadcast the + // operands and apply the element-wise operation to them. + assert(allowsForElementwiseBroadcastPropagation(producerOp)); + bcastIntentDependencies[it] = {}; + for (auto operand : producerOp->getOperands()) { + auto operandTy = mlir::cast(operand.getType()); + auto operandBcastDims = operandTy.getRank() == 0 + ? builder.getI64TensorAttr({}) + : it.broadcastDimensions; + auto bcastedOperandTy = RankedTensorType::get(it.resultType.getShape(), + operandTy.getElementType()); + BroadcastIntent bcastedOperandIntent = { + bcastedOperandTy, operand, it.outputDimensions, operandBcastDims}; + + // Record dependency and "recur". + bcastIntentDependencies[it].push_back(bcastedOperandIntent); + addToWorklistIfNew(bcastedOperandIntent); + } + } +} + +void sortBroadcastIntentsInReverseTopologicalOrder( + SmallVector &bcastIntentsVec, Block *parentBlock) { + // Sort broadcast intents in reverse topological order of the producer ops. We + // can use the positions in the block for this. All broadcast intents outside + // the block (e.g. arguments) will be sorted towards the front. + // This ordering is independent of the output dimensions as dependencies can + // only occur between broadcast intents of the same output dimension. + std::sort(bcastIntentsVec.begin(), bcastIntentsVec.end(), + [parentBlock](const BroadcastIntent &a, const BroadcastIntent &b) { + Operation *producerOpA = a.targetValue.getDefiningOp(); + Operation *producerOpB = b.targetValue.getDefiningOp(); + bool aInBlock = producerOpA != nullptr && + producerOpA->getBlock() == parentBlock; + bool bInBlock = producerOpB != nullptr && + producerOpB->getBlock() == parentBlock; + if (aInBlock && bInBlock) { + return producerOpA->isBeforeInBlock(producerOpB); + } + return !aInBlock && bInBlock; + }); +} + +void setInsertionPointToEarliestPointWithAllValuesAvailable( + PatternRewriter &rewriter, Block *block, ValueRange values) { + Operation *lastDef = nullptr; + for (Value v : values) { + Operation *def = v.getDefiningOp(); + if (def && def->getBlock() == block) { + if (!lastDef || lastDef->isBeforeInBlock(def)) lastDef = def; + } + } + if (lastDef) { + rewriter.setInsertionPointAfter(lastDef); + } else { + rewriter.setInsertionPointToStart(block); + } +} + +DenseMap realizeBroadcastIntents( + SmallVector &sortedBcastIntents, + DenseMap> + &bcastIntentDependencies, + Block *parentBlock, PatternRewriter &rewriter) { + // Realize broadcast intents in order. They must be sorted so that their + // dependencies are realized before them. + DenseMap realizations; + for (auto it : sortedBcastIntents) { + Operation *producerOp = it.targetValue.getDefiningOp(); + assert(!realizations.count(it) && "expect unrealized broadcast intent"); + auto deps = bcastIntentDependencies.find(it); + + // If we cannot propagate broadcasts further, materialize them as a + // dynamic_broadcast_in_dim op. + if (!producerOp || producerOp->getBlock() != parentBlock || + !allowsForBroadcastPropagation(producerOp)) { + assert(deps == bcastIntentDependencies.end() && "expect no dependencies"); + setInsertionPointToEarliestPointWithAllValuesAvailable( + rewriter, parentBlock, + ValueRange{it.targetValue, it.outputDimensions}); + realizations[it] = rewriter.create( + it.targetValue.getLoc(), it.resultType, it.targetValue, + it.outputDimensions, + mlir::cast(it.broadcastDimensions)); + continue; + } + + // For broadcast propagation across dynamic_broadcast_in_dim ops, the + // broadcasted value is already materialized. Forward it. + if (auto producerBcastOp = + llvm::dyn_cast_or_null(producerOp)) { + assert(deps != bcastIntentDependencies.end() && + deps->second.size() == 1 && "expect one dependency"); + auto bcastedOperand = realizations.find(deps->second.front()); + assert(bcastedOperand != realizations.end()); + realizations[it] = Value(bcastedOperand->second); + continue; + } + + // Othwerwise, realize broadcast intent for a (broadcasting) element-wise + // operation based on the broadcasted operands. + assert(allowsForElementwiseBroadcastPropagation(producerOp) && + "expect broadcast propagation over an (broadcasting) element-wise " + "operation"); + assert(deps != bcastIntentDependencies.end() && + deps->second.size() == producerOp->getNumOperands() && + "expect one dependency per operand"); + auto bcastedOperands = llvm::to_vector( + llvm::map_range(deps->second, [&](BroadcastIntent operandIntent) { + auto bcastedOperand = realizations.find(operandIntent); + assert(bcastedOperand != realizations.end() && + "expect dependencies to be realized earlier"); + return bcastedOperand->second; + })); + setInsertionPointToEarliestPointWithAllValuesAvailable( + rewriter, parentBlock, bcastedOperands); + OperationState newProducerOpState( + producerOp->getLoc(), producerOp->getName().getStringRef(), + bcastedOperands, it.resultType, producerOp->getAttrs()); + Operation *newProducerOp = rewriter.create(newProducerOpState); + assert(newProducerOp->getNumResults() == 1 && "expect exactly one result"); + realizations[it] = newProducerOp->getResults().front(); + } + + return realizations; +} + +void transitivelyEraseUnusedSideEffectFreeOps(Operation *root, + PatternRewriter &rewriter) { + // Find ops to erase. + SmallPtrSet opsToEraseSet; + SmallVector opsToErase; + SmallVector worklist = {root}; + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + + // Erase ops only once. + if (opsToEraseSet.count(op)) continue; + + // Erase only operations that are unused and free of side effects. + if (!isMemoryEffectFree(op) || + !llvm::all_of(op->getUsers(), [opsToEraseSet](Operation *user) { + return opsToEraseSet.count(user); + })) { + continue; + } + + // Erase and "recur". + opsToEraseSet.insert(op); + opsToErase.push_back(op); + for (Value operand : op->getOperands()) { + if (Operation *def = operand.getDefiningOp()) worklist.push_back(def); + } + } + + // Finally, erase the ops in the order of their uses. + for (Operation *op : opsToErase) rewriter.eraseOp(op); +} + +LogicalResult propagateBroadcast(DynamicBroadcastInDimOp root, + Block *parentBlock, + PatternRewriter &rewriter) { + // We can move broadcasts up over (i) (broadcasting) element-wise operations + // and (i) dynamic_broadcast_in_dim ops. This way, we propagate them through + // the IR to perform them early. Instead of broadcasting the result of such an + // op, we can broadcast the operands and apply the element-wise operation to + // them. + // + // To avoid exponential growth of the IR, we will do this in two phases: + // 1) First, we collect all the unique broadcast intents. These are + // broadcasted versions of values that we are interested in. They may + // later be materialized as an explicit broadcast or they can be the + // direct result of an operation over which a broadcast was propagated. + // 2) Then, we fulfill every broadcast intent in reverse topological order + // to ensure that their dependencies (the broadcasted operands) are + // available. + + // Find the unique broadcast intents. + BroadcastIntent rootBcastIntent; + SmallVector bcastIntents; + DenseMap> + bcastIntentDependencies; + findBroadcastIntents(root, parentBlock, rootBcastIntent, bcastIntents, + bcastIntentDependencies); + + // Fail if there is nothing but the root intent, i.e. if there is nothing to + // rewrite here. + if (bcastIntents.size() <= 1) { + assert(bcastIntents.front() == rootBcastIntent && "expect root intent"); + return failure(); + } + + // Sort the broadcast intents in reverse topological order so that they can be + // materialized and every depency is available when needed. + sortBroadcastIntentsInReverseTopologicalOrder(bcastIntents, parentBlock); + + // Realize broadcast intents. + DenseMap realizations = realizeBroadcastIntents( + bcastIntents, bcastIntentDependencies, parentBlock, rewriter); + + // Find the operations that may become redundant after replacing the root + // operation. This allows us to transitively erase unused side effect-free + // operations that result from this rewrite (after the root operation is no + // longer accessible). + SmallVector possiblyUnused; + for (auto operand : root->getOperands()) { + if (Operation *def = operand.getDefiningOp()) possiblyUnused.push_back(def); + } + + // Replace the root operation with its broadcast intent's realization. + rewriter.replaceOp(root, realizations[rootBcastIntent]); + + // Erase all the operations that have become redundant as a result of this + // rewrite. + for (Operation *op : possiblyUnused) { + transitivelyEraseUnusedSideEffectFreeOps(op, rewriter); + } + + return success(); +} + +struct BroadcastPropagationPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, + PatternRewriter &rewriter) const override { + return propagateBroadcast(op, op->getBlock(), rewriter); + } +}; + +struct BroadcastPropagationPass + : public impl::BroadcastPropagationPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + + // Collect patterns. + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + // Apply broadcast propagation in reverse order to start propagation at + // the root of broadcast chains. This avoids duplicate work. + GreedyRewriteConfig config; + config.useTopDownTraversal = false; + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/merge_assuming_ops_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/merge_assuming_ops_pass.cc new file mode 100644 index 00000000000000..47a0d36fe2b748 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/merge_assuming_ops_pass.cc @@ -0,0 +1,476 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace kernel_gen { + +#define GEN_PASS_DEF_MERGEASSUMINGOPSPASS +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +namespace { + +struct ShapeReificationPattern : public OpRewritePattern { + explicit ShapeReificationPattern(MLIRContext *context) + : OpRewritePattern(context) { + // Recursively reify until we hit an op that doesn't support it. + setHasBoundedRewriteRecursion(); + } + + LogicalResult matchAndRewrite(shape::ShapeOfOp op, + PatternRewriter &rewriter) const override { + // Only reify shape computation if operand allows for it. + auto shapeOrigin = op.getArg().getDefiningOp(); + if (!shapeOrigin) return failure(); + + llvm::SmallVector reifications; + if (failed(shapeOrigin.reifyReturnTypeShapes( + rewriter, shapeOrigin->getOperands(), reifications))) + return failure(); + assert(reifications.size() == 1); + Value reifiedShape = reifications.front(); + + // Insert cast if needed. + if (reifiedShape.getType() != op.getType()) { + reifiedShape = rewriter.create(op.getLoc(), op.getType(), + reifiedShape); + } + + rewriter.replaceOp(op, reifiedShape); + return success(); + } +}; + +template +struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Find all the shape operands, direct and indirect. + SmallVector inlinedOperands; + for (Value direct : op->getOperands()) { + if (auto bcastOp = direct.getDefiningOp()) { + for (Value indirect : bcastOp->getOperands()) + inlinedOperands.push_back(indirect); + } else { + inlinedOperands.push_back(direct); + } + } + + // Only rewrite if it makes a difference. + if (inlinedOperands.size() == op.getNumOperands()) return failure(); + + // Inline shape operands. + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), inlinedOperands, + op->getAttrs()); + return success(); + } +}; + +LogicalResult moveUpIntoAssumingOpMatchAndRewrite(Operation *op, + PatternRewriter &rewriter) { + // Only implemented for single-result ops. + if (op->getNumResults() != 1) return failure(); + + // Find a preceding `assuming` op. + auto *theBlock = op->getBlock(); + Operation *prev = op->getPrevNode(); + while (prev != nullptr && !llvm::isa(prev)) + prev = prev->getPrevNode(); + auto assumingOp = llvm::dyn_cast_or_null(prev); + if (!assumingOp) return failure(); + assert(assumingOp->getBlock() == theBlock && op->getBlock() == theBlock && + "expect assuming op and root op to be in the same block"); + + // Make sure that all operands will be available after moving. + auto isAvailable = [&](Value v) { + Operation *def = v.getDefiningOp(); + return def == nullptr || def->getBlock() != theBlock || + !assumingOp->isBeforeInBlock(def); + }; + if (!llvm::all_of(op->getOperands(), isAvailable)) return failure(); + + Block *body = assumingOp.getBody(); + auto yieldOp = llvm::cast(body->getTerminator()); + + // Find the operands to use if the op was within the assuming region. We + // will later use their copies, as we copy the assuming op and its body. + SmallVector newOperandsUnmapped = + llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) { + for (const auto &result : llvm::enumerate(assumingOp->getResults())) { + if (result.value() == v) return yieldOp->getOperand(result.index()); + } + return v; + })); + + // Insert the rewritten assuming op right before the old one. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(assumingOp); + auto newAssumingOp = rewriter.create( + assumingOp.getLoc(), assumingOp.getWitness(), + [&](OpBuilder &b, Location) { + // Copy body. + IRMapping mapping; + for (auto &nested : body->without_terminator()) + b.clone(nested, mapping); + + // Copy op into the new body and use the mapped operands. + for (auto it : llvm::zip(op->getOperands(), newOperandsUnmapped)) { + Value oldOperand, newOperandUnmapped; + std::tie(oldOperand, newOperandUnmapped) = it; + mapping.map(oldOperand, mapping.lookupOrDefault(newOperandUnmapped)); + } + Operation *newOp = b.clone(*op, mapping); + + // Yield the previous results and also the new ones. + auto mappedResults = llvm::to_vector<8>(llvm::map_range( + yieldOp.getOperands(), + [&](Value v) { return mapping.lookupOrDefault(v); })); + mappedResults.append(newOp->getResults().begin(), + newOp->getResults().end()); + return mappedResults; + }); + + // Replace the assuming op and the root op with the corresponding result + // values. + ValueRange newAssumingOpResults = newAssumingOp->getResults(); + rewriter.replaceOp(assumingOp, newAssumingOpResults.drop_back()); + rewriter.replaceOp(op, newAssumingOpResults.back()); + return success(); +} + +/// Move operation into a preceding assuming op. This allows to process +/// operations that depend on the assuming op's results. It will eventually +/// allow to make assuming regions' constraints independent from each other. +template +struct MoveUpIntoAssumingOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + return moveUpIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter); + } +}; + +// Move elementwise operations into a preceding assuming op. This will +// eventually allow for more fusion opportunities. +struct MoveElementwiseOpsUpIntoAssumingOpPattern : public RewritePattern { + explicit MoveElementwiseOpsUpIntoAssumingOpPattern(MLIRContext *ctx) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // Apply to all elementwise and broadcasting elementwise operations with no + // side effects. + if (!op->hasTrait() && + !op->hasTrait()) { + return failure(); + } + if (!isMemoryEffectFree(op)) return failure(); + + return moveUpIntoAssumingOpMatchAndRewrite(op, rewriter); + } +}; + +// Move operation into an assuming region if all uses are within its body. +LogicalResult moveDownIntoAssumingOpMatchAndRewrite(Operation *op, + PatternRewriter &rewriter) { + auto users = op->getUsers(); + auto it = users.begin(); + auto end = users.end(); + if (it == end) return failure(); + + // Find candidate assuming op. + auto assumingOp = (it++)->getParentOfType(); + if (!assumingOp || assumingOp->isProperAncestor(op)) return failure(); + + // Make sure all uses are within the unique assuming op's body. + while (it != end) { + auto hopefullySameAssumingOp = (it++)->getParentOfType(); + if (!hopefullySameAssumingOp || hopefullySameAssumingOp != assumingOp) { + return failure(); + } + } + + // Move op into the assuming region. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(assumingOp.getBody()); + Operation *newOp = rewriter.clone(*op); + rewriter.replaceOp(op, newOp->getResults()); + return success(); +} + +// Move elementwise operations into succeeding assuming regions. This will +// eventually allow for more fusion opportunities. +struct MoveElementwiseOpsDownIntoAssumingOpPattern : public RewritePattern { + explicit MoveElementwiseOpsDownIntoAssumingOpPattern(MLIRContext *ctx) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // Apply to all elementwise and broadcasting elementwise operations with no + // side effects. + if (!op->hasTrait() && + !op->hasTrait()) { + return failure(); + } + if (!isMemoryEffectFree(op)) return failure(); + + return moveDownIntoAssumingOpMatchAndRewrite(op, rewriter); + } +}; + +/// Move operation out of assuming op. This is only valid for +/// constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It +/// will eventually allow to make assuming regions' constraints independent from +/// each other. +template +struct MoveUpOutOfAssumingOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Must be inside of an assuming op. + auto assumingOp = op->template getParentOfType(); + if (!assumingOp) return failure(); + + // Operands must not be defined within the assuming op. + Block *body = assumingOp.getBody(); + auto isAvailable = [&](Value v) { + Operation *def = v.getDefiningOp(); + return def == nullptr || def->getBlock() != body; + }; + if (!llvm::all_of(op->getOperands(), isAvailable)) return failure(); + + // Move op before the assuming region. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(assumingOp); + Operation *newOp = rewriter.clone(*op); + rewriter.replaceOp(op, newOp->getResults()); + + // If the assuming region yields none of the new op's results, these values + // are exclusively used in the assuming op's body. In these cases there is + // no need for further rewrites. + auto isNewOpResult = [newOp](Value v) { + return llvm::is_contained(newOp->getResults(), v); + }; + auto yieldOp = cast(body->getTerminator()); + if (llvm::none_of(yieldOp.getOperands(), isNewOpResult)) return success(); + + // If the assuming region yields any of the new op's results, these values + // can instead bypass the assuming region. There is no need to yield them + // explicitly as they are assumed to be independent. The assuming op is + // rewritten accordingly. + SmallVector replacementValues; + auto newAssumingOp = rewriter.create( + assumingOp.getLoc(), assumingOp.getWitness(), + [&](OpBuilder &b, Location) { + // Copy body. + IRMapping mapping; + for (Operation &nested : body->without_terminator()) { + b.clone(nested, mapping); + } + + // Collect new yield operands. + SmallVector newYieldOperands; + for (Value result : yieldOp.getOperands()) { + if (isNewOpResult(result)) { + replacementValues.push_back(result); + } else { + newYieldOperands.push_back(mapping.lookupOrDefault(result)); + replacementValues.push_back(nullptr); + } + } + return newYieldOperands; + }); + + // Use the assuming op's results for the missing replacement values. + auto src = newAssumingOp.getResults().begin(); + for (auto &dst : replacementValues) { + if (dst) continue; + dst = *src++; + } + + rewriter.replaceOp(assumingOp, replacementValues); + return success(); + } +}; + +/// Merge assuming regions if their constraints are independent from each other. +struct MergeAssumingOpsPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::AssumingOp op, + PatternRewriter &rewriter) const override { + // Merge assuming op with directly preceding one if both witnesses are + // available. + auto precedingOp = + llvm::dyn_cast_or_null(op->getPrevNode()); + if (!precedingOp) return failure(); + if (op.getWitness().getDefiningOp() == precedingOp) return failure(); + + // Merge witnesses. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(precedingOp); + Value newWitness = rewriter.create( + op.getWitness().getDefiningOp()->getLoc(), + ValueRange{precedingOp.getWitness(), op.getWitness()}); + + // Merge assuming ops. + Block *body_a = precedingOp.getBody(); + Block *body_b = op.getBody(); + auto newAssumingOp = rewriter.create( + precedingOp.getLoc(), newWitness, [&](OpBuilder &b, Location) { + // Copy preceding op's body. + IRMapping mapping; + for (auto &nested : body_a->without_terminator()) { + b.clone(nested, mapping); + } + + // Map result values of preceding assuming op. + auto yieldOpA = + llvm::dyn_cast(body_a->getTerminator()); + for (auto pair : + llvm::zip(precedingOp->getResults(), yieldOpA.getOperands())) { + mapping.map(std::get<0>(pair), + mapping.lookupOrDefault(std::get<1>(pair))); + } + + // Copy op's body. + for (auto &nested : body_b->without_terminator()) { + b.clone(nested, mapping); + } + + // Collect merged assuming op's results. + SmallVector mappedResults; + auto yieldOpB = + llvm::dyn_cast(body_b->getTerminator()); + for (Value v : yieldOpA.getOperands()) { + mappedResults.push_back(mapping.lookupOrDefault(v)); + } + for (Value v : yieldOpB.getOperands()) { + mappedResults.push_back(mapping.lookupOrDefault(v)); + } + return mappedResults; + }); + + // Replace the two assuming ops with the new corresponding results. + ValueRange newResults = newAssumingOp->getResults(); + size_t splitAt = precedingOp->getNumResults(); + rewriter.replaceOp(precedingOp, newResults.take_front(splitAt)); + rewriter.replaceOp(op, newResults.drop_front(splitAt)); + return success(); + } +}; + +struct EliminateDuplicateCstrBroadcastableOps + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, + PatternRewriter &rewriter) const override { + // Search for previous occurence of the same constraint. + Operation *it = op->getPrevNode(); + while (it != nullptr) { + if (auto candidate = llvm::dyn_cast(it)) { + if (candidate.getShapes() == op.getShapes()) { + rewriter.replaceOp(op, candidate.getResult()); + return success(); + } + } + it = it->getPrevNode(); + } + + return failure(); + } +}; + +void populateMergeAssumingOpsPatterns(MLIRContext *context, + RewritePatternSet *patterns) { + patterns->add< + EliminateDuplicateCstrBroadcastableOps, + InlineBroadcastedShapeOperandsPattern, + MergeAssumingOpsPattern, MoveElementwiseOpsDownIntoAssumingOpPattern, + MoveElementwiseOpsUpIntoAssumingOpPattern, + MoveUpIntoAssumingOpPattern, + MoveUpIntoAssumingOpPattern, + MoveUpIntoAssumingOpPattern, + MoveUpOutOfAssumingOpPattern, + MoveUpOutOfAssumingOpPattern, + MoveUpOutOfAssumingOpPattern, ShapeReificationPattern>( + context); + mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns, + context); + mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context); + shape::AssumingAllOp::getCanonicalizationPatterns(*patterns, context); + shape::AssumingOp::getCanonicalizationPatterns(*patterns, context); + shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context); + shape::CstrBroadcastableOp::getCanonicalizationPatterns(*patterns, context); + tensor::CastOp::getCanonicalizationPatterns(*patterns, context); +} + +struct MergeAssumingOpsPass + : public impl::MergeAssumingOpsPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateMergeAssumingOpsPatterns(ctx, &patterns); + GreedyRewriteConfig config; + config.maxIterations = GreedyRewriteConfig::kNoLimit; + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index 45e248ceb904ff..d9dca26c8ce3a3 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -38,6 +38,9 @@ limitations under the License. #define GEN_PASS_DECL_PROPAGATESHAPEKNOWLEDGETOKERNELS #define GEN_PASS_DECL_FUSEINNERPARALLELLOOPSPASS #define GEN_PASS_DECL_COPYCLEANUPPASS +#define GEN_PASS_DECL_SHAPESIMPLIFICATIONPASS +#define GEN_PASS_DECL_MERGEASSUMINGOPSPASS +#define GEN_PASS_DECL_BROADCASTPROPAGATIONPASS namespace mlir { namespace kernel_gen { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td index 4f92be70d25397..9bd6fb8b2e8bf8 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -137,4 +137,22 @@ def CopyCleanupPass : Pass<"copy-cleanup", "mlir::func::FuncOp"> { }]; } +def ShapeSimplificationPass + : Pass<"shape-simplification", "mlir::func::FuncOp"> { + let summary = "Simplify shape ops"; +} + +def MergeAssumingOpsPass : Pass<"mhlo-merge-assuming-ops", "func::FuncOp"> { + let summary = "Prepare moving dynamic broadcasts up over element-wise " + "operations and broadcast the operands rather than the result. This will " + "eventually allow for larger fusions."; +} + +def BroadcastPropagationPass : Pass<"mhlo-broadcast-propagation", "func::FuncOp"> { + let summary = "Move dynamic broadcasts up over element-wise operations and " + "broadcast the operands rather than the result. This will eventually allow " + "for larger fusions."; +} + + #endif // TF_KERNEL_GEN_PASSES diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_simplification_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_simplification_pass.cc new file mode 100644 index 00000000000000..b5ceec7f48e8fc --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_simplification_pass.cc @@ -0,0 +1,253 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains the patterns to simplify shape ops that were deemed not +// suitable for shape op canonicalization in MLIR Core. + +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" + +namespace mlir { +namespace kernel_gen { + +#define GEN_PASS_DEF_SHAPESIMPLIFICATIONPASS +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +namespace { + +using shape::BroadcastOp; +using shape::ConstShapeOp; +using shape::ShapeOfOp; + +// Try to remove operands from broadcasts that don't contribute to the final +// result. +struct BroadcastRemoveSubsumedOperandsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BroadcastOp op, + PatternRewriter &rewriter) const override { + // First collect the static components when joining all shapes. The + // resulting vector contains a static dimension if any operand has a static + // non-1 dimension in that position. The remaining dimensions are set to + // dynamic size. + SmallVector knownExtents; + SmallVector, 4> operandExtents; + for (Value shape : op.getShapes()) { + auto &extents = operandExtents.emplace_back(); + if (failed(shape::getShapeVec(shape, extents))) return failure(); + + // Prepend dynamic dims if sizes don't match. + if (extents.size() > knownExtents.size()) { + knownExtents.insert(knownExtents.begin(), + extents.size() - knownExtents.size(), + ShapedType::kDynamic); + } + + for (size_t i = 0, e = extents.size(); i != e; ++i) { + int64_t extent = extents[e - i - 1]; + if (extent != ShapedType::kDynamic && extent != 1) { + int64_t &knownExtent = knownExtents[knownExtents.size() - i - 1]; + // A dynamic dimension is subsumed by a static one, but bail out for + // known conflicting shapes. + if (knownExtent != extent && knownExtent != ShapedType::kDynamic) + return failure(); + knownExtent = extent; + } + } + } + + // If we've figured out all shapes to be constants we're done. + if (!llvm::is_contained(knownExtents, ShapedType::kDynamic)) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), rewriter.getIndexTensorAttr(knownExtents)); + return success(); + } + + // If only some dimensions are known see if any of the operands can be + // removed without affecting the result. + SmallVector filteredOperands; + for (auto tuple : llvm::zip(op.getShapes(), operandExtents)) { + Value shape = std::get<0>(tuple); + auto &extents = std::get<1>(tuple); + + // An operand can't be dead if it's the only operand of the maximum rank. + // Removing it would reduce the rank of the output. + if (llvm::count_if(operandExtents, [&](ArrayRef op) { + return op.size() >= extents.size(); + }) <= 1) { + filteredOperands.push_back(shape); + continue; + } + + for (size_t i = 0, e = extents.size(); i != e; ++i) { + int64_t extent = extents[e - i - 1]; + // A dimension of an operand can be subsumed if it's + // - a 1 dimension. All other operands will have 1 dims or better. + if (extent == 1) continue; + + // - a dynamic dim but the result is known to be constant. + int64_t knownExtent = knownExtents[knownExtents.size() - i - 1]; + assert(knownExtent != 1); + if (knownExtent != ShapedType::kDynamic && + extent == ShapedType::kDynamic) + continue; + + // - a constant non-1 dimension equal to the "known" dim. + // In this case we also have to check whether this operand is the only + // contributor of that constant. + if (knownExtent != ShapedType::kDynamic && extent == knownExtent && + llvm::count_if(operandExtents, [&](ArrayRef operandShape) { + return i < operandShape.size() && + operandShape[operandShape.size() - i - 1] == knownExtent; + }) > 1) + continue; + + filteredOperands.push_back(shape); + break; + } + } + if (filteredOperands.size() != op.getShapes().size()) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + filteredOperands); + return success(); + } + return failure(); + } +}; + +// Convert cases like: +// ``` +// %1 = shape.shape_of %arg0 : tensor -> tensor<3xindex> +// %2 = shape.shape_of %arg1 : tensor -> tensor<3xindex> +// %3 = shape.broadcast %1, %2 : tensor<3xindex>, tensor<3xindex> +// -> tensor<3xindex> +// %result = tensor.extract %3[%c2] : tensor<3xindex> +// ``` +// to +// +// ``` +// %result = tensor.dim %arg0[%c2] : tensor +// ``` +struct ExtractFromBroadcastedTensorCanonicalizationPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp op, + PatternRewriter &rewriter) const override { + auto broadcastOp = op.getTensor().getDefiningOp(); + if (!broadcastOp) return failure(); + + // Confirm that there is a constant index. This is required, so we can + // confirm the DimOp's input will define the resulting broadcasted shape in + // that dimension. + auto index = + op.getIndices().front().getDefiningOp(); + if (!index) return failure(); + auto idx = index.value(); + + // Iterate through the operands with 3 considerations in this order: + // 1. If a static, non-1 dimension is seen, we know this to be the + // broadcasted result + // 2. If a single dynamic dimension is seen, we know this to be the + // broadcasted result (with a possibly 1 or non-1 result) + // 3. If no dynamic dimensions and no non-1 static dimensions are seen, we + // know the result to be 1 + // + // Iterate through all operands, keeping track of dynamic dimensions and + // returning immediately if a non-1 static dimension is seen. + ShapeOfOp dynamicShape; + int64_t numDynamic = 0; + for (auto shape : broadcastOp.getShapes()) { + auto shapeOfOp = shape.getDefiningOp(); + if (!shapeOfOp) return failure(); + auto shapedType = + mlir::cast(shapeOfOp->getOperandTypes().front()); + + // Abort on the existence of unranked shapes as they require more logic. + if (!shapedType.hasRank()) return failure(); + if (shapedType.getRank() <= idx) continue; + + // Only consider dynamic dimensions after the loop because any non-1 + // static dimension takes precedence. + if (shapedType.isDynamicDim(idx)) { + dynamicShape = shapeOfOp; + numDynamic++; + continue; + } + + if (shapedType.getDimSize(idx) == 1) continue; + + // Return as soon as we see a non-1 static dim. + rewriter.replaceOpWithNewOp( + op, shapedType.getDimSize(idx)); + return success(); + } + if (numDynamic > 1) return failure(); + + // Replace with the single dynamic dimension or 1. + if (dynamicShape) { + rewriter.replaceOpWithNewOp(op, dynamicShape.getArg(), + index); + } else { + rewriter.replaceOpWithNewOp(op, 1); + } + return success(); + } +}; + +struct ShapeSimplificationPass + : public impl::ShapeSimplificationPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(&getContext()); + + for (auto op : context->getRegisteredOperations()) { + if (isa(op.getDialect())) + op.getCanonicalizationPatterns(patterns, context); + } + + patterns.add(context); + + auto func = getOperation(); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.cc b/tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.cc similarity index 98% rename from tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.cc rename to tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.cc index 46b3a5500052a9..db21d257cd58f5 100644 --- a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.cc +++ b/tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include diff --git a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h b/tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h similarity index 91% rename from tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h rename to tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h index b3da62caa95e5e..ef67186d206644 100644 --- a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h +++ b/tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_TF_MLIR_TRANSLATE_CL_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_TF_MLIR_TRANSLATE_CL_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_TF_MLIR_TRANSLATE_CL_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_TF_MLIR_TRANSLATE_CL_H_ // This file contains command-line options aimed to provide the parameters // required by the TensorFlow Graph(Def) to MLIR module conversion. It is only @@ -51,4 +51,4 @@ extern llvm::cl::opt set_original_tf_func_name; extern llvm::cl::opt export_entry_func_to_flib; extern llvm::cl::opt export_original_tf_func_name; -#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_TF_MLIR_TRANSLATE_CL_H_ +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_TF_MLIR_TRANSLATE_CL_H_ diff --git a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tools/tf_mlir_translate_registration.cc similarity index 96% rename from tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_registration.cc rename to tensorflow/compiler/mlir/tools/tf_mlir_translate_registration.cc index 4a07a184bbffb9..7d14d3e954b5f9 100644 --- a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tools/tf_mlir_translate_registration.cc @@ -21,8 +21,8 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tools/file_tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include "tensorflow/core/framework/graph.pb.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD index 238781aa6455eb..4bc56d2d1b429e 100644 --- a/tensorflow/compiler/mlir/tosa/BUILD +++ b/tensorflow/compiler/mlir/tosa/BUILD @@ -4,6 +4,7 @@ # https://github.com/llvm/llvm-project/blob/main/mlir/docs/Dialects/TOSA.md load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") # TODO: Tighten visibility once targets are at the right granularity. @@ -85,6 +86,7 @@ cc_library( "//tensorflow/core/kernels:conv_grad_shape_utils", "//tensorflow/lite/kernels/internal:reference_base", "@com_google_absl//absl/status", + "@gemmlowp", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", @@ -252,3 +254,47 @@ cc_library( "@llvm-project//mlir:Transforms", ], ) + +tf_cc_binary( + name = "tf-tosa-opt", + testonly = True, + srcs = ["tf_tosa_opt.cc"], + tags = ["tf_tosa"], + deps = [ + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir:passes", + "//tensorflow/compiler/mlir:register_common_dialects", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite:tf_tfl_passes", # buildcleaner:keep + "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", + "//tensorflow/compiler/mlir/tensorflow:mlprogram_util", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_test_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_graph_optimization_pass", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", # buildcleaner:keep + "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops", + "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:runtime_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms/sparsecore:sparsecore_passes", + "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", + "//tensorflow/compiler/mlir/tf2xla/internal/passes:clustering_passes", + "//tensorflow/compiler/mlir/tf2xla/internal/passes:mlir_to_graph_passes", + "//tensorflow/compiler/mlir/tf2xla/transforms:tf_xla_passes", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", + "//tensorflow/compiler/mlir/tosa:tf_passes", + "//tensorflow/compiler/mlir/tosa:tf_tfl_passes", + "//tensorflow/compiler/mlir/tosa:tfl_passes", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@local_xla//xla/mlir/framework/ir:xla_framework", + "@local_xla//xla/mlir/framework/transforms:passes", + "@local_xla//xla/mlir_hlo:all_passes", + ], +) + +filegroup( + name = "litfiles", + srcs = glob(["runlit*py"]), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/mlir/tosa/glob_lit_test.bzl b/tensorflow/compiler/mlir/tosa/glob_lit_test.bzl new file mode 100644 index 00000000000000..c5c72a3b9610ad --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/glob_lit_test.bzl @@ -0,0 +1,151 @@ +# Test definitions for Lit, the LLVM test runner. +# +# This is reusing the LLVM Lit test runner in the interim until the new build +# rules are upstreamed. +# TODO(b/136126535): remove this custom rule. +"""Lit runner globbing test +""" + +load("@bazel_skylib//lib:paths.bzl", "paths") +load( + "@local_xla//xla:lit.bzl", + "lit_script_with_xla_gpu_cuda_data_dir", +) + +# Default values used by the test runner. +_default_test_file_exts = ["mlir", ".pbtxt", ".td"] +_default_driver = "@llvm-project//mlir:run_lit.sh" +_default_size = "small" +_default_tags = [] + +# These are patterns which we should never match, for tests, subdirectories, or +# test input data files. +_ALWAYS_EXCLUDE = [ + "**/LICENSE.txt", + "**/README.txt", + "**/lit.local.cfg", + # Exclude input files that have spaces in their names, since bazel + # cannot cope with such "targets" in the srcs list. + "**/* *", + "**/* */**", +] + +def _run_lit_test(name, data, size, tags, driver, features, exec_properties): + """Runs lit on all tests it can find in `data` under tensorflow/compiler/mlir. + + Note that, due to Bazel's hermetic builds, lit only sees the tests that + are included in the `data` parameter, regardless of what other tests might + exist in the directory searched. + + Args: + name: str, the name of the test, including extension. + data: [str], the data input to the test. + size: str, the size of the test. + tags: [str], tags to attach to the test. + driver: str, label of the driver shell script. + Note: use of a custom driver is not currently supported + and specifying a default driver will abort the tests. + features: [str], list of extra features to enable. + """ + + # Disable tests on windows for now, to enable testing rest of all xla and mlir. + native.py_test( + name = name, + srcs = ["@llvm-project//llvm:lit"], + tags = tags + ["no_pip", "no_windows"], + args = [ + "tensorflow/compiler/mlir/tosa/" + paths.basename(data[-1]) + " --config-prefix=runlit -v", + ] + features, + data = data + [ + "//tensorflow/compiler/mlir/tosa:litfiles", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:not", + ], + deps = ["@pypi_lit//:pkg"], + size = size, + main = "lit.py", + exec_properties = exec_properties, + ) + +def glob_lit_tests( + name = None, + exclude = [], + test_file_exts = _default_test_file_exts, + default_size = _default_size, + size_override = {}, + data = [], + per_test_extra_data = {}, + default_tags = _default_tags, + tags_override = {}, + driver = _default_driver, + features = [], + exec_properties = {}, + use_lit_test_suite = None, # @unused + hermetic_cuda_data_dir = None): + """Creates all plausible Lit tests (and their inputs) under this directory. + + Args: + name: str, name of the test_suite rule to generate for running all tests. + exclude: [str], paths to exclude (for tests and inputs). + test_file_exts: [str], extensions for files that are tests. + default_size: str, the test size for targets not in "size_override". + size_override: {str: str}, sizes to use for specific tests. + data: [str], additional input data to the test. + per_test_extra_data: {str: [str]}, extra data to attach to a given file. + default_tags: [str], additional tags to attach to the test. + tags_override: {str: str}, tags to add to specific tests. + driver: str, label of the driver shell script. + Note: use of a custom driver is not currently supported + and specifying a default driver will abort the tests. + features: [str], list of extra features to enable. + exec_properties: a dictionary of properties to pass on. + hermetic_cuda_data_dir: string. If set, the tests will be run with a + `--xla_gpu_cuda_data_dir` flag set to the hermetic CUDA data directory. + use_lit_test_suite: unused. For compatibility. + """ + + # Ignore some patterns by default for tests and input data. + exclude = _ALWAYS_EXCLUDE + exclude + + tests = native.glob( + ["*." + ext for ext in test_file_exts], + exclude = exclude, + ) + + # Run tests individually such that errors can be attributed to a specific + # failure. + all_tests = [] + for curr_test in tests: + final_test_name = curr_test + if hermetic_cuda_data_dir: + output_file = "with_xla_gpu_cuda_data_dir_{}".format(curr_test) + rule_name = "script_{}".format(output_file) + lit_script_with_xla_gpu_cuda_data_dir( + rule_name, + curr_test, + output_file, + hermetic_cuda_data_dir, + ) + final_test_name = output_file + all_tests.append(final_test_name + ".test") + + # Instantiate this test with updated parameters. + _run_lit_test( + name = final_test_name + ".test", + data = data + [final_test_name] + + per_test_extra_data.get(curr_test, []), + size = size_override.get(curr_test, default_size), + tags = default_tags + tags_override.get(curr_test, []), + driver = driver, + features = features, + exec_properties = exec_properties, + ) + + # TODO: remove this check after making it a required param. + if name: + native.test_suite( + name = name, + tests = all_tests, + tags = ["manual"], + ) diff --git a/tensorflow/compiler/mlir/tosa/runlit.cfg.py b/tensorflow/compiler/mlir/tosa/runlit.cfg.py new file mode 100644 index 00000000000000..ccf0852be8f655 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/runlit.cfg.py @@ -0,0 +1,71 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Lit runner configuration.""" + +import os +import platform +import sys +import lit.formats +from lit.llvm import llvm_config +from lit.llvm.subst import ToolSubst + +# Lint for undefined variables is disabled as config is not defined inside this +# file, instead config is injected by way of evaluating runlit.cfg.py from +# runlit.site.cfg.py which in turn is evaluated by lit.py. The structure is +# common for lit tests and intended to only persist temporarily (b/136126535). +# pylint: disable=undefined-variable +# Configuration file for the 'lit' test runner. + +# name: The name of this test suite. +config.name = 'MLIR ' + os.path.basename(config.mlir_test_dir) + +config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) + +# suffixes: A list of file extensions to treat as test files. +config.suffixes = ['.cc', '.hlo', '.json', '.mlir', '.pbtxt', '.py'] + +# test_source_root: The root path where tests are located. +config.test_source_root = config.mlir_test_dir + +# test_exec_root: The root path where tests should be run. +config.test_exec_root = os.environ['RUNFILES_DIR'] + +if platform.system() == 'Windows': + tool_patterns = [ + ToolSubst('FileCheck.exe', unresolved='fatal'), + # Handle these specially as they are strings searched for during testing. + ToolSubst('count.exe', unresolved='fatal'), + ToolSubst('not.exe', unresolved='fatal') + ] + + llvm_config.config.substitutions.append( + ('%python', '"%s"' % (sys.executable))) + + llvm_config.add_tool_substitutions(tool_patterns, + [llvm_config.config.llvm_tools_dir]) +else: + llvm_config.use_default_substitutions() + +# Tweak the PATH to include the tools dir. +llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) + +tool_dirs = config.mlir_tf_tools_dirs + [ + config.mlir_tools_dir, config.llvm_tools_dir +] +tool_names = [ + 'tf-tosa-opt', +] +tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] +llvm_config.add_tool_substitutions(tools, tool_dirs) +# pylint: enable=undefined-variable diff --git a/tensorflow/compiler/mlir/tosa/runlit.site.cfg.py b/tensorflow/compiler/mlir/tosa/runlit.site.cfg.py new file mode 100644 index 00000000000000..3f17710069eb6c --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/runlit.site.cfg.py @@ -0,0 +1,63 @@ +# Copyright 2019 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Lit runner site configuration.""" + +import os +import platform +import lit.llvm + +# Handle the test srcdir for platforms. On windows, things are weird with bazel. +if platform.system() == 'Windows': + srcdir = os.environ['TEST_SRCDIR'] + real_test_srcdir = srcdir[:srcdir.find('tensorflow/compiler/mlir/tosa')] + external_srcdir = os.path.join(real_test_srcdir, 'external') +else: + real_test_srcdir = os.environ['TEST_SRCDIR'] + external_srcdir = real_test_srcdir + +# Lint for undefined variables is disabled as config is not defined inside this +# file, instead config is injected by lit.py. The structure is common for lit +# tests and intended to only persist temporarily (b/136126535). +# pylint: disable=undefined-variable +config.llvm_tools_dir = os.path.join(external_srcdir, 'llvm-project', 'llvm') +config.mlir_obj_root = os.path.join(real_test_srcdir) +config.mlir_tools_dir = os.path.join(external_srcdir, 'llvm-project', 'mlir') +# TODO(jpienaar): Replace with suffices in build rule. +config.suffixes = ['.td', '.mlir', '.pbtxt'] + +mlir_tf_tools_dirs = [ + 'tensorflow/compiler/mlir/tosa', +] +config.mlir_tf_tools_dirs = [ + os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], s) + for s in mlir_tf_tools_dirs +] +test_dir = os.environ['TEST_TARGET'] +test_dir = test_dir.strip('/').rsplit(':', 1)[0] +config.mlir_test_dir = os.path.join(real_test_srcdir, + os.environ['TEST_WORKSPACE'], test_dir) + +if platform.system() == 'Windows': + # Configure this to work with msys2, TF's preferred windows bash. + config.lit_tools_dir = '/usr/bin' + +lit.llvm.initialize(lit_config, config) + +# Let the main config do the real work. +lit_config.load_config( + config, + os.path.join( + os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], + 'tensorflow/compiler/mlir/tosa/runlit.cfg.py'))) +# pylint: enable=undefined-variable diff --git a/tensorflow/compiler/mlir/tosa/tests/BUILD b/tensorflow/compiler/mlir/tosa/tests/BUILD index e936d924ef4abb..46a4c1fc752bf3 100644 --- a/tensorflow/compiler/mlir/tosa/tests/BUILD +++ b/tensorflow/compiler/mlir/tosa/tests/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/mlir/tosa:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -22,7 +22,7 @@ filegroup( name = "test_utilities", testonly = True, data = [ - "//tensorflow/compiler/mlir:tf-opt", + "//tensorflow/compiler/mlir/tosa:tf-tosa-opt", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", ], diff --git a/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir b/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir index 8738a07ac60400..34d7007ea6cbb6 100644 --- a/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --tosa-convert-tfl-uint8 --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --tosa-convert-tfl-uint8 --verify-each %s | FileCheck %s + // Operations for testing --tosa-convert-tfl-uint8 diff --git a/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir b/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir index 5d7c3316b19ef2..ced3651bff327f 100644 --- a/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --split-input-file --pass-pipeline='builtin.module(func.func(tosa-tflite-convert-function-metadata))' %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --pass-pipeline='builtin.module(func.func(tosa-tflite-convert-function-metadata))' %s | FileCheck %s + module attributes {tfl.schema_version = 3 : i32} { // CHECK: func.func @main( diff --git a/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir b/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir index f00c0358fdac67..c41b202edc8faf 100644 --- a/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --tosa-fuse-bias-tf --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --tosa-fuse-bias-tf --verify-each %s | FileCheck %s + // Operations for testing --tosa-fuse-bias-tf diff --git a/tensorflow/compiler/mlir/tosa/tests/lower-complex-types.mlir b/tensorflow/compiler/mlir/tosa/tests/lower-complex-types.mlir index c9b59c2201c313..3985720caf1d19 100644 --- a/tensorflow/compiler/mlir/tosa/tests/lower-complex-types.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/lower-complex-types.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --split-input-file --tosa-lower-complex-types --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tosa-lower-complex-types --verify-each %s | FileCheck %s + // CHECK-LABEL: test_complex_input // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4x4x2xf32> diff --git a/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir b/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir index 28f3192bae2f6d..8952d5fcd5ef9a 100644 --- a/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --tfl-to-tosa-pipeline=target-compilation-backend %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --tfl-to-tosa-pipeline=target-compilation-backend %s | FileCheck %s + // CHECK: tensor<1x8x8x3xf32> {ml_program.identifier = "a"} // CHECK-SAME: tensor<1x8x8x3xf32> {ml_program.identifier = "b"} diff --git a/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir b/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir index 8feb41f2631f0f..cf4dacffe76fef 100644 --- a/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --split-input-file --pass-pipeline='builtin.module(tflite-retain-call-once-funcs)' %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --pass-pipeline='builtin.module(tflite-retain-call-once-funcs)' %s | FileCheck %s + // CHECK-LABEL: module { module { diff --git a/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir b/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir index bac7879370fdd8..b595c032bef9ac 100644 --- a/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --split-input-file --tosa-strip-quant-types --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tosa-strip-quant-types --verify-each %s | FileCheck %s + // ----- diff --git a/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir b/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir index 5f75b923739d90..e607798da0d622 100644 --- a/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --pass-pipeline='builtin.module(tosa-tflite-strip-module-metadata,func.func(tosa-tflite-strip-function-metadata))' %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --pass-pipeline='builtin.module(tosa-tflite-strip-module-metadata,func.func(tosa-tflite-strip-function-metadata))' %s | FileCheck %s + // CHECK-LABEL: module { // CHECK-NOT: tf.schema_version diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-tfl-to-tosa-pipeline.mlir index 7eadb79b757bd4..fc1403205ca34f 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-tfl-to-tosa-pipeline.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + // These tests focus on TensorFlow and TensorFlow Lite hybrid lowering and focus // on tfl.custom operations that are Flex ops. diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir index da75738e80d1cf..0bd0eeb0285d1d 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir @@ -1,7 +1,7 @@ -// RUN: tf-opt --tf-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa -// RUN: tf-opt --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --tf-to-tosa-pipeline --verify-each %s | FileCheck %s + +// RUN: tf-tosa-opt --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + // Operations for testing tf-to-tosa-pipeline // TODO: These tests are fairly minimal. Expand the checks to be more robust. @@ -106,7 +106,7 @@ func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> te // ----- // CHECK-LABEL: test_real_div -// CHECK: %[[VAR0:.*]] = tosa.int_div %arg0, %arg1 +// CHECK: %[[VAR0:.*]] = tosa.intdiv %arg0, %arg1 func.func @test_real_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { %2 = "tf.RealDiv"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> func.return %2 : tensor<13x21x3xi32> @@ -117,12 +117,12 @@ func.func @test_real_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) // CHECK-LABEL: func.func @test_floor_div( // CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xi32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> -// CHECK: %[[VAL_5:.*]] = tosa.int_div %[[VAL_0]], %[[VAL_1]] : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor) -> tensor<13x21x3xi32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]], %[[VAL_2]] : (tensor<13x1x3xi32>, tensor<13x21x3xi32>, tensor) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> +// CHECK: %[[VAL_5:.*]] = tosa.intdiv %[[VAL_0]], %[[VAL_1]] : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor<1xi8>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]], %[[VAL_2]] : (tensor<13x1x3xi32>, tensor<13x21x3xi32>, tensor<1xi8>) -> tensor<13x21x3xi32> // CHECK: %[[VAL_8:.*]] = tosa.equal %[[VAL_0]], %[[VAL_7]] : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1> // CHECK: %[[VAL_9:.*]] = tosa.logical_not %[[VAL_8]] : (tensor<13x21x3xi1>) -> tensor<13x21x3xi1> // CHECK: %[[VAL_10:.*]] = tosa.greater %[[VAL_3]], %[[VAL_6]] : (tensor<1x1x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1> @@ -1053,6 +1053,20 @@ func.func @test_gather_nd(%arg0: tensor<13x21x3xf32>) -> tensor<6x7x21x3xf32> { func.return %1 : tensor<6x7x21x3xf32> } +// ----- + +// CHECK-LABEL: test_scatter_nd +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x224x512xf32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x2xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = tosa.reduce_sum %[[VAR2:.*]] {axis = 1 : i32} : (tensor<1x2xi32>) +// CHECK-DAG: %[[VAR4:.*]] = tosa.scatter %[[VAR1:.*]], %[[VAR3:.*]], %arg0 : (tensor<1x224x512xf32>, tensor<1x1xi32>, tensor<1x1x512xf32>) +// CHECK: return %[[VAR4]] +func.func @test_scatter_nd(%arg0: tensor<1x1x512xf32>) -> tensor<1x224x512xf32> { + %shape = "tf.Const"() {device = "", value = dense<[1, 224, 512]> : tensor<3xi32>} : () -> tensor<3xi32> + %indices = "tf.Const"() {device = "", value = dense<[[[0, 0]]]>: tensor<1x1x2xi32>} : () -> tensor<1x1x2xi32> + %1 = "tf.ScatterNd"(%indices, %arg0, %shape) {device = ""} : (tensor<1x1x2xi32>, tensor<1x1x512xf32>, tensor<3xi32>) -> tensor<1x224x512xf32> + func.return %1 : tensor<1x224x512xf32> +} // ----- diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-unequal-ranks.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-unequal-ranks.mlir index e34f0501c5533a..97ebeeac782a47 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-unequal-ranks.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-unequal-ranks.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --split-input-file --tf-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tf-to-tosa-pipeline --verify-each %s | FileCheck %s + // Test tf legalization that produce TOSA ResultsBroadcastableShape operators with unequal ranks // ----- @@ -79,7 +79,7 @@ func.func @test_logical_or(%arg0: tensor<8x13x21x3xi1>, %arg1: tensor<13x21x1xi1 // ----- // CHECK-LABEL: test_floor_div -// CHECK: tosa.int_div +// CHECK: tosa.intdiv // CHECK: tosa.select func.func @test_floor_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> { %2 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> @@ -89,7 +89,7 @@ func.func @test_floor_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<1x13x1x3xi32 // ----- // CHECK-LABEL: test_real_div -// CHECK: tosa.int_div +// CHECK: tosa.intdiv func.func @test_real_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> { %2 = "tf.RealDiv"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> func.return %2 : tensor<1x13x21x3xi32> diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-dequantize_softmax.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-dequantize_softmax.mlir index 936dbf7c69c630..28c764de62ab39 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-dequantize_softmax.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-dequantize_softmax.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --tosa-dequantize-tfl-softmax %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --tosa-dequantize-tfl-softmax %s | FileCheck %s + // ----- diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir index a14fe7e43f4bdc..1bc7e084fdbc1c 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --pass-pipeline='builtin.module(func.func(tosa-legalize-tfl{disable-patterns=TFLConv2D,TFLSoftmax, enable-patterns=TFLFullyConnected,TFLTranspose}))' %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --pass-pipeline='builtin.module(func.func(tosa-legalize-tfl{disable-patterns=TFLConv2D,TFLSoftmax, enable-patterns=TFLFullyConnected,TFLTranspose}))' %s | FileCheck %s + // ----- @@ -30,11 +30,10 @@ func.func @test_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {values = dense<[28, 1, 1, 19]> : tensor<4xindex>} // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {values = dense<[14, 28]> : tensor<2xindex>} // CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<28xf32>}> // CHECK: %[[VAR1:.*]] = tosa.transpose %arg1 {perms = array} // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[CONST0]] // CHECK-DAG: %[[VAR3:.*]] = tosa.reshape %[[VAR1]], %[[CONST1]] -// CHECK-DAG: %[[VAR4:.*]] = tosa.conv2d %[[VAR2]], %[[VAR3]], %[[VAR0]], %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK-DAG: %[[VAR4:.*]] = tosa.conv2d %[[VAR2]], %[[VAR3]], %[[CONST3]], %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} // CHECK: %[[VAR5:.*]] = tosa.reshape %[[VAR4]], %[[CONST2]] func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[1, 0]> : tensor<2xi32> diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 2384330ea0b236..c217547b4a783b 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -1,7 +1,7 @@ -// RUN: tf-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa -// RUN: tf-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + +// RUN: tf-tosa-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + // Operations for testing tfl-to-tosa-pipeline @@ -61,9 +61,8 @@ func.func @test_conv2d_slicing(%arg0: tensor<2x32x32x8xf32>, %arg1: tensor<16x3x // ----- // CHECK-LABEL: test_transpose_conv2d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<16xf32>}> // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> -// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR0]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} +// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR1]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x32x32x16xf32> { %cst = arith.constant dense<[1, 32, 32, 16]> : tensor<4xi32> %cst_1 = "tfl.no_value"() {value = unit} : () -> none @@ -74,9 +73,8 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16 // ----- // CHECK-LABEL: test_transpose_conv2d_relu -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<16xf32>}> // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> -// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR0]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} +// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR1]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} // CHECK: %[[VAR3:.*]] = tosa.clamp %[[VAR2]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} func.func @test_transpose_conv2d_relu(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x32x32x16xf32> { %cst = arith.constant dense<[1, 32, 32, 16]> : tensor<4xi32> @@ -87,6 +85,21 @@ func.func @test_transpose_conv2d_relu(%arg0: tensor<1x32x32x8xf32>, %cst_0: tens // ----- +// CHECK-LABEL: test_transpose_conv2d_outpad +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR0]], %[[VAR0]], %[[VAR0]] {acc_type = f32, out_pad = array, stride = array} +func.func @test_transpose_conv2d_outpad(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>) -> tensor<1x33x33x16xf32> { + %cst = arith.constant dense<[1, 33, 33, 16]> : tensor<4xi32> + %cst_1 = "tfl.no_value"() {value = unit} : () -> none + %0 = "tfl.transpose_conv"(%cst, %arg1, %arg0, %cst_1) + {padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, + fused_activation_function = "NONE"} + : (tensor<4xi32>, tensor<16x1x1x8xf32>, tensor<1x32x32x8xf32>, none) -> tensor<1x33x33x16xf32> + func.return %0 : tensor<1x33x33x16xf32> +} + +// ----- + // CHECK-LABEL: test_conv2d_qi8 // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<{{.*}}> : tensor<16x2x2x8xi8>}> // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0> : tensor<16xi32>}> @@ -104,7 +117,7 @@ func.func @test_conv2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform : tensor<16x2x2x8xi8>}> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<16xi8>}> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<16xi32>}> // CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK: %[[VAL_6:.*]] = tosa.conv2d %arg0, %[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_5]] {acc_type = i32, dilation = array, pad = array, stride = array} // CHECK: %[[VAL_7:.*]] = tosa.rescale %[[VAL_6]] @@ -284,10 +297,9 @@ func.func @test_depthwise_conv2d_slicing(%arg0: tensor<1x32x32x8xf32>, %arg1: te // CHECK-LABEL: test_conv3d // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x2x7x7x2xf32> // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3x3x2x4xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<4xf32>}> // CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]] {perms = array} -// CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_2]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_4]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} func.func @test_conv3d(%arg0: tensor<2x2x7x7x2xf32>, %arg1: tensor<2x3x3x2x4xf32>) -> tensor<2x2x7x7x4xf32> { %cst = "tfl.no_value"() {value} : () -> none %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<2x2x7x7x2xf32>, tensor<2x3x3x2x4xf32>, none) -> tensor<2x2x7x7x4xf32> @@ -299,10 +311,9 @@ func.func @test_conv3d(%arg0: tensor<2x2x7x7x2xf32>, %arg1: tensor<2x3x3x2x4xf32 // CHECK-LABEL: test_conv3d_dynamic // CHECK-SAME: %[[VAL_0:.*]]: tensor // CHECK-SAME: %[[VAL_1:.*]]: tensor<3x1x1x8x16xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<16xf32>}> // CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]] {perms = array} -// CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_2]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_4]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} func.func @test_conv3d_dynamic(%arg0: tensor, %arg1: tensor<3x1x1x8x16xf32>) -> tensor<*xf32> { %cst = "tfl.no_value"() {value} : () -> none %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor, tensor<3x1x1x8x16xf32>, none) -> tensor<*xf32> @@ -345,13 +356,12 @@ func.func @test_conv3d_slicing(%arg0: tensor<1x32x32x32x8xf32>, %arg1: tensor<3x // CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0.0156862643> : tensor<1x1x1x1x1xf32>} // CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1.11982894> : tensor<1x1x1x1x1xf32>} // CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<-4> : tensor<1x1x1x1x1xi32>} -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<34xf32>} // CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> -// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[BIAS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_0]] // CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]], %[[SHIFT]] // CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]] {perms = array} -// CHECK: %[[VAL_12:.*]] = tosa.conv3d %[[VAL_10]], %[[VAL_11]], %[[VAL_6]], %[[ZP]], %[[ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_12:.*]] = tosa.conv3d %[[VAL_10]], %[[VAL_11]], %[[BIAS_ZP]], %[[BIAS_ZP]], %[[BIAS_ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} // CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_12]], %[[VAL_4]], %[[SHIFT]] // CHECK: %[[VAL_14:.*]] = tosa.cast %[[VAL_13]] // CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_14]], %[[VAL_5]] @@ -367,6 +377,17 @@ func.func @test_conv3d_qi8(%arg0: tensor<1x4x8x21x17x!quant.uniform : tensor<16xi48>}> : () -> tensor<16xi48> +// CHECK: tosa.conv3d {{.+}}, %[[BIAS]], %{{.+}} {acc_type = i48, {{.+}}} : {{.+}} -> tensor<1x15x15x15x16xi48> +func.func @test_conv3d_qi16(%input: tensor<1x32x32x32x8x!quant.uniform>, %filter: tensor<3x3x3x8x16x!quant.uniform>) -> tensor<1x15x15x15x16x!quant.uniform> { + %bias = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform>, value = dense<123> : tensor<16xi16>} : () -> tensor<16x!quant.uniform> + %0 = "tfl.conv_3d"(%input, %filter, %bias) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x32x32x32x8x!quant.uniform>, tensor<3x3x3x8x16x!quant.uniform>, tensor<16x!quant.uniform>) -> tensor<1x15x15x15x16x!quant.uniform> + func.return %0 : tensor<1x15x15x15x16x!quant.uniform> +} + +// ----- + // CHECK-LABEL: test_add // CHECK: %[[VAR0:.*]] = tosa.add %arg0, %arg1 func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { @@ -425,9 +446,31 @@ func.func @test_mul_unranked(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x1x1xf32 // CHECK-LABEL: test_exp // CHECK: %[[VAR0:.*]] = tosa.exp %arg0 -func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { - %0 = "tfl.exp"(%arg0) : (tensor<13x21x3xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> +func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tfl.exp"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +} + +// ----- + +// CHECK-LABEL: test_exp_qi8 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<256xi8>}> +// CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] +func.func @test_exp_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { + %0 = "tfl.exp"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + func.return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_exp_qi16 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<513xi16>}> +// CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] +func.func @test_exp_qi16(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { + %0 = "tfl.exp"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + func.return %0 : tensor<13x21x3x!quant.uniform> } // ----- @@ -444,7 +487,7 @@ func.func @test_rcp(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_div // CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %arg1 -// CHECK: %[[VAR0:.*]] = tosa.int_div %arg0, %[[RESHAPE]] +// CHECK: %[[VAR0:.*]] = tosa.intdiv %arg0, %[[RESHAPE]] func.func @test_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor) -> tensor<*xi32> { %0 = "tfl.div"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xi32>, tensor) -> tensor<*xi32> func.return %0 : tensor<*xi32> @@ -455,16 +498,16 @@ func.func @test_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor) -> tensor<*x // CHECK-LABEL: func.func @test_floor_div( // CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xi32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor<13x21x3xi32> { -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> -// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_5]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xi32> -// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_0]], %[[VAL_6]] : (tensor<13x21x3xi32>, tensor<1x1x1xi32>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_7:.*]] = tosa.intdiv %[[VAL_0]], %[[VAL_6]] : (tensor<13x21x3xi32>, tensor<1x1x1xi32>) -> tensor<13x21x3xi32> // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_5]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xi32> -// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_0]], %[[VAL_8]], %[[VAL_2]] : (tensor<13x21x3xi32>, tensor<1x1x1xi32>, tensor) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_0]], %[[VAL_8]], %[[VAL_2]] : (tensor<13x21x3xi32>, tensor<1x1x1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32> // CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_5]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xi32> -// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_7]], %[[VAL_2]] : (tensor<1x1x1xi32>, tensor<13x21x3xi32>, tensor) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_7]], %[[VAL_2]] : (tensor<1x1x1xi32>, tensor<13x21x3xi32>, tensor<1xi8>) -> tensor<13x21x3xi32> // CHECK: %[[VAL_12:.*]] = tosa.equal %[[VAL_0]], %[[VAL_11]] : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1> // CHECK: %[[VAL_13:.*]] = tosa.logical_not %[[VAL_12]] : (tensor<13x21x3xi1>) -> tensor<13x21x3xi1> // CHECK: %[[VAL_14:.*]] = tosa.greater %[[VAL_3]], %[[VAL_9]] : (tensor<1x1x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1> @@ -631,6 +674,33 @@ func.func @test_logical_or(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) // ----- +// CHECK-LABEL: test_bitwise_xor_int8 +// CHECK: %[[VAR0:.*]] = tosa.bitwise_xor %arg0, %arg1 : (tensor<1x11x5xi8>, tensor<29x11x5xi8>) -> tensor<29x11x5xi8> +func.func @test_bitwise_xor_int8(%arg0: tensor<1x11x5xi8>, %arg1: tensor<29x11x5xi8>) -> tensor<29x11x5xi8> { + %0 = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<1x11x5xi8>, tensor<29x11x5xi8>) -> tensor<29x11x5xi8> + func.return %0 : tensor<29x11x5xi8> +} + +// ----- + +// CHECK-LABEL: test_bitwise_xor_int16 +// CHECK: %[[VAR0:.*]] = tosa.bitwise_xor %arg0, %arg1 : (tensor<1x11x5xi16>, tensor<29x11x5xi16>) -> tensor<29x11x5xi16> +func.func @test_bitwise_xor_int16(%arg0: tensor<1x11x5xi16>, %arg1: tensor<29x11x5xi16>) -> tensor<*xi16> { + %0 = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<1x11x5xi16>, tensor<29x11x5xi16>) -> tensor<*xi16> + func.return %0 : tensor<*xi16> +} + +// ----- + +// CHECK-LABEL: test_bitwise_xor_int32 +// CHECK: %[[VAR0:.*]] = tosa.bitwise_xor %arg0, %arg1 : (tensor<4x16x1xi32>, tensor<1x16x1xi32>) -> tensor<4x16x1xi32> +func.func @test_bitwise_xor_int32(%arg0: tensor<4x16x1xi32>, %arg1: tensor<1x16x1xi32>) -> tensor<4x16x1xi32> { + %0 = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<4x16x1xi32>, tensor<1x16x1xi32>) -> tensor<4x16x1xi32> + func.return %0 : tensor<4x16x1xi32> +} + +// ----- + // CHECK-LABEL: test_logical_not // CHECK: %[[VAR0:.*]] = tosa.logical_not %arg0 func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> { @@ -709,6 +779,18 @@ func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { // ----- +// CHECK-LABEL: test_reduce_any_dynamic_output +// CHECK-DAG: %[[VAR0:.*]] = tosa.reduce_any %arg0 {axis = 0 : i32} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} +// CHECK: %[[VAR1:.*]] = tosa.reshape %[[VAR0]], %[[VAR10]] +func.func @test_reduce_any_dynamic_output(%arg0: tensor<13x21x3xi1>) -> tensor { + %cst = arith.constant dense<0> : tensor<1xi32> + %0 = "tfl.reduce_any"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xi1>, tensor<1xi32>) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: test_reduce_min // CHECK-DAG: %[[VAR0:.*]] = tosa.reduce_min %arg0 {axis = 0 : i32} // CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} @@ -789,6 +871,21 @@ func.func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // ----- +// CHECK-LABEL: test_reduce_mean_dynamic_output +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.0769230798> : tensor<1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = tosa.reduce_sum %arg0 {axis = 0 : i32} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %[[VAR1]], %[[VAR10]] +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK: %[[VAR4:.*]] = tosa.mul %[[VAR2]], %[[VAR0]], %[[SHIFT]] +func.func @test_reduce_mean_dynamic_output(%arg0: tensor<13x21x3xf32>) -> tensor { + %cst = arith.constant dense<0> : tensor<1xi32> + %0 = "tfl.mean"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: test_reduce_mean_out_of_bounds // CHECK: "tfl.mean" func.func @test_reduce_mean_out_of_bounds(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { @@ -803,7 +900,6 @@ func.func @test_reduce_mean_out_of_bounds(%arg0: tensor<13x21x3xf32>) -> tensor< // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x2x!quant.uniform> // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<31> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<1105078632> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> @@ -906,9 +1002,31 @@ func.func @test_floor(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_log // CHECK: %[[VAR0:.*]] = tosa.log %arg0 -func.func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { - %0 = "tfl.log"(%arg0) : (tensor<13x21x3xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> +func.func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tfl.log"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +} + +// ----- + +// CHECK-LABEL: test_log_qi8 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<256xi8>}> +// CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] +func.func @test_log_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { + %0 = "tfl.log"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + func.return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_log_qi16 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<513xi16>}> +// CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] +func.func @test_log_qi16(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { + %0 = "tfl.log"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + func.return %0 : tensor<13x21x3x!quant.uniform> } // ----- @@ -1158,6 +1276,18 @@ func.func @test_max_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32 // ----- +// CHECK-LABEL: test_max_pool2d_slicing +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[1, 31, 31, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x31x31x8xf32> +// CHECK: %[[VAL_4:.*]] = tosa.max_pool2d %[[VAL_3]] {kernel = array, pad = array, stride = array} : (tensor<1x31x31x8xf32>) -> tensor<1x15x15x8xf32> +func.func @test_max_pool2d_slicing(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { + %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x32x32x8xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: test_reshape // CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[VAR10]] @@ -1758,7 +1888,6 @@ func.func @test_log_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_matmul -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<28xf32>}> // CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<[14, 1, 1, 19]> : tensor<4xindex>} // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {values = dense<[28, 1, 1, 19]> : tensor<4xindex>} // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {values = dense<[14, 28]> : tensor<2xindex>} @@ -1766,7 +1895,7 @@ func.func @test_log_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK: %[[VAR2:.*]] = tosa.transpose %arg1 {perms = array} // CHECK: %[[VAR3:.*]] = tosa.reshape %arg0, %[[CONST0]] // CHECK: %[[VAR4:.*]] = tosa.reshape %[[VAR2]], %[[CONST1]] -// CHECK: %[[VAR5:.*]] = tosa.conv2d %[[VAR3]], %[[VAR4]], %[[VAR1]], %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK: %[[VAR5:.*]] = tosa.conv2d %[[VAR3]], %[[VAR4]], %[[CONST3]], %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} // CHECK: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[CONST2]] func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[1, 0]> : tensor<2xi32> @@ -2426,73 +2555,70 @@ func.func @test_max_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform> -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<5> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<31> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<31> : tensor<1xi8>}> -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<-1010580540> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_7:.*]] = "tosa.const"() <{values = dense<1515870810> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_8:.*]] = "tosa.const"() <{values = dense<536870912> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<4> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<13> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<7> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<1> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<9> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<17> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi16>}> -// CHECK-DAG: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<23> : tensor<1xi8>}> -// CHECK-DAG: %[[VAL_17:.*]] = "tosa.const"() <{values = dense<"0x5{{.*}}"> : tensor<513xi16>}> -// CHECK-DAG: %[[VAL_18:.*]] = "tosa.const"() <{values = dense<"0xE{{.*}}"> : tensor<513xi16>}> -// CHECK-DAG: %[[VAL_19:.*]] = "tosa.const"() <{values = dense<"0x4{{.*}}"> : tensor<513xi16>}> -// CHECK-DAG: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<"0x0{{.*}}"> : tensor<513xi16>}> -// CHECK-DAG: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> -// CHECK-DAG: %[[VAL_22:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> -// CHECK-DAG: %[[VAL_23:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> -// CHECK-DAG: %[[VAL_24:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> -// CHECK-DAG: %[[VAL_25:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]], %[[VAL_24]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<13x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<13x21x3xi32> -// CHECK-DAG: %[[VAL_26:.*]] = tosa.reduce_max %[[VAL_25]] -// CHECK-DAG: %[[VAL_27:.*]] = tosa.sub %[[VAL_25]], %[[VAL_26]] -// CHECK-DAG: %[[VAL_28:.*]] = tosa.rescale %[[VAL_27]], %[[VAL_21]], %[[VAL_16]], %[[VAL_24]], %[[VAL_15]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi16>) -> tensor<13x21x3x!quant.uniform> -// CHECK-DAG: %[[VAL_29:.*]] = tosa.table %[[VAL_28]], %[[VAL_20]] -// CHECK-DAG: %[[VAL_30:.*]] = tosa.table %[[VAL_28]], %[[VAL_19]] -// CHECK-DAG: %[[VAL_31:.*]] = tosa.table %[[VAL_28]], %[[VAL_18]] -// CHECK-DAG: %[[VAL_32:.*]] = tosa.table %[[VAL_28]], %[[VAL_17]] -// CHECK-DAG: %[[VAL_33:.*]] = tosa.logical_left_shift %[[VAL_29]], %[[VAL_14]] -// CHECK-DAG: %[[VAL_34:.*]] = tosa.logical_left_shift %[[VAL_30]], %[[VAL_13]] -// CHECK-DAG: %[[VAL_35:.*]] = tosa.logical_left_shift %[[VAL_31]], %[[VAL_12]] -// CHECK-DAG: %[[VAL_36:.*]] = tosa.arithmetic_right_shift %[[VAL_32]], %[[VAL_11]] -// CHECK-DAG: %[[VAL_37:.*]] = tosa.add %[[VAL_33]], %[[VAL_34]] -// CHECK-DAG: %[[VAL_38:.*]] = tosa.add %[[VAL_37]], %[[VAL_35]] -// CHECK-DAG: %[[VAL_39:.*]] = tosa.add %[[VAL_38]], %[[VAL_36]] -// CHECK-DAG: %[[VAL_40:.*]] = tosa.arithmetic_right_shift %[[VAL_39]], %[[VAL_10]] -// CHECK-DAG: %[[VAL_41:.*]] = tosa.reduce_sum %[[VAL_40]] -// CHECK-DAG: %[[VAL_42:.*]] = tosa.clz %[[VAL_41]] -// CHECK-DAG: %[[VAL_43:.*]] = tosa.sub %[[VAL_42]], %[[VAL_12]] -// CHECK-DAG: %[[VAL_44:.*]] = tosa.logical_left_shift %[[VAL_41]], %[[VAL_43]] -// CHECK-DAG: %[[VAL_45:.*]] = tosa.mul %[[VAL_44]], %[[VAL_6]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_46:.*]] = tosa.add %[[VAL_45]], %[[VAL_7]] -// CHECK-DAG: %[[VAL_47:.*]] = tosa.mul %[[VAL_46]], %[[VAL_44]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_48:.*]] = tosa.sub %[[VAL_8]], %[[VAL_47]] -// CHECK-DAG: %[[VAL_49:.*]] = tosa.mul %[[VAL_46]], %[[VAL_48]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_50:.*]] = tosa.mul %[[VAL_49]], %[[VAL_9]], %[[VAL_4]] -// CHECK-DAG: %[[VAL_51:.*]] = tosa.add %[[VAL_46]], %[[VAL_50]] -// CHECK-DAG: %[[VAL_52:.*]] = tosa.mul %[[VAL_51]], %[[VAL_44]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_53:.*]] = tosa.sub %[[VAL_8]], %[[VAL_52]] -// CHECK-DAG: %[[VAL_54:.*]] = tosa.mul %[[VAL_51]], %[[VAL_53]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_55:.*]] = tosa.mul %[[VAL_54]], %[[VAL_9]], %[[VAL_4]] -// CHECK-DAG: %[[VAL_56:.*]] = tosa.add %[[VAL_51]], %[[VAL_55]] -// CHECK-DAG: %[[VAL_57:.*]] = tosa.mul %[[VAL_56]], %[[VAL_44]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_58:.*]] = tosa.sub %[[VAL_8]], %[[VAL_57]] -// CHECK-DAG: %[[VAL_59:.*]] = tosa.mul %[[VAL_56]], %[[VAL_58]], %[[VAL_5]] -// CHECK-DAG: %[[VAL_60:.*]] = tosa.mul %[[VAL_59]], %[[VAL_9]], %[[VAL_4]] -// CHECK-DAG: %[[VAL_61:.*]] = tosa.add %[[VAL_56]], %[[VAL_60]] -// CHECK-DAG: %[[VAL_62:.*]] = tosa.mul %[[VAL_39]], %[[VAL_61]], %[[VAL_22]] -// CHECK-DAG: %[[VAL_63:.*]] = tosa.sub %[[VAL_3]], %[[VAL_42]] -// CHECK-DAG: %[[VAL_64:.*]] = tosa.arithmetic_right_shift %[[VAL_62]], %[[VAL_2]] -// CHECK-DAG: %[[VAL_65:.*]] = tosa.arithmetic_right_shift %[[VAL_64]], %[[VAL_63]] -// CHECK-DAG: %[[VAL_66:.*]] = tosa.rescale %[[VAL_65]], %[[VAL_21]], %[[VAL_22]], %[[VAL_24]], %[[VAL_1]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<35> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<4> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{values = dense<536870912> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{values = dense<1515870810> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR5:.*]] = "tosa.const"() <{values = dense<-1010580540> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR6:.*]] = "tosa.const"() <{values = dense<1> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR7:.*]] = "tosa.const"() <{values = dense<12> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR8:.*]] = "tosa.const"() <{values = dense<7> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR9:.*]] = "tosa.const"() <{values = dense<9> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR10:.*]] = "tosa.const"() <{values = dense<17> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR11:.*]] = "tosa.const"() <{values = dense<"0x5{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR12:.*]] = "tosa.const"() <{values = dense<"0xE{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR13:.*]] = "tosa.const"() <{values = dense<"0x4{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR14:.*]] = "tosa.const"() <{values = dense<"0x0{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT_31:.*]] = "tosa.const"() <{values = dense<31> : tensor<1xi8>}> +// CHECK-DAG: %[[mult1073741824:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-DAG: %[[shift30:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: %[[shift23:.*]] = "tosa.const"() <{values = dense<23> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: %[[input_zp1:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> +// CHECK-DAG: %[[zp0i32:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> +// CHECK-DAG: %[[output_zp128:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> +// CHECK-DAG: %[[VAL27:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi16>}> +// CHECK-DAG: %[[VAR15:.*]] = tosa.rescale %arg0, %[[mult1073741824]], %[[shift30]], %[[input_zp1]], %[[zp0i32]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} +// CHECK-DAG: %[[VAR16:.*]] = tosa.reduce_max %[[VAR15]] {axis = 2 : i32} +// CHECK-DAG: %[[VAR17:.*]] = tosa.sub %[[VAR15]], %[[VAR16]] +// CHECK-DAG: %[[VAR18:.*]] = tosa.rescale %[[VAR17]], %[[mult1073741824]], %[[shift23]], %[[zp0i32]], %[[VAL27]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} +// CHECK-DAG: %[[VAR19:.*]] = tosa.table %[[VAR18]], %[[VAR14]] +// CHECK-DAG: %[[VAR20:.*]] = tosa.table %[[VAR18]], %[[VAR13]] +// CHECK-DAG: %[[VAR21:.*]] = tosa.table %[[VAR18]], %[[VAR12]] +// CHECK-DAG: %[[VAR22:.*]] = tosa.table %[[VAR18]], %[[VAR11]] +// CHECK-DAG: %[[VAR23:.*]] = tosa.logical_left_shift %[[VAR19]], %[[VAR10]] +// CHECK-DAG: %[[VAR24:.*]] = tosa.logical_left_shift %[[VAR20]], %[[VAR9]] +// CHECK-DAG: %[[VAR25:.*]] = tosa.logical_left_shift %[[VAR21]], %[[VAR6]] +// CHECK-DAG: %[[VAR26:.*]] = tosa.arithmetic_right_shift %[[VAR22]], %[[VAR8]] {round = true} +// CHECK-DAG: %[[VAR27:.*]] = tosa.add %[[VAR23]], %[[VAR24]] +// CHECK-DAG: %[[VAR28:.*]] = tosa.add %[[VAR27]], %[[VAR25]] +// CHECK-DAG: %[[VAR29:.*]] = tosa.add %[[VAR28]], %[[VAR26]] +// CHECK-DAG: %[[VAR30:.*]] = tosa.arithmetic_right_shift %[[VAR29]], %[[VAR7]] {round = true} +// CHECK-DAG: %[[VAR31:.*]] = tosa.reduce_sum %[[VAR30]] {axis = 2 : i32} +// CHECK-DAG: %[[VAR32:.*]] = tosa.clz %[[VAR31]] +// CHECK-DAG: %[[VAR33:.*]] = tosa.sub %[[VAR32]], %[[VAR6]] +// CHECK-DAG: %[[VAR34:.*]] = tosa.logical_left_shift %[[VAR31]], %[[VAR33]] +// CHECK-DAG: %[[VAR35:.*]] = tosa.mul %[[VAR34]], %[[VAR5]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR36:.*]] = tosa.add %[[VAR35]], %[[VAR4]] +// CHECK-DAG: %[[VAR37:.*]] = tosa.mul %[[VAR36]], %[[VAR34]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR38:.*]] = tosa.sub %[[VAR3]], %[[VAR37]] +// CHECK-DAG: %[[VAR39:.*]] = tosa.mul %[[VAR36]], %[[VAR38]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR40:.*]] = tosa.mul %[[VAR39]], %[[VAR2]], %[[SHIFT]] +// CHECK-DAG: %[[VAR41:.*]] = tosa.add %[[VAR36]], %[[VAR40]] +// CHECK-DAG: %[[VAR42:.*]] = tosa.mul %[[VAR41]], %[[VAR34]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR43:.*]] = tosa.sub %[[VAR3]], %[[VAR42]] +// CHECK-DAG: %[[VAR44:.*]] = tosa.mul %[[VAR41]], %[[VAR43]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR45:.*]] = tosa.mul %[[VAR44]], %[[VAR2]], %[[SHIFT]] +// CHECK-DAG: %[[VAR46:.*]] = tosa.add %[[VAR41]], %[[VAR45]] +// CHECK-DAG: %[[VAR47:.*]] = tosa.mul %[[VAR46]], %[[VAR34]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR48:.*]] = tosa.sub %[[VAR3]], %[[VAR47]] +// CHECK-DAG: %[[VAR49:.*]] = tosa.mul %[[VAR46]], %[[VAR48]], %[[SHIFT_31]] +// CHECK-DAG: %[[VAR50:.*]] = tosa.mul %[[VAR49]], %[[VAR2]], %[[SHIFT]] +// CHECK-DAG: %[[VAR51:.*]] = tosa.add %[[VAR46]], %[[VAR50]] +// CHECK-DAG: %[[VAR52:.*]] = tosa.mul %[[VAR29]], %[[VAR51]], %[[shift30]] +// CHECK-DAG: %[[VAR53:.*]] = tosa.sub %[[VAR1]], %[[VAR32]] +// CHECK-DAG: %[[VAR54:.*]] = tosa.arithmetic_right_shift %[[VAR52]], %[[VAR53]] {round = true} +// CHECK: %[[VAR55:.*]] = tosa.rescale %[[VAR54]], %[[mult1073741824]], %[[shift30]], %[[zp0i32]], %[[output_zp128]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} func.func @test_softmax_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -2860,6 +2986,58 @@ func.func @test_resize_nearest_align_qi8_scalar_input(%arg0: tensor<3x1x1x7x!qua } // ----- + +// CHECK-LABEL: test_fullyconnected_qi16 +// CHECK: %[[BIAS:.+]] = "tosa.const"() <{values = dense<123> : tensor<3xi48>}> : () -> tensor<3xi48> +// CHECK: tosa.conv2d {{.+}}, %[[BIAS]], %{{.+}} {acc_type = i48, {{.+}}} : {{.+}} -> tensor<1x1x1x3xi48> +func.func @test_fullyconnected_qi16(%input: tensor<1x7x!quant.uniform>, %filter: tensor<3x7x!quant.uniform>) -> tensor<1x3x!quant.uniform> { + %bias = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform>, value = dense<123> : tensor<3xi32>} : () -> tensor<3x!quant.uniform> + %0 = "tfl.fully_connected"(%input, %filter, %bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x7x!quant.uniform>, tensor<3x7x!quant.uniform>, tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: @test_fullyconnected_dynamic_output +func.func @test_fullyconnected_dynamic_output(%arg0: tensor<1x2048xf32>, %arg1: tensor<1000x2048xf32>, %arg2: tensor<1000xf32>) -> tensor { + // CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 2048]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {values = dense<[1000, 1, 1, 2048]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {values = dense<[1, 1000]> : tensor<2xindex>} + // CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> + // CHECK: %[[VAL0:.*]] = tosa.reshape %arg0, %[[CONST0]] + // CHECK: %[[VAL1:.*]] = tosa.reshape %arg1, %[[CONST1]] + // CHECK: %[[VAL2:.*]] = tosa.conv2d %[[VAL0]], %[[VAL1]], %arg2, %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} + // CHECK: %[[VAL3:.*]] = tosa.reshape %[[VAL2]], %[[CONST2]] + // return %[[VAL3]] + %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x2048xf32>, tensor<1000x2048xf32>, tensor<1000xf32>) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @test_fullyconnected_keep_dims +func.func @test_fullyconnected_keep_dims(%arg0: tensor<1x64x64x768x!quant.uniform>, %arg1: tensor<3072x768x!quant.uniform:f32, 0.003333511995151639>>, %arg2: tensor<3072x!quant.uniform>) -> tensor<1x64x64x3072x!quant.uniform> { + // CHECK-DAG: %[[CONST_SHAPE0:.*]] = tosa.const_shape {values = dense<[1, 64, 64, 3072]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST0:.*]] = "tosa.const"() <{values = dense<38> : tensor<1xi8>}> + // CHECK-DAG: %[[CONST1:.*]] = "tosa.const"() <{values = dense<1241512252> : tensor<1xi32>}> + // CHECK-DAG: %[[CONST2:.*]] = "tosa.const"() <{values = dense<45> : tensor<1xi8>}> + // CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> + // CHECK-DAG: %[[CONST4:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> + // CHECK-DAG: %[[CONST5:.*]] = "tosa.const"() <{values = dense<5> : tensor<1xi8>}> + // CHECK-DAG: %[[CONST_SHAPE1:.*]] = tosa.const_shape {values = dense<[3072, 1, 1, 768]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST_SHAPE2:.*]] = tosa.const_shape {values = dense<[4096, 1, 1, 768]> : tensor<4xindex>} + // CHECK: %[[RESHAPE_IN:.*]] = tosa.reshape %arg0, %[[CONST_SHAPE2]] : (tensor<1x64x64x768x!quant.uniform>, !tosa.shape<4>) + // CHECK: %[[RESHAPE_FILT:.*]] = tosa.reshape %arg1, %[[CONST_SHAPE1]] : (tensor<3072x768x!quant.uniform:f32, 0.003333511995151639>>, !tosa.shape<4>) + // CHECK: %[[CONV:.*]] = tosa.conv2d %[[RESHAPE_IN]], %[[RESHAPE_FILT]], %arg2, %[[CONST5]], %[[CONST4]] {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<4096x1x1x768x!quant.uniform>, tensor<3072x1x1x768x!quant.uniform:f32, 0.003333511995151639>>, tensor<3072x!quant.uniform>, tensor<1xi8>, tensor<1xi8>) + // CHECK: %[[RESCALE:.*]] = tosa.rescale %[[CONV]], %[[CONST1]], %[[CONST0]], %[[CONST3]], %[[CONST2]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<4096x1x1x3072xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) + // CHECK: %[[RESHAPE_OUT:.*]] = tosa.reshape %[[RESCALE]], %[[CONST_SHAPE0]] : (tensor<4096x1x1x3072x!quant.uniform>, !tosa.shape<4>) -> tensor<1x64x64x3072x!quant.uniform> + // CHECK: return %[[RESHAPE_OUT]] + %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<1x64x64x768x!quant.uniform>, tensor<3072x768x!quant.uniform:f32, 0.003333511995151639>>, tensor<3072x!quant.uniform>) -> tensor<1x64x64x3072x!quant.uniform> + func.return %0 : tensor<1x64x64x3072x!quant.uniform> +} + +// ----- + // CHECK-LABEL: test_gather // CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 13, 63]> : tensor<3xindex>} // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg0, %[[VAR10]] @@ -3011,6 +3189,71 @@ func.func @test_sparse_to_dense(%arg0 : tensor, %arg1 : tensor) // ----- +// CHECK-LABEL: test_scatter_nd +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x224x512xf32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x2xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = tosa.reduce_sum %[[VAR2:.*]] {axis = 1 : i32} : (tensor<1x2xi32>) +// CHECK-DAG: %[[VAR5:.*]] = tosa.scatter %[[VAR1:.*]], %[[VAR3:.*]], %arg0 : (tensor<1x224x512xf32>, tensor<1x1xi32>, tensor<1x1x512xf32>) +// CHECK: return %[[VAR5]] +func.func @test_scatter_nd(%arg0: tensor<1x1x512xf32>) -> tensor<1x224x512xf32> { + %shape = "tfl.pseudo_const"() <{value = dense<[1, 224, 512]> : tensor<3xi32>}> : () -> tensor<3xi32> + %indices = "tfl.pseudo_const"() <{value = dense<[[[0, 0]]]> : tensor<1x1x2xi32>}> : () -> tensor<1x1x2xi32> + %0 = "tfl.scatter_nd"(%indices, %arg0, %shape) : (tensor<1x1x2xi32>, tensor<1x1x512xf32>, tensor<3xi32>) -> tensor<1x224x512xf32> + func.return %0 : tensor<1x224x512xf32> +} + +// ----- + +// CHECK-LABEL: test_scatter_nd_reshape +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<{{\[\[}}8, 4, 1]]> : tensor<1x3xi32>}> : () +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x16x4xf32>}> : () +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{values = dense<{{\[\[}}0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3], [1, 0, 0], [1, 0, 1], [1, 0, 2], [1, 0, 3]]> : tensor<8x3xi32>}> : () +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[NEW_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 8, 4]> : tensor<3xindex>} +// CHECK-DAG: %[[NEW_SHAPE1:.*]] = tosa.const_shape {values = dense<[1, 8]> : tensor<2xindex>} +// CHECK-DAG: %[[NEW_SHAPE2:.*]] = tosa.const_shape {values = dense<[2, 2, 4, 4]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg0, %[[NEW_SHAPE]] : (tensor<2x2x2x4xf32>, !tosa.shape<3>) +// CHECK-DAG: %[[VAR5:.*]] = tosa.mul %[[VAR3]], %[[VAR1]], %[[SHIFT]] : (tensor<8x3xi32>, tensor<1x3xi32>, tensor<1xi8>) +// CHECK-DAG: %[[VAR6:.*]] = tosa.reduce_sum %[[VAR5]] {axis = 1 : i32} : (tensor<8x3xi32>) +// CHECK-DAG: %[[VAR7:.*]] = tosa.reshape %[[VAR6]], %[[NEW_SHAPE1]] : (tensor<8x1xi32>, !tosa.shape<2>) +// CHECK-DAG: %[[VAR8:.*]] = tosa.scatter %[[VAR2]], %[[VAR7]], %[[VAR4]] : (tensor<1x16x4xf32>, tensor<1x8xi32>, tensor<1x8x4xf32>) +// CHECK-DAG: %[[VAR9:.*]] = tosa.reshape %[[VAR8]], %[[NEW_SHAPE2]] : (tensor<1x16x4xf32>, !tosa.shape<4>) +// CHECK-DAG: return %[[VAR9]] +func.func @test_scatter_nd_reshape(%arg0: tensor<2x2x2x4xf32>) -> tensor<2x2x4x4xf32> { + %shape = "tfl.pseudo_const"() <{value = dense<[2, 2, 4, 4]> : tensor<4xi32>}> : () -> tensor<4xi32> + %indices = "tfl.pseudo_const"() <{value = dense<[[[[0, 0, 0], [0, 0, 1]], [[0, 0, 2], [0, 0, 3]]], [[[1, 0, 0], [1, 0, 1]], [[1, 0, 2], [1, 0, 3]]]]> : tensor<2x2x2x3xi32>}> : () -> tensor<2x2x2x3xi32> + %0 = "tfl.scatter_nd"(%indices, %arg0, %shape) : (tensor<2x2x2x3xi32>, tensor<2x2x2x4xf32>, tensor<4xi32>) -> tensor<2x2x4x4xf32> + func.return %0 : tensor<2x2x4x4xf32> +} + +// ----- + +// CHECK-LABEL: test_scatter_nd_qi8 +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x224x512xi8>}> : () -> tensor<1x224x512x!quant.uniform> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x2xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = tosa.reduce_sum %[[VAR2:.*]] {axis = 1 : i32} : (tensor<1x2xi32>) +// CHECK-DAG: %[[VAR4:.*]] = tosa.scatter %[[VAR1:.*]], %[[VAR3:.*]], %arg0 : (tensor<1x224x512x!quant.uniform>, tensor<1x1xi32>, tensor<1x1x512x!quant.uniform>) +// CHECK: return %[[VAR4]] +func.func @test_scatter_nd_qi8(%arg0: tensor<1x1x512x!quant.uniform>) -> tensor<1x224x512x!quant.uniform> { + %shape = "tfl.pseudo_const"() <{value = dense<[1, 224, 512]> : tensor<3xi32>}> : () -> tensor<3xi32> + %indices = "tfl.pseudo_const"() <{value = dense<[[[0, 0]]]> : tensor<1x1x2xi32>}> : () -> tensor<1x1x2xi32> + %0 = "tfl.scatter_nd"(%indices, %arg0, %shape) : (tensor<1x1x2xi32>, tensor<1x1x512x!quant.uniform>, tensor<3xi32>) -> tensor<1x224x512x!quant.uniform> + func.return %0 : tensor<1x224x512x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_scatter_nd_duplicate_indices +// CHECK: tfl.scatter_nd +func.func @test_scatter_nd_duplicate_indices(%arg0: tensor<2x2x2x4xf32>) -> tensor<2x2x4x4xf32> { + %shape = "tfl.pseudo_const"() <{value = dense<[2, 2, 4, 4]> : tensor<4xi32>}> : () -> tensor<4xi32> + %indices = "tfl.pseudo_const"() <{value = dense<[[[[0, 0, 0], [0, 0, 1]], [[0, 0, 2], [0, 0, 3]]], [[[1, 0, 0], [1, 0, 0]], [[1, 0, 2], [1, 0, 3]]]]> : tensor<2x2x2x3xi32>}> : () -> tensor<2x2x2x3xi32> + %0 = "tfl.scatter_nd"(%indices, %arg0, %shape) : (tensor<2x2x2x3xi32>, tensor<2x2x2x4xf32>, tensor<4xi32>) -> tensor<2x2x4x4xf32> + func.return %0 : tensor<2x2x4x4xf32> +} + +// ----- + // CHECK-LABEL: @test_arg_max func.func @test_arg_max(%arg0: tensor<13x21x3xf32>) -> tensor<*xi32> { // CHECK: %[[ARGMAX:.+]] = tosa.argmax %arg0 {axis = 1 : i32} @@ -3113,6 +3356,15 @@ func.func @test_conv2d_infer(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x // ----- +// CHECK-LABEL: @test_conv2d_no_bias +func.func @test_conv2d_no_bias(%input: tensor<1x32x32x8x!quant.uniform>, %filter: tensor<3x3x8x16x!quant.uniform>) -> tensor<1x32x32x3x!quant.uniform> { + %bias = "tfl.no_value"() {value} : () -> none + %0 = "tfl.conv_2d"(%input, %filter, %bias) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8x!quant.uniform>, tensor<3x3x8x16x!quant.uniform>, none) -> tensor<1x32x32x3x!quant.uniform> + return %0 : tensor<1x32x32x3x!quant.uniform> +} + +// ----- + // CHECK-LABEL: @test_squeeze func.func @test_squeeze(%arg0: tensor<2x1x3x1xf32>) -> tensor<2x3x1xf32> { // CHECK: tosa.reshape @@ -3720,3 +3972,14 @@ func.func @test_transpose_conv2d_bias_f32(%arg0: tensor<1x64x64x256xf32>) -> ten %2 = "tfl.transpose_conv"(%cst, %0, %arg0, %1) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<128x2x2x256xf32>, tensor<1x64x64x256xf32>, tensor<128xf32>) -> tensor<1x128x128x128xf32> return %2 : tensor<1x128x128x128xf32> } + +// ----- + +// CHECK-LABEL: test_concat_qconst +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<42> : tensor<28x19xi8>}> : () -> tensor<28x19x!quant.uniform> +// CHECK-DAG: %[[VAR1:.*]] = tosa.concat %[[VAR0]], %arg0 {axis = 0 : i32} : (tensor<28x19x!quant.uniform>, tensor<1x19x!quant.uniform>) -> tensor<29x19x!quant.uniform> +func.func @test_concat_qconst(%arg0: tensor<1x19x!quant.uniform> ) -> tensor<29x19x!quant.uniform> { + %0 = "tfl.pseudo_qconst"() {qtype = tensor<28x19x!quant.uniform>, value = dense<42> : tensor<28x19xi8>} : () -> tensor<28x19x!quant.uniform> + %1 = "tfl.concatenation"(%0, %arg0) {axis = 0 : i32, fused_activation_function = "NONE"}: (tensor<28x19x!quant.uniform>, tensor<1x19x!quant.uniform>) -> tensor<29x19x!quant.uniform> + return %1 : tensor<29x19x!quant.uniform> +} diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir index 84a960be8edf69..ac18406771fea4 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir @@ -1,7 +1,7 @@ -// RUN: tf-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa -// RUN: tf-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + +// RUN: tf-tosa-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + // Operations for testing tfl-to-tosa-pipeline diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-unequal-ranks.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-unequal-ranks.mlir index 311077409348e7..c4d07792549543 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-unequal-ranks.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-unequal-ranks.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + // Test tf legalization that produce TOSA ResultsBroadcastableShape operators with unequal ranks // ----- @@ -111,7 +111,7 @@ func.func @test_mul_qi8(%arg0: tensor<13x21x3x!quant.uniform, %arg1: tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> { %0 = "tfl.floor_div"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xi32>, tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> @@ -120,7 +120,7 @@ func.func @test_floor_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<1x13x1x3xi32 // ----- // CHECK-LABEL: test_div -// CHECK: tosa.int_div +// CHECK: tosa.intdiv func.func @test_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor) -> tensor<*xi32> { %0 = "tfl.div"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xi32>, tensor) -> tensor<*xi32> func.return %0 : tensor<*xi32> diff --git a/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir b/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir index ac918b321356e8..c8c8eb46c58cd7 100644 --- a/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt %s --tosa-tflite-verify-fully-converted --split-input-file -verify-diagnostics -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt %s --tosa-tflite-verify-fully-converted --split-input-file -verify-diagnostics + // CHECK-LABEL: func.func @main func.func @main(%arg0: tensor<2xf32>) -> (tensor<2xf32>) { diff --git a/tensorflow/compiler/mlir/tosa/tf_tosa_opt.cc b/tensorflow/compiler/mlir/tosa/tf_tosa_opt.cc new file mode 100644 index 00000000000000..9dd43370877802 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/tf_tosa_opt.cc @@ -0,0 +1,81 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow//compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" +#include "tensorflow/compiler/mlir/register_common_dialects.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/runtime_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/mlprogram_util.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.h" +#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/tosa/tf_passes.h" +#include "tensorflow/compiler/mlir/tosa/tf_tfl_passes.h" +#include "tensorflow/compiler/mlir/tosa/tfl_passes.h" +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h" +#include "xla/mlir/framework/transforms/passes.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" + +int main(int argc, char** argv) { + tensorflow::InitMlir y(&argc, &argv); + + mlir::registerAllPasses(); + mlir::registerTransformsPasses(); + mlir::registerTensorFlowPasses(); + mlir::TFDevice::registerTensorFlowDevicePasses(); + mlir::tf_saved_model::registerTensorFlowSavedModelPasses(); + mlir::TFL::registerTensorFlowLitePasses(); + mlir::mhlo::registerAllMhloPasses(); + + // These are in compiler/mlir/tf2xla and not part of the above MHLO passes. + mlir::mhlo::registerLegalizeTfPasses(); + mlir::mhlo::registerTfXlaPasses(); + mlir::quant::stablehlo::registerBridgePasses(); + tensorflow::tf2xla::internal::registerTFXLABridgeClusteringPasses(); + tensorflow::tf2xla::internal::registerTFXLABridgeMlirToGraphPasses(); + mlir::tf_test::registerTensorFlowTestPasses(); + mlir::xla_framework::registerXlaFrameworkPasses(); + tensorflow::RegisterConvertMlirToXlaHloPipelineWithDefaults(); + tensorflow::RegisterGraphOptimizationPasses(); + tensorflow::RegisterMlProgramPasses(); + mlir::TFTPU::registerRuntimeLoweringPasses(); + mlir::TFDevice::registerSparseCorePasses(); + mlir::tosa::registerLegalizeTosaPasses(); + mlir::tosa::registerTFtoTOSALegalizationPipeline(); + mlir::tosa::registerTFLtoTOSALegalizationPipeline(); + mlir::tosa::registerTFTFLtoTOSALegalizationPipeline(); + + tensorflow::tfrt_compiler::RegisterTPULowerClusterToRuntimeOpsPassPipeline(); + tensorflow::tfrt_compiler:: + RegisterNonTPULowerClusterToRuntimeOpsPassPipeline(); + + mlir::DialectRegistry registry; + mlir::RegisterCommonToolingDialects(registry); + + return failed( + mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry)); +} diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 100a429af4a96a..d0bc0d6b57d5ae 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -46,6 +46,7 @@ limitations under the License. #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project #include "mlir/Dialect/Utils/StaticValueUtils.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -1698,11 +1699,11 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, rewriter, op->getLoc(), int32_logits_type, op12_add_op11_op9.getResult(), op10_rshift_op8.getResult()); - // Step 3. get sum(exp()). output 13.18 + // Step 3. get sum(exp()). output 12.19 auto op14_rshift_op13_12 = CreateOpAndInfer( rewriter, op->getLoc(), int32_logits_type, op13_add_op12_op10.getResult(), - getTosaConstTensorSingleI32(rewriter, op, 13, input_rank), true); + getTosaConstTensorSingleI32(rewriter, op, 12, input_rank), true); auto op15_reducesum_op14 = CreateOpAndInfer( rewriter, op->getLoc(), int32_rsum_type, @@ -1789,32 +1790,17 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, // Right shift amount is // num_bits_over_unit + 31 - (sizeof(OutputT) * 8 = - // (13 - headroom_plus_one) + 31 - 8 = - // (13 + 31 - 8) - headroom_plus_one - - // The calculated shift amount can be larger than 31, which is invalid - // in TOSA. In this case, the output should be the quantized equivalent - // to all 0's. To emulate this behaviour, we can use two shifts: - // 1. Right shift of 5, calculated by: - // max_headroom_plus_one_value = 31; - // 13 + 31 - 8 - max_headroom_plus_one_value - // 2. Right shift by the remainder - constexpr int constant_shift_amount = 5; - + // (12 - headroom_plus_one) + 31 - 8 = + // (12 + 31 - 8) - headroom_plus_one auto op27_sub_op16 = CreateOpAndInfer( rewriter, op->getLoc(), int32_rsum_type, - getTosaConstTensorSingleI32(rewriter, op, 13 + 31 - 8 - constant_shift_amount, input_rank), + getTosaConstTensorSingleI32(rewriter, op, 12 + 31 - 8, input_rank), op16_clz_op15.getResult()); - auto constant_shift = CreateOpAndInfer( - rewriter, op->getLoc(), int32_logits_type, - op26_mul_op13_x.getResult(), getTosaConstTensorSingleI32(rewriter, op, constant_shift_amount, input_rank), - false); - auto op28_rshift_op26_op27 = CreateOpAndInfer( rewriter, op->getLoc(), int32_logits_type, - constant_shift.getResult(), op27_sub_op16.getResult(), true); + op26_mul_op13_x.getResult(), op27_sub_op16.getResult(), true); return buildRescale(rewriter, op, output_type, op28_rshift_op26_op27.getResult(), 1.0, 0, @@ -1839,8 +1825,8 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, auto exp_func = [](double x) -> double { return std::exp(x); }; // Follow TFLite reference: tensorflow/lite/kernels/activations.cc - Value exp_table_const = - getTosaConst16bitTable(rewriter, op, exp_func, -10.0, 0); + Value exp_table_const = getTosaConst16bitTable( + rewriter, op, 10.0 / 65535.0, 32767, 2.0 / 65535.0, 0, exp_func); double input_diff_scale = in_quant_type.getScale() / (10.0 / 65535.0); @@ -1913,8 +1899,9 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, return 1.0 / (1.0 + x); }; - Value one_over_one_plus_x_table_const = getTosaConst16bitTable( - rewriter, op, one_over_one_plus_x_func, 0.0, 1.0); + Value one_over_one_plus_x_table_const = getTosaConst16bitTable( + rewriter, op, 1.0 / 65535.0, -32768, 2.0 / 65535.0, 0, + one_over_one_plus_x_func); // Get (1 / sum(exp(x))) result as 23 bits (including sign bit) auto op17_table_op16 = CreateOpAndInfer( @@ -3023,13 +3010,12 @@ std::optional convertReduceOpCommon( bool is_quantized, int32_t input_scale_multiplier, int32_t input_scale_shift, int64_t input_zp, int32_t output_scale_multiplier, int32_t output_scale_shift, - int64_t output_zp, StringRef nan_mode = "") { + int64_t output_zp, bool keep_dims, StringRef nan_mode = "") { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; ArrayRef input_shape = input_type.getShape(); - ArrayRef output_shape = output_type.getShape(); auto input_rank = input_shape.size(); Location loc = op->getLoc(); @@ -3096,7 +3082,29 @@ std::optional convertReduceOpCommon( /*scale32=*/true); } + // If keep dims, no reshaping of the output is required + if (keep_dims) { + return val; + } + // Squeeze out the reduced axes. + const auto squeeze_axes = [](llvm::ArrayRef in, llvm::ArrayRef axes) { + llvm::SmallVector sorted_axes{axes}; + std::sort(sorted_axes.begin(), sorted_axes.end()); + auto current_axis = sorted_axes.begin(); + + llvm::SmallVector out; + out.reserve(in.size() - axes.size()); + for (const auto& [i, dim] : llvm::enumerate(in)) { + if (current_axis == sorted_axes.end() || i != *current_axis) + out.push_back(dim); + else + current_axis++; + } + return out; + }; + + const auto output_shape = squeeze_axes(input_shape, axes); auto output_shape_value = getTosaConstShape(rewriter, op->getLoc(), tensorflow::ConvertMlirShapeToTF(output_shape)); @@ -3112,7 +3120,7 @@ std::optional convertReduceOpCommon( PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, Type reduce_element_type, bool is_quantized, double input_scale, int64_t input_zp, - double output_scale, int64_t output_zp, StringRef nan_mode = "") { + double output_scale, int64_t output_zp, bool keep_dims, StringRef nan_mode = "") { const int32_t scale_width = 32; int32_t input_scale_multiplier; @@ -3128,7 +3136,7 @@ std::optional convertReduceOpCommon( return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, reduce_element_type, is_quantized, input_scale_multiplier, input_scale_shift, input_zp, - output_scale_multiplier, output_scale_shift, output_zp, nan_mode); + output_scale_multiplier, output_scale_shift, output_zp, keep_dims, nan_mode); } // Lowers ReduceAll to a sequence of TOSA ops. @@ -3136,14 +3144,15 @@ std::optional convertReduceAllOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims); } // Lowers ReduceAny to a sequence of TOSA ops. @@ -3151,14 +3160,15 @@ std::optional convertReduceAnyOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims); } // Lowers ReduceMin to a sequence of TOSA ops. @@ -3167,6 +3177,7 @@ std::optional convertReduceMinOp(PatternRewriter& rewriter, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, + bool keep_dims, StringRef nan_mode) { RankedTensorType input_type = dyn_cast(input_value.getType()); @@ -3174,7 +3185,7 @@ std::optional convertReduceMinOp(PatternRewriter& rewriter, return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, nan_mode); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims, nan_mode); } // Lowers ReduceMax to a sequence of TOSA ops. @@ -3183,6 +3194,7 @@ std::optional convertReduceMaxOp(PatternRewriter& rewriter, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, + bool keep_dims, StringRef nan_mode) { RankedTensorType input_type = dyn_cast(input_value.getType()); @@ -3190,7 +3202,7 @@ std::optional convertReduceMaxOp(PatternRewriter& rewriter, return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, nan_mode); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims, nan_mode); } // Lowers ReduceProd to a sequence of TOSA ops. @@ -3198,7 +3210,8 @@ std::optional convertReduceProdOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -3216,7 +3229,7 @@ std::optional convertReduceProdOp(PatternRewriter& rewriter, return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims); } // Lowers ReduceSum to a sequence of TOSA ops. @@ -3224,7 +3237,8 @@ std::optional convertReduceSumOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -3267,7 +3281,7 @@ std::optional convertReduceSumOp(PatternRewriter& rewriter, return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, reduce_element_type, - input_is_qtype, input_scale, input_zp, output_scale, output_zp); + input_is_qtype, input_scale, input_zp, output_scale, output_zp, keep_dims); } // Lowers ReduceMean to a sequence of TOSA ops. @@ -3275,7 +3289,8 @@ std::optional convertReduceMeanOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { // reduce_mean is lowered as followed for quantized types: // op1 = reduce_sum(input) with the 1.0/num_elements_on_reduced_axis // integrated to the rescale layer, @@ -3368,7 +3383,7 @@ std::optional convertReduceMeanOp(PatternRewriter& rewriter, auto val = convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, reduce_element_type, input_is_qtype, input_scale_multiplier, input_scale_shift, input_zp, - output_scale_multiplier, output_scale_shift, output_zp); + output_scale_multiplier, output_scale_shift, output_zp, keep_dims); if (!val.has_value()) return std::nullopt; @@ -3940,7 +3955,7 @@ std::optional convertConv3DCommon( (void)rewriter.notifyMatchFailure(op, "currently only supports NDHWC"); return std::nullopt; } - RankedTensorType filter_type = filter.getType().cast(); + RankedTensorType filter_type = mlir::cast(filter.getType()); // Note that the kernel shape of tfl.conv_3d isn't [O, D, H, W, I] but // [D, H, W, I, O] which is the same as in TF. // Transpose filter shape from [D, H, W, I, O] to [O, D, H, W, C] @@ -4406,6 +4421,229 @@ std::optional convertGatherNdOp(PatternRewriter& rewriter, Operation* op, .getResult(); } +std::optional convertScatterNdOp(PatternRewriter& rewriter, + Operation* op, Value result_value, + Value indices_value, + Value updates_value, + Value shape_value) { + auto const result_type = dyn_cast(result_value.getType()); + auto const indices_type = dyn_cast(indices_value.getType()); + auto const updates_type = dyn_cast(updates_value.getType()); + auto const shape_type = dyn_cast(shape_value.getType()); + + if (!result_type || !indices_type || !updates_type || !shape_type) { + (void)rewriter.notifyMatchFailure( + op, "input/output types must be ranked tensor type"); + return std::nullopt; + } + + // Don't support variable indices yet since we cannot check uniqueness + // of indices in this case + Operation* indices_op = indices_value.getDefiningOp(); + if (!indices_op || !llvm::isa(indices_op)) { + (void)rewriter.notifyMatchFailure(op, "indices must be a constant tensor"); + return std::nullopt; + } + + Type indices_elmt_type = indices_type.getElementType(); + if (!indices_elmt_type.isInteger(32)) { + (void)rewriter.notifyMatchFailure(op, "indices expected to be int32"); + return std::nullopt; + } + + // The tosa scatter operation only supports unique indices, so if there + // are duplicates, we cannot legalize + tosa::ConstOp const_indices = cast(indices_op); + ElementsAttr const_data = const_indices.getValues(); + if (!checkUniqueConstantScatterIndices(indices_type, result_type, + const_data)) { + (void)rewriter.notifyMatchFailure(op, "index values must be unique"); + return std::nullopt; + } + + // N: number of batches + // Always 1 for ScatterND + // + // Because TOSA's SCATTER operator already uses the symbol 'N' for + // the number of batches, we will use the symbol 'ND' to specify the + // number of dimensions that are sliced from input instead of'N' in + // the TF MLIR documentation. + // + // ND: indices.shape[-1] + // + // W: number of indices in each batch + // Computed as: + // product(indices.shape[0:-1]) (all but the last dimension) + // + // K: range of each index + // Computed as: + // product(result.shape[0:ND-1]) + // + // C: number of channels for each index + // Computed as: + // product(result.shape[ND:]) + // + // The updates tensor needs to be reshaped, but not transposed, to move + // the dimensions into [N, W, C] order. + // + // Indices needs to be put in the form of [N, W], but a simple flattening + // will not suffice, because the indices need to index into the [W]-shape + // updates vector instead. + // + // To flatten the coordinates, first reshape indices to a [W, ND] matrix, + // where the matrix now represents W ND-dimensional coordinates into the + // updates tensor. + // + // From here, we take each of the ND dimensions and multiply it with + // the size of the next updates dimension (or 1 for the last + // dimension), then sum all these together with a reduce_sum + // operator. This is exactly the same mathematics as one would use + // flatten the indices of an N-dimensional row-major array into a + // 1-D array in C. + // + // More precisely, do an element-wise multiply with [updates.shape[1 + // .. ND], 1] in axis 1, then reduce_sum in axis 1 to flatten to a + // [W]-shaped tensor, then trivially reshape to [N=1, W] to be + // compatible with the SCATTER operator's shape. + // + // Then perform the tosa.SCATTER() operation. + // + // Now we have result = [N, K, C]. + // + // Reshape with a single, simple reshape to the final output shape + // provided by shape_value. + + const unsigned int input_output_rank = result_type.getShape().size(); + const unsigned int indices_rank = indices_type.getShape().size(); + + const unsigned int ND = indices_type.getShape()[indices_rank - 1]; + + if (ND > input_output_rank) { + (void)rewriter.notifyMatchFailure( + op, "size of last dimension of indices must be <= input/output rank"); + return std::nullopt; + } + + // Calculate N, K, W, C. (N is always 1) + auto const indices_shape_begin{indices_type.getShape().begin()}; + auto const result_shape_begin{result_type.getShape().begin()}; + auto const accumulate_func = [](auto const& a_, auto const& b_) { + return a_ * b_; + }; + + const unsigned int N = 1; + const unsigned int W = std::accumulate(indices_shape_begin, + indices_shape_begin + indices_rank - 1, + 1, accumulate_func); + const unsigned int K = std::accumulate( + result_shape_begin, result_shape_begin + ND, 1, accumulate_func); + const unsigned int C = std::accumulate(result_shape_begin + ND, + result_shape_begin + input_output_rank, + 1, accumulate_func); + + SmallVector tosa_indices_shape({N, W}); + SmallVector indices_matrix_shape({W, ND}); + SmallVector tosa_input_shape({N, W, C}); + SmallVector tosa_values_in_out_shape({N, K, C}); + + // Flatten the updates tensor to an [N, W] matrix. + auto input_shape_value = + getTosaConstShape(rewriter, op->getLoc(), + tensorflow::ConvertMlirShapeToTF(tosa_input_shape)); + auto tosa_input_reshape_op = CreateOpAndInfer( + rewriter, op->getLoc(), + tensorflow::GetTypeFromTFTensorShape(tosa_input_shape, + result_type.getElementType()), + updates_value, input_shape_value); + + // Flatten the indices tensor to an [W, ND] matrix. + auto indices_matrix_shape_value = + getTosaConstShape(rewriter, op->getLoc(), + tensorflow::ConvertMlirShapeToTF(indices_matrix_shape)); + auto indices_matrix_reshape_op = CreateOpAndInfer( + rewriter, op->getLoc(), + tensorflow::GetTypeFromTFTensorShape(indices_matrix_shape, + indices_elmt_type), + indices_value, indices_matrix_shape_value); + + SmallVector flattened_coeff_vec; + for (int i = 1; i < ND; i++) { + flattened_coeff_vec.push_back(result_type.getShape()[i]); + } + flattened_coeff_vec.push_back(1); + for (int i = ND - 1; i > 0; i--) { + flattened_coeff_vec[i - 1] *= flattened_coeff_vec[i]; + } + std::optional flattened_coeff_value = getConstTensor( + rewriter, op, flattened_coeff_vec, + {static_cast(flattened_coeff_vec.size())}); + + if (!flattened_coeff_value) { + (void)rewriter.notifyMatchFailure( + op, "failed to calculate flattened coeff value"); + return std::nullopt; + } + + // Multiply the coefficients by the coordinates + Value mul_x = indices_matrix_reshape_op.getResult(); + Value mul_y = flattened_coeff_value.value(); + RankedTensorType mul_type = tensorflow::GetTypeFromTFTensorShape( + indices_matrix_shape, indices_type.getElementType()); + if (EqualizeRanks(rewriter, op->getLoc(), mul_x, mul_y).failed()) { + (void)rewriter.notifyMatchFailure( + op, "failed to broadcast coefficients over the coordinates"); + return std::nullopt; + } + auto flattened_indices_mul_op = CreateMulOpAndInfer( + rewriter, op, mul_type, mul_x, mul_y); + + // Sum up the products of the coefficients and coordinates + auto flattened_indices_reduce_op = CreateOpAndInfer( + rewriter, op->getLoc(), + tensorflow::GetTypeFromTFTensorShape(tosa_indices_shape, + indices_type.getElementType()), + flattened_indices_mul_op.getResult(), rewriter.getI32IntegerAttr(1)); + + // And reshape to [N, W] + auto tosa_indices_shape_value = + getTosaConstShape(rewriter, op->getLoc(), + tensorflow::ConvertMlirShapeToTF(tosa_indices_shape)); + auto tosa_indices_reshape_op = CreateOpAndInfer( + rewriter, op->getLoc(), + tensorflow::GetTypeFromTFTensorShape(tosa_indices_shape, + indices_type.getElementType()), + flattened_indices_reduce_op.getResult(), tosa_indices_shape_value); + + // Scatter_nd has no input tensor, use a zero tensor + Type const_element_type = updates_type.getElementType(); + auto const_type = + RankedTensorType::get(tosa_values_in_out_shape, const_element_type); + if (mlir::isa(const_element_type)) { + auto quant_type = dyn_cast(const_element_type); + const_element_type = quant_type.getStorageType(); + } + auto const_storage_type = + RankedTensorType::get(tosa_values_in_out_shape, const_element_type); + auto const_attr = DenseElementsAttr::get( + const_storage_type, rewriter.getZeroAttr(const_element_type)); + Value tosa_values_in = + rewriter.create(op->getLoc(), const_type, const_attr); + + // Now the scatter op itself + auto tosa_scatter_op = CreateOpAndInfer( + rewriter, op->getLoc(), result_type, tosa_values_in, + tosa_indices_reshape_op.getResult(), tosa_input_reshape_op.getResult()); + + // Finally, reshape back to the expected output shape. + auto reshape_shape_value = + getTosaConstShape(rewriter, op->getLoc(), + tensorflow::ConvertMlirShapeToTF(result_type.getShape())); + return CreateOpAndInfer(rewriter, op->getLoc(), result_type, + tosa_scatter_op.getResult(), + reshape_shape_value) + .getResult(); +} + // Lowers OneHot operator to a sequence of TOSA ops. std::optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, Value result_value, Value indices_value, @@ -4764,7 +5002,7 @@ std::optional convertBroadcastToOp(PatternRewriter& rewriter, // Lowers cast operator to a sequence of TOSA ops. std::optional convertCastOp(PatternRewriter& rewriter, Operation* op, Value input, RankedTensorType output_type) { - auto input_type = input.getType().cast(); + auto input_type = mlir::cast(input.getType()); auto input_element_type = input_type.getElementType(); Value cast_input = input; diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h index 5ddcec25e821f9..8cc74ee9bd5157 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h @@ -179,14 +179,16 @@ std::optional convertReduceAllOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems); + ElementsAttr axes_elems, + bool keep_dims); // Lowers ReduceAny to a sequence of TOSA ops. std::optional convertReduceAnyOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems); + ElementsAttr axes_elems, + bool keep_dims); // Lowers ReduceMin to a sequence of TOSA ops. std::optional convertReduceMinOp(PatternRewriter& rewriter, @@ -194,6 +196,7 @@ std::optional convertReduceMinOp(PatternRewriter& rewriter, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, + bool keep_dims, StringRef nan_mode = "PROPAGATE"); // Lowers ReduceMax to a sequence of TOSA ops. @@ -202,6 +205,7 @@ std::optional convertReduceMaxOp(PatternRewriter& rewriter, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, + bool keep_dims, StringRef nan_mode = "PROPAGATE"); // Lowers ReduceProd to a sequence of TOSA ops. @@ -209,21 +213,24 @@ std::optional convertReduceProdOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems); + ElementsAttr axes_elems, + bool keep_dims); // Lowers ReduceSum to a sequence of TOSA ops. std::optional convertReduceSumOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems); + ElementsAttr axes_elems, + bool keep_dims); // Lowers ReduceMean to a sequence of TOSA ops. std::optional convertReduceMeanOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elem); + ElementsAttr axes_elem, + bool keep_dims); // Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize. std::optional convertResizeOp(PatternRewriter& rewriter, Operation* op, @@ -298,6 +305,12 @@ std::optional convertGatherNdOp(PatternRewriter& rewriter, Operation* op, Value result_value, Value params_value, Value indices_value); +// Lowers ScatterNd operator to a sequence of TOSA ops. +std::optional convertScatterNdOp(PatternRewriter& rewriter, + Operation* op, Value result_value, + Value indices_value, + Value updates_value, Value shape_value); + // Lowers OneHot operator to a sequence of TOSA ops. std::optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, Value result_value, Value indices_value, diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index b355829547f0c3..5f2f04ad4051a6 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -131,6 +131,7 @@ DECL_CONVERT_OP(ResizeNearestNeighbor); DECL_CONVERT_OP(Gather); DECL_CONVERT_OP(GatherV2); DECL_CONVERT_OP(GatherNd); +DECL_CONVERT_OP(ScatterNd); DECL_CONVERT_OP(SelectV2); DECL_CONVERT_OP(SpaceToDepth); DECL_CONVERT_OP(DepthToSpace); @@ -176,7 +177,7 @@ LogicalResult ConvertTFReluOp::matchAndRewrite( } mlir::Attribute min_val, max_val; - if (element_type.isa()) { + if (mlir::isa(element_type)) { min_val = rewriter.getFloatAttr(element_type, 0.0f); max_val = rewriter.getFloatAttr(element_type, std::numeric_limits::max()); @@ -207,7 +208,7 @@ LogicalResult ConvertTFRelu6Op::matchAndRewrite( } mlir::Attribute min_val, max_val; - if (element_type.isa()) { + if (mlir::isa(element_type)) { min_val = rewriter.getFloatAttr(element_type, 0.0f); max_val = rewriter.getFloatAttr(element_type, 6.0f); } else { @@ -1122,7 +1123,7 @@ LogicalResult ConvertTFAllOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceAllOp( - rewriter, op, output_type, tf_all_op.getInput(), axes_elems); + rewriter, op, output_type, tf_all_op.getInput(), axes_elems, tf_all_op.getKeepDims()); if (!result) return failure(); @@ -1144,7 +1145,7 @@ LogicalResult ConvertTFAnyOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceAnyOp( - rewriter, op, output_type, tf_any_op.getInput(), axes_elems); + rewriter, op, output_type, tf_any_op.getInput(), axes_elems, tf_any_op.getKeepDims()); if (!result) return failure(); @@ -1166,7 +1167,7 @@ LogicalResult ConvertTFMaxOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceMaxOp( - rewriter, op, output_type, tf_max_op.getInput(), axes_elems); + rewriter, op, output_type, tf_max_op.getInput(), axes_elems, tf_max_op.getKeepDims()); if (!result) return failure(); @@ -1188,7 +1189,7 @@ LogicalResult ConvertTFMinOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceMinOp( - rewriter, op, output_type, tf_min_op.getInput(), axes_elems); + rewriter, op, output_type, tf_min_op.getInput(), axes_elems, tf_min_op.getKeepDims()); if (!result) return failure(); @@ -1210,7 +1211,7 @@ LogicalResult ConvertTFMeanOp::matchAndRewrite( return failure(); std::optional result = convertReduceMeanOp( - rewriter, op, output_type, tf_mean_op.getInput(), axes_elems); + rewriter, op, output_type, tf_mean_op.getInput(), axes_elems, tf_mean_op.getKeepDims()); if (!result) return failure(); @@ -1232,7 +1233,7 @@ LogicalResult ConvertTFProdOp::matchAndRewrite( return failure(); std::optional result = convertReduceProdOp( - rewriter, op, output_type, tf_prod_op.getInput(), axes_elems); + rewriter, op, output_type, tf_prod_op.getInput(), axes_elems, tf_prod_op.getKeepDims()); if (!result) return failure(); @@ -1254,7 +1255,7 @@ LogicalResult ConvertTFSumOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceSumOp( - rewriter, op, output_type, tf_sum_op.getInput(), axes_elems); + rewriter, op, output_type, tf_sum_op.getInput(), axes_elems, tf_sum_op.getKeepDims()); if (!result) return failure(); @@ -1446,7 +1447,7 @@ LogicalResult ConvertTFFusedBatchNormV3Op::matchAndRewrite( auto epsilon_const = CreateOpAndInfer( rewriter, op->getLoc(), epsilon_type, epsilon_attr); - variance_type = variance.getType().cast(); + variance_type = mlir::cast(variance.getType()); Value op2_add_var_epsilon = CreateOpAndInfer( rewriter, op->getLoc(), variance_type, variance, epsilon_const); @@ -1777,7 +1778,7 @@ LogicalResult ConvertTFPadV2Op::matchAndRewrite( auto tf_pad_op = cast(op); RankedTensorType output_type = - tf_pad_op.getResult().getType().dyn_cast(); + mlir::dyn_cast(tf_pad_op.getResult().getType()); if (!output_type) { return rewriter.notifyMatchFailure(op, "output type not a ranked tensor"); } @@ -2001,6 +2002,22 @@ LogicalResult ConvertTFGatherNdOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFScatterNdOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_scatternd_op = cast(op); + + const std::optional result = convertScatterNdOp( + rewriter, op, tfl_scatternd_op.getResult(), tfl_scatternd_op.getIndices(), + tfl_scatternd_op.getUpdates(), tfl_scatternd_op.getShape()); + + if (!result) { + return failure(); + } + rewriter.replaceOp(op, {result.value()}); + + return success(); +} + LogicalResult ConvertTFSelectV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_sel_op = cast(op); @@ -2620,6 +2637,7 @@ void populateLegalizeTFPatterns(MLIRContext* ctx, RewritePatternSet& patterns) { patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); + patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index a11ed5e33b465e..b37319b07d6ee4 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -196,6 +196,7 @@ DECL_CONVERT_OP(Const); DECL_CONVERT_OP(QConst); DECL_CONVERT_OP(Gather); DECL_CONVERT_OP(GatherNd); +DECL_CONVERT_OP(ScatterNd); DECL_CONVERT_OP(SparseToDense); DECL_CONVERT_OP(OneHot); DECL_CONVERT_OP(ArgMax); @@ -207,8 +208,11 @@ DECL_CONVERT_OP(Imag); DECL_CONVERT_OP(RFFT2d); DECL_CONVERT_OP(LogicalAnd); DECL_CONVERT_OP(LogicalOr); +DECL_CONVERT_OP(BitwiseXor); DECL_CONVERT_OP(Pow); DECL_CONVERT_OP(BroadcastTo); +DECL_CONVERT_OP(Exp); +DECL_CONVERT_OP(Log); #undef DECL_CONVERT_OP @@ -359,7 +363,7 @@ LogicalResult ConvertTFLReluOp::matchAndRewrite( } mlir::Attribute min_val, max_val; - if (element_type.isa()) { + if (mlir::isa(element_type)) { min_val = rewriter.getFloatAttr(element_type, 0.0f); max_val = rewriter.getFloatAttr(element_type, std::numeric_limits::max()); @@ -429,7 +433,7 @@ LogicalResult ConvertTFLRelu1Op::matchAndRewrite( } mlir::Attribute min_val, max_val; - if (element_type.isa()) { + if (mlir::isa(element_type)) { min_val = rewriter.getFloatAttr(element_type, -1.0f); max_val = rewriter.getFloatAttr(element_type, 1.0f); } else { @@ -496,7 +500,7 @@ LogicalResult ConvertTFLRelu0To1Op::matchAndRewrite( } mlir::Attribute min_val, max_val; - if (element_type.isa()) { + if (mlir::isa(element_type)) { min_val = rewriter.getFloatAttr(element_type, 0.0f); max_val = rewriter.getFloatAttr(element_type, 1.0f); } else { @@ -563,7 +567,7 @@ LogicalResult ConvertTFLRelu6Op::matchAndRewrite( } mlir::Attribute min_val, max_val; - if (element_type.isa()) { + if (mlir::isa(element_type)) { min_val = rewriter.getFloatAttr(element_type, 0.0f); max_val = rewriter.getFloatAttr(element_type, 6.0f); } else { @@ -1343,6 +1347,8 @@ LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite( DenseI64ArrayAttr kernel_size; DenseI64ArrayAttr stride; DenseI64ArrayAttr pad; + // Pooling has no non-unit dilation + DenseI64ArrayAttr dilation = rewriter.getDenseI64ArrayAttr({1, 1}); { int64_t kernel_h = tfl_maxpool_op.getFilterHeight(); int64_t kernel_w = tfl_maxpool_op.getFilterWidth(); @@ -1361,9 +1367,6 @@ LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite( if (!GetPaddingFromString(tfl_maxpool_op.getPadding().str(), &tf_pad).ok()) return failure(); - // Pooling has no non-unit dilation - DenseI64ArrayAttr dilation = rewriter.getDenseI64ArrayAttr({1, 1}); - RankedTensorType filter_type = RankedTensorType::get(i64array, rewriter.getIntegerType(64)); @@ -1376,8 +1379,13 @@ LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite( return failure(); } + // TFLite only supports NHWC format + const Value max_pool_input = getInputSlicedToItsUsedSize( + rewriter, op, tensorflow::FORMAT_NHWC, input_type, + tfl_maxpool_op.getInput(), kernel_size, pad, stride, dilation); + CreateReplaceOpAndInfer(rewriter, op, output_type, - tfl_maxpool_op.getInput(), + max_pool_input, kernel_size, stride, pad); return success(); } @@ -1511,6 +1519,102 @@ Value lowerGroupedConvolution(TFL::Conv2DOp op, PatternRewriter& rewriter) { convolutions, output_slice_dim); } +/* Ensure bias is of the correct type. +TOSA requires that bias must be of the same type as the output, and that +output type must be of a certain type depending on the input type. +*/ +static FailureOr> getTosaBias( + Operation* op, PatternRewriter& rewriter, ShapedType input_type, + ShapedType output_type, bool output_is_qtype, Value bias) { + Type bias_ety; + + int bias_bits; + if (output_is_qtype) { + auto input_qtype = + dyn_cast(input_type.getElementType()); + if (!input_qtype) { + return rewriter.notifyMatchFailure(op, + "output is qtype but input is not"); + } + int input_bits = input_qtype.getStorageTypeIntegralWidth(); + // For signed int8/int16 input tensor, int32/int48 bias and output + // tensor are generated. + bias_bits = input_bits == 16 ? 48 : 32; + bias_ety = rewriter.getIntegerType(bias_bits); + } else { + bias_ety = output_type.getElementType(); + bias_bits = bias_ety.getIntOrFloatBitWidth(); + } + + if (!bias || !dyn_cast(bias.getType())) { + // The bias may actually be typed "None" which has no value. TOSA requires + // bias to be an array of output_channel_count values, so create a constant + // of the appropriate number and type of zeros. + RankedTensorType bias_type = RankedTensorType::get({1}, bias_ety); + auto bias_attr = rewriter.getZeroAttr(bias_type); + bias = CreateOpAndInfer(rewriter, op->getLoc(), bias_type, + mlir::cast(bias_attr)); + } + + auto prev_bias_type = dyn_cast(bias.getType()); + if (!prev_bias_type) { + return rewriter.notifyMatchFailure(op, "bias not a ranked tensor"); + } + + auto prev_bias_etype = prev_bias_type.getElementType(); + + int prev_bias_bits; + if (auto prev_bias_eqtype = + dyn_cast(prev_bias_etype)) { + prev_bias_bits = prev_bias_eqtype.getStorageTypeIntegralWidth(); + } else { + prev_bias_bits = prev_bias_etype.getIntOrFloatBitWidth(); + } + + if (prev_bias_bits == bias_bits) { + return std::pair(bias_ety, bias); + } + + auto const_op = bias.getDefiningOp(); + if (!const_op) { + return rewriter.notifyMatchFailure(op, "bias not a ConstOp"); + } + + DenseElementsAttr bias_attr; + { + auto prev_bias_attr = + dyn_cast(const_op.getValuesAttr()); + if (!prev_bias_attr) { + return rewriter.notifyMatchFailure( + op, "bias values not DenseIntElementsAttr"); + } + // Promote to int32/int48 if necessary. + bias_attr = prev_bias_attr.mapValues( + bias_ety, + [bias_bits = bias_ety.getIntOrFloatBitWidth()]( + const APInt& x) -> APInt { return x.sext(bias_bits); }); + } + + ShapedType bias_output_type; + if (auto bias_attr_type = dyn_cast(bias_attr.getType())) { + bias_output_type = bias_attr_type.clone(bias_ety); + } else { + bias_output_type = dyn_cast(const_op.getResult().getType()); + if (!bias_output_type) { + return rewriter.notifyMatchFailure( + op, "bias defining op result not ShapedType"); + } + bias_output_type = bias_output_type.clone(bias_ety); + } + + auto new_const_op = + rewriter.create(op->getLoc(), bias_output_type, bias_attr); + Value new_bias = new_const_op.getResult(); + rewriter.replaceOp(const_op, new_bias); + + return std::make_pair(bias_ety, new_bias); +} + LogicalResult ConvertTFLConv2DOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_conv2d_op = cast(op); @@ -1583,19 +1687,10 @@ LogicalResult ConvertTFLConv2DOp::matchAndRewrite( return failure(); } - Value unquantized_bias = tfl_conv2d_op.getBias(); - Type bias_ety = - output_is_qtype ? rewriter.getI32Type() : output_type.getElementType(); - if (unquantized_bias) { - Type new_bias_ety = getElementTypeOrSelf(unquantized_bias.getType()); - if (auto qtype = mlir::dyn_cast(new_bias_ety)) { - new_bias_ety = qtype.getStorageType(); - } - if (new_bias_ety.getIntOrFloatBitWidth() > - bias_ety.getIntOrFloatBitWidth()) { - bias_ety = new_bias_ety; - } - } + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_conv2d_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); // TFLite only supports NHWC format Value conv2d_input = getInputSlicedToItsUsedSize( @@ -1609,8 +1704,7 @@ LogicalResult ConvertTFLConv2DOp::matchAndRewrite( auto a1_conv2d_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(bias_ety), conv2d_input, - tfl_conv2d_op.getFilter(), unquantized_bias, pad, stride, dilation, - acc_type); + tfl_conv2d_op.getFilter(), bias_val, pad, stride, dilation, acc_type); Value conv2d_output; if (input_is_qtype) { @@ -1654,11 +1748,11 @@ LogicalResult ConvertTFLConv3DOp::matchAndRewrite( } bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -1710,37 +1804,26 @@ LogicalResult ConvertTFLConv3DOp::matchAndRewrite( } } - Value unquantized_bias = tfl_conv3d_op.getBias(); - if (!dyn_cast(unquantized_bias.getType())) { - // The bias may actually be typed "None" which has no value. TOSA requires - // bias to be an array of output_channel_count values, so create a constant - // of the appropriate number and type of zeros. - auto bias_dim = filter_type.getShape().back(); - RankedTensorType bias_type = - RankedTensorType::get({bias_dim}, filter_type.getElementType()); - auto bias_attr = rewriter.getZeroAttr(bias_type); - unquantized_bias = CreateOpAndInfer( - rewriter, op->getLoc(), bias_type, bias_attr.cast()); - } - // TFLite only supports NDHWC format, tensorflow::FORMAT_NHWC is used for both // rank 4 and rank 5 tensors Value conv3d_input = getInputSlicedToItsUsedSize( rewriter, op, tensorflow::FORMAT_NHWC, input_type, tfl_conv3d_op.getInput(), kernel_size, pad, stride, dilation); - Type bias_ety = - unquantized_bias.getType().cast().getElementType(); + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_conv3d_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); auto acc_type = getConvAccTypeAttr(rewriter, /* input_etype = */ input_type.getElementType(), /* output_etype = */ bias_ety); - std::optional a1_conv3d_op = convertConv3DCommon( - rewriter, op, output_type.clone(bias_ety), conv3d_input, - tfl_conv3d_op.getFilter(), unquantized_bias, pad, stride, dilation, - acc_type, StringRef("NDHWC")); + std::optional a1_conv3d_op = + convertConv3DCommon(rewriter, op, output_type.clone(bias_ety), + conv3d_input, tfl_conv3d_op.getFilter(), bias_val, + pad, stride, dilation, acc_type, StringRef("NDHWC")); if (!a1_conv3d_op) return failure(); @@ -1789,23 +1872,6 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( bool output_is_qtype = mlir::isa(output_type.getElementType()); - const bool has_bias = - tfl_conv_op.getBias() && !isa(tfl_conv_op.getBias().getType()); - - if (has_bias) { - RankedTensorType bias_type = - dyn_cast(tfl_conv_op.getBias().getType()); - bool bias_is_qtype = - isa(bias_type.getElementType()); - - if (input_is_qtype != bias_is_qtype) { - return rewriter.notifyMatchFailure( - op, - "input/bias tensor should " - "be all quantized or all floating-point"); - } - } - if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { return rewriter.notifyMatchFailure( @@ -1835,49 +1901,10 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( return failure(); } - int output_channel = 0; - // TODO(suderman): We need to figure out how to guarantee output channel - // propagation. - if (output_type.hasRank()) { - output_channel = output_type.getDimSize(3); - } else if (filter_type.hasRank()) { - output_channel = filter_type.getDimSize(0); - } else { - return failure(); - } - - Value bias_val; - if (has_bias) { - bias_val = tfl_conv_op.getBias(); - } else { - std::optional zero_bias; - if (input_is_qtype) { - uint32_t input_bits = - cast(input_type.getElementType()) - .getStorageTypeIntegralWidth(); - uint32_t weight_bits = - cast(filter_type.getElementType()) - .getStorageTypeIntegralWidth(); - - if (input_bits == 16 && weight_bits == 8) { - // For signed 16x8, the output is accumulated into int48 - SmallVector vec(output_channel, APInt(48, 0, true)); - zero_bias = getConstTensor(rewriter, op, vec, {output_channel}); - } else { - SmallVector vec(output_channel, 0); - zero_bias = - getConstTensor(rewriter, op, vec, {output_channel}); - } - } else { - SmallVector vec(output_channel, 0.0f); - zero_bias = getConstTensor(rewriter, op, vec, {output_channel}); - } - - if (!zero_bias) return failure(); - bias_val = zero_bias.value(); - } - - Type bias_ety = cast(bias_val.getType()).getElementType(); + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_conv_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); auto acc_type = getConvAccTypeAttr(rewriter, @@ -1886,8 +1913,8 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( auto a1_conv2d_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(bias_ety), - tfl_conv_op.getInput(), tfl_conv_op.getWeights(), bias_val, - outpad, stride, acc_type); + tfl_conv_op.getInput(), tfl_conv_op.getWeights(), bias_val, outpad, + stride, acc_type); Value conv2d_output; if (input_is_qtype) { @@ -1931,11 +1958,11 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( if (!filter_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -2020,20 +2047,10 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( filter_type.getElementType()), a1_filter_transpose_op.getResult(), a2_reshape_dims_value); - Type bias_ety = - output_is_qtype ? rewriter.getI32Type() : output_type.getElementType(); - - Value unquantized_bias = tfl_conv2d_op.getBias(); - if (unquantized_bias) { - Type new_bias_ety = getElementTypeOrSelf(unquantized_bias.getType()); - if (auto qtype = new_bias_ety.dyn_cast()) { - new_bias_ety = qtype.getStorageType(); - } - if (new_bias_ety.getIntOrFloatBitWidth() > - bias_ety.getIntOrFloatBitWidth()) { - bias_ety = new_bias_ety; - } - } + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_conv2d_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); // TFLite only supports NHWC format Value conv2d_input = getInputSlicedToItsUsedSize( @@ -2047,7 +2064,7 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( auto a3_depthwise_conv2d_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(bias_ety), conv2d_input, - a2_filter_reshape_op.getResult(), unquantized_bias, pad, stride, dilation, + a2_filter_reshape_op.getResult(), bias_val, pad, stride, dilation, acc_type); Value conv2d_output; @@ -2138,8 +2155,8 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( rewriter, op->getLoc(), UnrankedTensorType::get(rhs_ty.getElementType()), rhs, new_rhs_shape_value); - lhs_ty = lhs.getType().cast(); - rhs_ty = rhs.getType().cast(); + lhs_ty = mlir::cast(lhs.getType()); + rhs_ty = mlir::cast(rhs.getType()); } if (transpose_lhs) { @@ -2231,8 +2248,6 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( dyn_cast(tfl_fc_op.getInput().getType()); RankedTensorType filter_type = dyn_cast(tfl_fc_op.getFilter().getType()); - RankedTensorType bias_type = - dyn_cast(tfl_fc_op.getBias().getType()); if (!input_type || !filter_type) return failure(); bool input_is_qtype = @@ -2306,53 +2321,10 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( filter_val, new_filter_shape_value); filter_type = cast(filter_val.getType()); - Value bias_val; - if (!bias_type) { - // For some matmuls, the bias may actually be a "UnitType" which has no - // value. TOSA requires bias to be an array of output_channel_count values, - // so create a constant of the appropriate number and type of zeros. - SmallVector bias_shape({filter_type.getShape()[0]}); - RankedTensorType new_bias_type; - - DenseElementsAttr bias_attr; - if (mlir::isa(input_type.getElementType())) { - SmallVector bias_arr(bias_shape[0]); - - for (int i = 0; i < bias_shape[0]; i++) { - bias_arr[i] = 0.0; - } - new_bias_type = - RankedTensorType::get(bias_shape, input_type.getElementType()); - bias_attr = - DenseElementsAttr::get(new_bias_type, llvm::ArrayRef(bias_arr)); - } else { - SmallVector bias_arr(bias_shape[0]); - - for (int i = 0; i < bias_shape[0]; i++) { - bias_arr[i] = 0; - } - if (!input_is_qtype) { - return rewriter.notifyMatchFailure( - op, "input must be quantized type if it's not float type"); - } - auto input_qtype = - mlir::cast(input_type.getElementType()); - Type new_bias_ety = input_qtype.getStorageTypeIntegralWidth() == 16 - ? rewriter.getIntegerType(48) - : rewriter.getI32Type(); - new_bias_type = RankedTensorType::get(bias_shape, new_bias_ety); - bias_attr = - DenseElementsAttr::get(new_bias_type, llvm::ArrayRef(bias_arr)); - } - auto bias_op = CreateOpAndInfer(rewriter, op->getLoc(), - new_bias_type, bias_attr); - bias_val = bias_op.getResult(); - bias_type = new_bias_type; - } else { - bias_val = tfl_fc_op.getBias(); - } - - Type bias_ety = mlir::cast(bias_val.getType()).getElementType(); + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_fc_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); auto acc_type = getConvAccTypeAttr(rewriter, @@ -2378,19 +2350,16 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( // If we know the output rank, we need to ensure the output shape is correct. ShapedType fc_type = mlir::cast(fc_output.getType()); - DenseI64ArrayAttr output_shape_attr; - if (output_type.hasRank()) { - output_shape_attr = rewriter.getDenseI64ArrayAttr(output_type.getShape()); + llvm::SmallVector output_shape; + if (tfl_fc_op.getKeepNumDims()) { + const llvm::ArrayRef orig_input_shape = tfl_fc_op.getInput().getType().getShape(); + output_shape.append(orig_input_shape.begin(), orig_input_shape.end() - 1); + output_shape.push_back(OC); } else { - // set output_shape to {N, OC} to match previous results - // with tosa::FullyConnectedOp - output_shape_attr = rewriter.getDenseI64ArrayAttr({N, OC}); + output_shape.append({N, OC}); } - auto output_shape_value = - (output_type.hasRank()) - ? getTosaConstShape(rewriter, op->getLoc(), output_type.getShape()) - : getTosaConstShape(rewriter, op->getLoc(), {N, OC}); + auto output_shape_value = getTosaConstShape(rewriter, op->getLoc(), output_shape); fc_output = CreateOpAndInfer( rewriter, op->getLoc(), UnrankedTensorType::get(fc_type.getElementType()), fc_output, output_shape_value); @@ -2644,7 +2613,7 @@ LogicalResult ConvertTFLReduceAllOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "fail to get reduction indices"); std::optional result = convertReduceAllOp( - rewriter, op, output_type, tfl_all_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_all_op.getInput(), axes_elems, tfl_all_op.getKeepDims()); if (!result) return failure(); @@ -2666,7 +2635,7 @@ LogicalResult ConvertTFLReduceAnyOp::matchAndRewrite( return failure(); std::optional result = convertReduceAnyOp( - rewriter, op, output_type, tfl_any_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_any_op.getInput(), axes_elems, tfl_any_op.getKeepDims()); if (!result) return failure(); @@ -2688,7 +2657,7 @@ LogicalResult ConvertTFLReduceMaxOp::matchAndRewrite( return failure(); std::optional result = convertReduceMaxOp( - rewriter, op, output_type, tfl_max_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_max_op.getInput(), axes_elems, tfl_max_op.getKeepDims()); if (!result) return failure(); @@ -2710,7 +2679,7 @@ LogicalResult ConvertTFLReduceMinOp::matchAndRewrite( return failure(); std::optional result = convertReduceMinOp( - rewriter, op, output_type, tfl_min_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_min_op.getInput(), axes_elems, tfl_min_op.getKeepDims()); if (!result) return failure(); @@ -2732,7 +2701,7 @@ LogicalResult ConvertTFLReduceProdOp::matchAndRewrite( return failure(); std::optional result = convertReduceProdOp( - rewriter, op, output_type, tfl_prod_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_prod_op.getInput(), axes_elems, tfl_prod_op.getKeepDims()); if (!result) return failure(); @@ -2754,7 +2723,7 @@ LogicalResult ConvertTFLMeanOp::matchAndRewrite( return failure(); std::optional result = convertReduceMeanOp( - rewriter, op, output_type, tfl_mean_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_mean_op.getInput(), axes_elems, tfl_mean_op.getKeepDims()); if (!result) return failure(); @@ -2776,7 +2745,7 @@ LogicalResult ConvertTFLSumOp::matchAndRewrite( return failure(); std::optional result = convertReduceSumOp( - rewriter, op, output_type, tfl_sum_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_sum_op.getInput(), axes_elems, tfl_sum_op.getKeepDims()); if (!result) return failure(); @@ -3476,17 +3445,11 @@ LogicalResult ConvertTFLHardSwishOp::matchAndRewrite( mlir::dyn_cast_or_null( output_type.getElementType()); - auto hardswish_func = [](double v) -> double { - double w = v + 3.0; - w = w < 0.0 ? 0.0 : w > 6.0 ? 6.0 : w; - return v * w / 6.0; - }; - if (input_qtype.getStorageTypeIntegralWidth() == 8) { // Implement with 8-bit table lookup. - Value table_const = getTosaConst8bitTable( + Value table_const = getTosaConstHardSwish8bitTable( rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), - output_qtype.getScale(), output_qtype.getZeroPoint(), hardswish_func); + output_qtype.getScale(), output_qtype.getZeroPoint()); CreateReplaceOpAndInfer( rewriter, op, output_type, tfl_hardswish_op.getInput(), table_const); @@ -3636,7 +3599,8 @@ LogicalResult ConvertTFLAtan2Op::matchAndRewrite( // Note: the implementation of std::atan2 may be different on // different machines, so may result in varying numerical results. auto atan_func = [](double x) -> double { return std::atan(x); }; - Value table_const = getTosaConst16bitTable(rewriter, op, atan_func, 0.0, 1.0); + Value table_const = getTosaConst16bitTable( + rewriter, op, 1.0 / 65535.0, -32768, 2.0 / 65535.0, 0, atan_func); auto table_result = CreateOpAndInfer( rewriter, loc, output_ty.clone(rewriter.getIntegerType(32)), casted, table_const); @@ -3729,13 +3693,10 @@ LogicalResult ConvertTFLLogisticOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "input/output zeropoint should be 0 in 16-bit mode"); } - double input_min = -32768 * input_qtype.getScale(); - double input_max = 32767 * input_qtype.getScale(); - // Generate table with gen_lut() in - // tensorflow/lite/kernels/internal/common.h - Value table_const = getTosaConst16bitTable(rewriter, op, sigmoid_func, - input_min, input_max); + Value table_const = + getTosaConst16bitTable(rewriter, op, input_qtype.getScale(), + 0, 2.0 / 65535.0, 0, sigmoid_func); auto op1_table_in = CreateOpAndInfer(rewriter, op->getLoc(), int32_type, @@ -3801,13 +3762,9 @@ LogicalResult ConvertTFLTanhOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "input/output zeropoint should be 0 in 16-bit mode"); } - double input_min = -32768 * input_qtype.getScale(); - double input_max = 32767 * input_qtype.getScale(); - // Generate table with gen_lut() in - // tensorflow/lite/kernels/internal/common.h - Value table_const = - getTosaConst16bitTable(rewriter, op, tanh_func, input_min, input_max); + Value table_const = getTosaConst16bitTable( + rewriter, op, input_qtype.getScale(), 0, 2.0 / 65535.0, 0, tanh_func); auto op1_table_in = CreateOpAndInfer(rewriter, op->getLoc(), int32_type, @@ -3833,7 +3790,7 @@ static LogicalResult LegalizeFloatingPointPrelu(Operation* op, Value input, Value alpha, ShapedType output_type) { Value mul = CreateMulOpAndInfer(rewriter, op, output_type, input, alpha); - auto rank = mul.getType().cast().getRank(); + auto rank = mlir::cast(mul.getType()).getRank(); Value const_zero = getTosaConstTensorSingleF32(rewriter, op, 0.0, rank); auto ge = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(rewriter.getIntegerType(1)), @@ -4007,7 +3964,7 @@ static LogicalResult LegalizeFloatingPointLeakyRelu(Operation* op, PatternRewriter& rewriter, Value input, double alpha, ShapedType output_type) { - auto rank = input.getType().cast().getRank(); + auto rank = mlir::cast(input.getType()).getRank(); Value const_alpha = getTosaConstTensorSingleF32(rewriter, op, alpha, rank); auto mul = CreateMulOpAndInfer(rewriter, op, output_type, input, const_alpha); if (alpha <= 1.0) { @@ -4382,6 +4339,22 @@ LogicalResult ConvertTFLGatherNdOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFLScatterNdOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_scatternd_op = cast(op); + + const std::optional result = convertScatterNdOp( + rewriter, op, tfl_scatternd_op.getResult(), tfl_scatternd_op.getIndices(), + tfl_scatternd_op.getUpdates(), tfl_scatternd_op.getShape()); + + if (!result) { + return failure(); + } + rewriter.replaceOp(op, {result.value()}); + + return success(); +} + LogicalResult ConvertTFLSparseToDenseOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_sparse_to_dense_op = cast(op); @@ -4562,12 +4535,12 @@ LogicalResult ConvertTFLArgMinOp::matchAndRewrite( // so need to rescale ArgMax output to original output zero point int output_zp = 0; Type output_ty = arg_min_op.getType(); - Type output_ety = output_ty.cast().getElementType(); + Type output_ety = mlir::cast(output_ty).getElementType(); if (auto output_quantized_ty = dyn_cast(output_ety)) { output_zp = output_quantized_ty.getZeroPoint(); if (output_zp != 0) { // need to rescale arg_max output to output zero point - output_ty = output_ty.cast().clone(input_ety); + output_ty = mlir::cast(output_ty).clone(input_ety); } } @@ -4837,6 +4810,11 @@ LogicalResult ConvertTFLLogicalOrOp::matchAndRewrite( return ConvertBinaryOp(op, rewriter); } +LogicalResult ConvertTFLBitwiseXorOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + return ConvertBinaryOp(op, rewriter); +} + LogicalResult ConvertTFLPowOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { return ConvertBinaryOp(op, rewriter); @@ -4857,6 +4835,128 @@ LogicalResult ConvertTFLBroadcastToOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFLExpOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_exp_op = cast(op); + + RankedTensorType input_type = + dyn_cast(tfl_exp_op.getX().getType()); + RankedTensorType output_type = + dyn_cast(tfl_exp_op.getResult().getType()); + + if (!input_type || !output_type) { + return rewriter.notifyMatchFailure( + op, "input/output are not all a ranked tensor"); + } + + mlir::quant::UniformQuantizedType input_qtype = + dyn_cast_or_null( + input_type.getElementType()); + mlir::quant::UniformQuantizedType output_qtype = + dyn_cast_or_null( + output_type.getElementType()); + + if ((input_qtype == nullptr) != (output_qtype == nullptr)) { + return rewriter.notifyMatchFailure( + op, + "input/output tensor should be all quantized or all floating-point"); + } + + // Quantization case + if (input_qtype && output_qtype) { + auto exp_func = [](float x) -> float { return std::exp(x); }; + + Value table_const; + if (input_qtype.getStorageTypeIntegralWidth() == 8) { + table_const = getTosaConst8bitTable( + rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), + output_qtype.getScale(), output_qtype.getZeroPoint(), exp_func); + } else if (input_qtype.getStorageTypeIntegralWidth() == 16) { + table_const = getTosaConst16bitTable( + rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), + output_qtype.getScale(), output_qtype.getZeroPoint(), exp_func); + } else { + return rewriter.notifyMatchFailure( + op, "only quantized int8 and int16 are supported"); + } + + CreateReplaceOpAndInfer(rewriter, op, output_type, + tfl_exp_op.getX(), table_const); + return success(); + } + + CreateReplaceOpAndInfer(rewriter, op, tfl_exp_op.getType(), + tfl_exp_op.getX()); + + return success(); +} + +LogicalResult ConvertTFLLogOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_log_op = cast(op); + + RankedTensorType input_type = + dyn_cast(tfl_log_op.getX().getType()); + RankedTensorType output_type = + dyn_cast(tfl_log_op.getResult().getType()); + + if (!input_type || !output_type) { + return rewriter.notifyMatchFailure( + op, "input/output are not all a ranked tensor"); + } + + mlir::quant::UniformQuantizedType input_qtype = + dyn_cast_or_null( + input_type.getElementType()); + mlir::quant::UniformQuantizedType output_qtype = + dyn_cast_or_null( + output_type.getElementType()); + + if ((input_qtype == nullptr) != (output_qtype == nullptr)) { + return rewriter.notifyMatchFailure( + op, + "input/output tensor should be all quantized or all floating-point"); + } + + // Quantization case + if (input_qtype && output_qtype) { + const float output_min = + ((input_qtype.getStorageTypeIntegralWidth() == 8 ? -128 : -32768) - + output_qtype.getZeroPoint()) * + static_cast(output_qtype.getScale()); + + auto log_func = [&](float x) -> float { + if (x <= 0.0f) { + return output_min; + } + return std::log(x); + }; + + Value table_const; + if (input_qtype.getStorageTypeIntegralWidth() == 8) { + table_const = getTosaConst8bitTable( + rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), + output_qtype.getScale(), output_qtype.getZeroPoint(), log_func); + } else if (input_qtype.getStorageTypeIntegralWidth() == 16) { + table_const = getTosaConst16bitTable( + rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), + output_qtype.getScale(), output_qtype.getZeroPoint(), log_func); + } else { + return rewriter.notifyMatchFailure( + op, "only quantized int8 and int16 are supported"); + } + + CreateReplaceOpAndInfer(rewriter, op, output_type, + tfl_log_op.getX(), table_const); + return success(); + } + + CreateReplaceOpAndInfer(rewriter, op, tfl_log_op.getType(), + tfl_log_op.getX()); + + return success(); +} + LogicalResult LegalizeTFL::initialize(MLIRContext* context) { RewritePatternSet patterns(context); mlir::tosa::populateLegalizeTFLPatterns(context, patterns); @@ -4892,6 +4992,7 @@ void populateLegalizeTFLPatterns(MLIRContext* ctx, DEF_PATTERN_INSERT(TFLLogicalAnd); DEF_PATTERN_INSERT(TFLLogicalOr); + DEF_PATTERN_INSERT(TFLBitwiseXor); DEF_PATTERN_INSERT(TFLPow); DEF_PATTERN_INSERT(TFLGelu); @@ -4983,6 +5084,7 @@ void populateLegalizeTFLPatterns(MLIRContext* ctx, DEF_PATTERN_INSERT(TFLConst); DEF_PATTERN_INSERT(TFLQConst); DEF_PATTERN_INSERT(TFLGatherNd); + DEF_PATTERN_INSERT(TFLScatterNd); DEF_PATTERN_INSERT(TFLSparseToDense); DEF_PATTERN_INSERT(Constant); DEF_PATTERN_INSERT(TFLOneHot); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index d3be58ce5a7d51..4356c7a34d177e 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -113,8 +113,8 @@ std::optional convertTFConv2DCommon( stride = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now - int64_t stride_h = strides_attr[1].cast().getInt(); - int64_t stride_w = strides_attr[2].cast().getInt(); + int64_t stride_h = mlir::cast(strides_attr[1]).getInt(); + int64_t stride_w = mlir::cast(strides_attr[2]).getInt(); stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } } @@ -123,8 +123,8 @@ std::optional convertTFConv2DCommon( dilation = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now - int64_t dilation_h = dilations_attr[1].cast().getInt(); - int64_t dilation_w = dilations_attr[2].cast().getInt(); + int64_t dilation_h = mlir::cast(dilations_attr[1]).getInt(); + int64_t dilation_w = mlir::cast(dilations_attr[2]).getInt(); dilation = rewriter.getDenseI64ArrayAttr({dilation_h, dilation_w}); } } @@ -172,9 +172,9 @@ std::optional convertTFConv3DCommon( // Defaults to [1, 1, 1]. strides = rewriter.getDenseI64ArrayAttr({1, 1, 1}); } else { - int64_t stride_d = strides_attr[1].cast().getInt(); - int64_t stride_h = strides_attr[2].cast().getInt(); - int64_t stride_w = strides_attr[3].cast().getInt(); + int64_t stride_d = mlir::cast(strides_attr[1]).getInt(); + int64_t stride_h = mlir::cast(strides_attr[2]).getInt(); + int64_t stride_w = mlir::cast(strides_attr[3]).getInt(); strides = rewriter.getDenseI64ArrayAttr({stride_d, stride_h, stride_w}); } @@ -183,17 +183,18 @@ std::optional convertTFConv3DCommon( // Defaults to [1, 1, 1]. dilations = rewriter.getDenseI64ArrayAttr({1, 1, 1}); } else { - int64_t dilation_d = dilations_attr[1].cast().getInt(); - int64_t dilation_h = dilations_attr[2].cast().getInt(); - int64_t dilation_w = dilations_attr[3].cast().getInt(); + int64_t dilation_d = mlir::cast(dilations_attr[1]).getInt(); + int64_t dilation_h = mlir::cast(dilations_attr[2]).getInt(); + int64_t dilation_w = mlir::cast(dilations_attr[3]).getInt(); dilations = rewriter.getDenseI64ArrayAttr({dilation_d, dilation_h, dilation_w}); } - RankedTensorType input_type = input.getType().cast(); + RankedTensorType input_type = mlir::cast(input.getType()); DenseI64ArrayAttr pads; { - RankedTensorType filter_type = filter.getType().cast(); + RankedTensorType filter_type = + mlir::cast(filter.getType()); tensorflow::TensorFormat data_format_tf; if (!FormatFromString(data_format_ref, &data_format_tf)) { @@ -582,6 +583,90 @@ Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op, } } +Value getTosaConstHardSwish8bitTable(PatternRewriter& rewriter, Operation* op, + float input_scale, int32_t input_zp, + float output_scale, int32_t output_zp) { + // Define tflite params: + // See: HardSwishPrepare / HardSwishParams + const float hires_input_scale = (1.0f / 128.0f) * input_scale; + const float reluish_scale = 3.0f / 32768.0f; + const float output_multiplier = hires_input_scale / output_scale; + + int16_t output_multiplier_fixedpoint_int16; + int output_multiplier_exponent; + + int16_t reluish_multiplier_fixedpoint_int16; + int reluish_multiplier_exponent; + + int32_t output_multiplier_fixedpoint_int32; + tflite::QuantizeMultiplier(output_multiplier, + &output_multiplier_fixedpoint_int32, + &output_multiplier_exponent); + tflite::DownScaleInt32ToInt16Multiplier(output_multiplier_fixedpoint_int32, + &output_multiplier_fixedpoint_int16); + assert(output_multiplier_exponent <= 0); + + const float reluish_multiplier = hires_input_scale / reluish_scale; + int32_t reluish_multiplier_fixedpoint_int32; + + tflite::QuantizeMultiplier(reluish_multiplier, + &reluish_multiplier_fixedpoint_int32, + &reluish_multiplier_exponent); + tflite::DownScaleInt32ToInt16Multiplier(reluish_multiplier_fixedpoint_int32, + &reluish_multiplier_fixedpoint_int16); + + // See HardSwish function in + // tensorflow/lite/kernels/internal/reference/hardswish.h + SmallVector table; + for (int32_t i = -128; i < 128; i++) { + const int16_t input_value = i - input_zp; + const int16_t input_value_on_hires_input_scale = input_value * (1 << 7); + const int16_t input_value_on_preshift_output_scale = + gemmlowp::SaturatingRoundingDoublingHighMul( + input_value_on_hires_input_scale, + output_multiplier_fixedpoint_int16); + int16_t reluish_value = input_value_on_hires_input_scale; + if (reluish_multiplier_exponent > 0) { + reluish_value = tflite::reference_ops::SaturatingLeftShift( + reluish_value, reluish_multiplier_exponent - 1); + } + reluish_value = gemmlowp::SaturatingRoundingDoublingHighMul( + reluish_value, reluish_multiplier_fixedpoint_int16); + if (reluish_multiplier_exponent > 0) { + reluish_value = + tflite::reference_ops::SaturatingLeftShift(reluish_value, 1); + } + if (reluish_multiplier_exponent < 0) { + reluish_value = gemmlowp::RoundingDivideByPOT( + reluish_value, -reluish_multiplier_exponent); + } + reluish_value = (reluish_value + (1 << 15)) >> 1; + const int16_t preshift_output_value = + tflite::reference_ops::SaturatingDoublingHighMul( + reluish_value, input_value_on_preshift_output_scale); + int16_t output_value = gemmlowp::RoundingDivideByPOT( + preshift_output_value, -output_multiplier_exponent); + output_value += output_zp; + output_value = + std::min(output_value, std::numeric_limits::max()); + output_value = + std::max(output_value, std::numeric_limits::min()); + table.push_back(output_value); + } + + auto element_qtype = + UniformQuantizedType::get(true, rewriter.getIntegerType(8), + rewriter.getF32Type(), 1.0f, 0, -128, 127); + auto const_type = tensorflow::GetTypeFromTFTensorShape({256}, element_qtype); + auto storage_type = tensorflow::GetTypeFromTFTensorShape( + {256}, element_qtype.getStorageType()); + auto const_attr = DenseElementsAttr::get(storage_type, llvm::ArrayRef(table)); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + Value getTosaConstRsqrt8bitTable(PatternRewriter& rewriter, Operation* op, float input_scale, int32_t input_zp, float output_scale, int32_t output_zp) { @@ -637,24 +722,25 @@ Value getTosaConstRsqrt8bitTable(PatternRewriter& rewriter, Operation* op, } // Create a 8-bit TOSA TABLE constant tensor with int8[256] array. -// Follow PopulateLookupTable() tensorflow/lite/kernels/activations.cc +// Follow LUTPopulateInt8() tensorflow/lite/kernels/internal/common.h Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op, - double input_scale, int32_t input_zp, - double output_scale, int32_t output_zp, - std::function func) { + float input_scale, int32_t input_zp, + float output_scale, int32_t output_zp, + std::function func) { SmallVector table; + float inverse_scale = 1.0f / output_scale; for (int32_t i = -128; i < 128; i++) { - double dequantized = input_scale * (i - input_zp); - double transformed = func(dequantized); + float dequantized = input_scale * (i - input_zp); + float transformed = func(dequantized); - double max = (output_scale > 1.0) ? DBL_MAX : (DBL_MAX * output_scale); + float max = (output_scale > 1.0) ? FLT_MAX : (FLT_MAX * output_scale); if (transformed >= max) { table.push_back(INT8_MAX); continue; } - int32_t rescaled = std::llround(transformed / output_scale); + int32_t rescaled = std::round(transformed * inverse_scale); int32_t quantized = static_cast(rescaled + output_zp); table.push_back( static_cast(std::min(std::max(quantized, -128), 127))); @@ -673,34 +759,52 @@ Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op, return const_op.getResult(); } -// Create a 16-bit TOSA TABLE constant tensor with int16[513] array. -// Output is restricted to [-1.0, 1.0]. -// Follow gen_lut() tensorflow/lite/kernels/internal/common.h +// Create a 16-bit TOSA TABLE constant tensor. +// A float should be used by default for FloatT except if a double is required +// for backward compatibility. +// Follow LUTPopulateInt16() tensorflow/lite/kernels/internal/common.h +template Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op, - std::function func, double min, - double max) { + FloatT input_scale, int32_t input_zp, + FloatT output_scale, int32_t output_zp, + std::function func) { + static_assert(std::is_floating_point::value, + "FloatT must be a floating-point type."); + SmallVector table; - double step = (max - min) / 512.0f; - double half_step = step / 2.0f; + FloatT input_min = + input_scale * (std::numeric_limits::min() - input_zp); + FloatT input_max = + input_scale * (std::numeric_limits::max() - input_zp); + FloatT output_min = + output_scale * (std::numeric_limits::min() - output_zp); + FloatT output_max = + output_scale * (std::numeric_limits::max() - output_zp); + + FloatT step = (input_max - input_min) / 512; + FloatT half_step = step / 2; + FloatT output_scaling_inv = 65536 / (output_max - output_min); + for (int32_t i = 0; i < 512; i++) { - int32_t sample_val = std::llround(func(min + (i * step)) * 32768.0); - double midpoint_interp_val = - std::round(((func(min + (i + 1) * step) * 32768.0) + - std::round(func(min + (i * step)) * 32768.0)) / - 2.0); - double midpoint_val = - std::round(func(min + (i * step) + half_step) * 32768.0); - double midpoint_err = midpoint_interp_val - midpoint_val; - int32_t bias = std::llround(midpoint_err / 2.0); + FloatT sample_val = + std::round(func(input_min + (i * step)) * output_scaling_inv); + FloatT midpoint_interp_val = std::round( + ((func(input_min + (i + 1) * step) * output_scaling_inv) + + std::round(func(input_min + (i * step)) * output_scaling_inv)) / + 2); + FloatT midpoint_val = std::round(func(input_min + (i * step) + half_step) * + output_scaling_inv); + FloatT midpoint_err = midpoint_interp_val - midpoint_val; + FloatT bias = std::round(midpoint_err / 2); table.push_back(static_cast( - std::min(std::max(sample_val - bias, -32768), 32767))); + std::min(std::max(sample_val - bias, -32768), 32767))); } - int32_t max_val = std::llround(func(max) * 32768.0); - table.push_back( - static_cast(std::min(std::max(max_val, -32768), 32767))); + FloatT max_val = std::round(func(input_max) * output_scaling_inv); + table.push_back(static_cast( + std::min(std::max(max_val, -32768), 32767))); auto const_type = tensorflow::GetTypeFromTFTensorShape({513}, rewriter.getIntegerType(16)); @@ -711,6 +815,18 @@ Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op, return const_op.getResult(); } +template Value getTosaConst16bitTable(PatternRewriter& rewriter, + Operation* op, float input_scale, + int32_t input_zp, + float output_scale, + int32_t output_zp, + std::function func); + +template Value getTosaConst16bitTable( + PatternRewriter& rewriter, Operation* op, double input_scale, + int32_t input_zp, double output_scale, int32_t output_zp, + std::function func); + // Create a 32-bit TOSA TABLE for Softmax Exp void getTosaConst32bitSoftmaxExpTable(PatternRewriter& rewriter, Operation* op, double beta, double input_scale, @@ -1036,14 +1152,14 @@ bool getTransposeConv2dPaddingValues( return false; } - int total_padding = ((ifm_size - 1) * dim_stride + filter_size - ofm_size); - total_padding = total_padding > 0 ? total_padding : 0; + int total_padding = + ((ifm_size - 1) * dim_stride + filter_size - ofm_size); pad_before = total_padding / 2; pad_after = total_padding - pad_before; - computed_paddings.push_back(pad_before); - computed_paddings.push_back(pad_after); + computed_paddings.push_back(-pad_before); + computed_paddings.push_back(-pad_after); } explicit_padding = rewriter.getDenseI64ArrayAttr(computed_paddings); @@ -1292,10 +1408,10 @@ Value reshapeScalarTo1D(PatternRewriter& rewriter, Location loc, Value value) { } DenseElementsAttr const_attr; - if (attr.getElementType().isa()) { + if (mlir::isa(attr.getElementType())) { const_attr = DenseElementsAttr::get(storage_type, {attr.getValues()[0]}); - } else if (attr.getElementType().isa()) { + } else if (mlir::isa(attr.getElementType())) { const_attr = DenseElementsAttr::get(storage_type, {attr.getValues()[0]}); } else { @@ -1382,5 +1498,36 @@ LogicalResult broadcastLowRankTensor(PatternRewriter& rewriter, Operation* op, return success(); } +bool checkUniqueConstantScatterIndices(ShapedType indices_type, + ShapedType result_type, + ElementsAttr const_data) { + llvm::ArrayRef const indices_shape = indices_type.getShape(); + const unsigned int indices_rank = indices_shape.size(); + const unsigned int result_rank = result_type.getRank(); + const unsigned int last_dim_size = indices_shape[indices_rank - 1]; + + // Reconstruct each index from the unshaped constant data array and + // calculate the corresponding flattened index + auto const const_data_range = const_data.getValues(); + assert((const_data_range.size() % last_dim_size == 0) && + "Constant data length should be a multiple of indices_shape[-1]"); + + std::vector flattened_indices; + flattened_indices.reserve(const_data_range.size() / last_dim_size); + for (auto beg = const_data_range.begin(); beg < const_data_range.end(); + beg += last_dim_size) { + std::vector current_single_index(result_rank); + std::copy(beg, beg + last_dim_size, current_single_index.begin()); + const uint64_t f_index{ + ElementsAttr::getFlattenedIndex(result_type, current_single_index)}; + flattened_indices.push_back(f_index); + } + + // If adjacent flattened values are found, there are non-unique indices + std::sort(flattened_indices.begin(), flattened_indices.end()); + return std::adjacent_find(flattened_indices.begin(), + flattened_indices.end()) == flattened_indices.end(); +} + } // namespace tosa } // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h index 8dc618d2bf2608..a2b990446924c9 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h @@ -102,14 +102,18 @@ Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op, // Create a 8-bit TOSA TABLE constant tensor Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op, - double input_scale, int32_t input_zp, - double output_scale, int32_t output_zp, - std::function func); + float input_scale, int32_t input_zp, + float output_scale, int32_t output_zp, + std::function func); // Create a 16-bit TOSA TABLE constant tensor +// A float should be used by default for FloatT except if a double is required +// for backward compatibility +template Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op, - std::function func, double min, - double max); + FloatT input_scale, int32_t input_zp, + FloatT output_scale, int32_t output_zp, + std::function func); // Create a 32-bit TOSA TABLE for Softmax Exp void getTosaConst32bitSoftmaxExpTable(PatternRewriter& rewriter, Operation* op, @@ -122,6 +126,11 @@ Value getTosaConstRsqrt8bitTable(PatternRewriter& rewriter, Operation* op, float input_scale, int32_t input_zp, float output_scale, int32_t output_zp); +// Create an 8-bit TOSA Table constant tensor for the HardSwish operator +Value getTosaConstHardSwish8bitTable(PatternRewriter& rewriter, Operation* op, + float input_scale, int32_t input_zp, + float output_scale, int32_t output_zp); + // Create a 32-bit float constant operator from a float Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op, float val, int rank); @@ -203,6 +212,14 @@ Value getInputSlicedToItsUsedSize(PatternRewriter& rewriter, Operation* op, // Check if scale32 mode is used for given output_element_type bool isScale32(mlir::quant::UniformQuantizedType output_element_type); +// Checks if the multi-dimensional indices supplied by a constant tensor +// are unique. This is a useful check for legalizations to tosa.scatter +// which requires indices are unique, while in TF/TFLite they may be +// non-unique. +bool checkUniqueConstantScatterIndices(ShapedType indices_type, + ShapedType result_type, + ElementsAttr const_data); + // Applies a set of patterns greedily to the specified function, then applies // a cleanup to guarantee the function contract and constants are valid. This // means patterns can performed shape inference while not altering immutable diff --git a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td index a7230ccf901399..b0141dcaf9fa13 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td +++ b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td @@ -29,8 +29,6 @@ include "mlir/Dialect/Tosa/IR/TosaOps.td" def ConvertTFLAbsOp : Pat<(TFL_AbsOp $arg), (Tosa_AbsOp $arg)>; def ConvertTFLCeilOp : Pat<(TFL_CeilOp $arg), (Tosa_CeilOp $arg)>; def ConvertTFLFloorOp : Pat<(TFL_FloorOp $arg), (Tosa_FloorOp $arg)>; -def ConvertTFLExpOp : Pat<(TFL_ExpOp $arg), (Tosa_ExpOp $arg)>; -def ConvertTFLLogOp : Pat<(TFL_LogOp $arg), (Tosa_LogOp $arg)>; def ConvertTFLLogicalNotOp : Pat<(TFL_LogicalNotOp $arg), (Tosa_LogicalNotOp $arg)>; // Removing the quant.stats op for unquantized models. diff --git a/tensorflow/compiler/mlir/utils/BUILD b/tensorflow/compiler/mlir/utils/BUILD index 2256c421b45717..ae6a01df20e1b2 100644 --- a/tensorflow/compiler/mlir/utils/BUILD +++ b/tensorflow/compiler/mlir/utils/BUILD @@ -37,3 +37,40 @@ cc_library( "@llvm-project//llvm:Support", ], ) + +cc_library( + name = "saved_model_converter_utils", + srcs = ["saved_model_converter_utils.cc"], + hdrs = ["saved_model_converter_utils.h"], + visibility = [ + "//tensorflow/cc/experimental/tfa:__subpackages__", + ], + deps = [ + "//tensorflow/cc/saved_model:loader", + "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/compiler/mlir/tf2xla/api/v2:mlir_roundtrip_flags", + "//tensorflow/core/framework:op", + "//tensorflow/core/framework:op_def_builder", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "validators", + srcs = [ + "validators.cc", + ], + hdrs = [ + "validators.h", + ], + deps = [ + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/utils/saved_model_converter_utils.cc b/tensorflow/compiler/mlir/utils/saved_model_converter_utils.cc new file mode 100644 index 00000000000000..d818acf6ee528d --- /dev/null +++ b/tensorflow/compiler/mlir/utils/saved_model_converter_utils.cc @@ -0,0 +1,94 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/utils/saved_model_converter_utils.h" + +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/mlir_roundtrip_flags.h" + + +namespace tensorflow { +namespace utils { + +// Util that registers 'extra_tf_opdefs' to the TF global registry. +// Return OK on success, failure if registering failed. +absl::Status RegisterExtraTfOpDefs( + absl::Span extra_tf_opdefs) { + for (const auto& tf_opdefs_string : extra_tf_opdefs) { + OpDef opdef; + // NOLINTNEXTLINE: Use tsl::protobuf to be compatible with OSS. + if (!tsl::protobuf::TextFormat::ParseFromString(tf_opdefs_string, &opdef)) { + LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string; + return absl::InvalidArgumentError("fail to parse extra OpDef"); + } + // Register extra opdefs. + // TODO: b/133770952 - Support shape functions. + OpRegistry::Global()->Register( + [opdef](OpRegistrationData* op_reg_data) -> absl::Status { + *op_reg_data = OpRegistrationData(opdef); + return absl::OkStatus(); + }); + } + return absl::OkStatus(); +} + +absl::StatusOr> ImportSavedModel( + const std::string& input_filename, const int saved_model_version, + const std::unordered_set& tags, + absl::Span extra_tf_opdefs, + absl::Span exported_names, const GraphImportConfig& specs, + bool enable_variable_lifting, mlir::MLIRContext* context, + std::unique_ptr* saved_model_bundle) { + // Register extra TF ops passed as OpDef. + auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs); + if (!extra_opdefs_status.ok()) return extra_opdefs_status; + + if (saved_model_version == 2) { + auto module_or = SavedModelObjectGraphToMlirImport( + input_filename, tags, exported_names, context, + /*unconditionally_use_set_output_shapes=*/true); + if (!module_or.status().ok()) return module_or.status(); + return std::move(module_or).value(); + } else if (saved_model_version == 1) { + MLIRImportOptions options; + options.upgrade_legacy = specs.upgrade_legacy; + options.unconditionally_use_set_output_shapes = true; + options.lift_variables = enable_variable_lifting; + auto module_or = SavedModelSignatureDefsToMlirImport( + input_filename, tags, exported_names, context, options, + saved_model_bundle); + + if (!module_or.status().ok()) return module_or.status(); + return std::move(module_or).value(); + } else { + return absl::InvalidArgumentError("Should be either saved model v1 or v2."); + } +} + +} // namespace utils +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/utils/saved_model_converter_utils.h b/tensorflow/compiler/mlir/utils/saved_model_converter_utils.h new file mode 100644 index 00000000000000..fc4440fb918a37 --- /dev/null +++ b/tensorflow/compiler/mlir/utils/saved_model_converter_utils.h @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_SAVED_MODEL_CONVERTER_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_SAVED_MODEL_CONVERTER_UTILS_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/mlir_roundtrip_flags.h" + +namespace tensorflow { +namespace utils { + +// 'saved_model_bundle' will be initialized if V1 model was loaded. +absl::StatusOr> ImportSavedModel( + const std::string& input_filename, int saved_model_version, + const std::unordered_set& tags, + absl::Span extra_tf_opdefs, + absl::Span exported_names, const GraphImportConfig& specs, + bool enable_variable_lifting, mlir::MLIRContext* context, + std::unique_ptr* saved_model_bundle); + +} // namespace utils +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_SAVED_MODEL_CONVERTER_UTILS_H_ diff --git a/tensorflow/compiler/mlir/utils/validators.cc b/tensorflow/compiler/mlir/utils/validators.cc new file mode 100644 index 00000000000000..870c7e1f1efbfe --- /dev/null +++ b/tensorflow/compiler/mlir/utils/validators.cc @@ -0,0 +1,147 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/utils/validators.h" + +#include +#include + +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`. +bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y) { + auto attr = op->getAttrOfType(name); + if (!attr) return false; + + auto elements = attr.getValue(); + if (elements.size() != 4 || + std::any_of(elements.begin(), elements.end(), + [](Attribute e) { return !mlir::isa(e); })) + return false; + + if (mlir::cast(elements.front()).getInt() != 1 || + mlir::cast(elements.back()).getInt() != 1) + return false; + + Builder b(op->getContext()); + *x = b.getI32IntegerAttr(mlir::cast(elements[1]).getInt()); + *y = b.getI32IntegerAttr(mlir::cast(elements[2]).getInt()); + + return true; +} + +// Returns true if the attribute is an integer list of the form [1, X, Y, 1]. +bool TFIntListIs1XY1(const Attribute attr) { + const auto &elements = mlir::cast(attr).getValue(); + if (elements.size() != 4 || + std::any_of(elements.begin(), elements.end(), + [](Attribute e) { return !mlir::isa(e); })) + return false; + + if (mlir::cast(elements.front()).getValue() != 1 || + mlir::cast(elements.back()).getValue() != 1) + return false; + return true; +} + +// Returns true if the attribute is an integer list of the form [1, 1, X, Y]. +bool TFIntListIs11XY(const Attribute attr) { + const auto &elements = mlir::cast(attr).getValue(); + if (elements.size() != 4 || + std::any_of(elements.begin(), elements.end(), + [](Attribute e) { return !mlir::isa(e); })) + return false; + + const Attribute *data = elements.data(); + if (mlir::cast(data[0]).getValue() != 1 || + mlir::cast(data[1]).getValue() != 1) + return false; + return true; +} + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, Z, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`, z. +bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y, IntegerAttr *z) { + auto attr = op->getAttrOfType(name); + if (!attr) return false; + + auto elements = attr.getValue(); + if (elements.size() != 5 || + std::any_of(elements.begin(), elements.end(), + [](Attribute e) { return !mlir::isa(e); })) + return false; + + if (mlir::cast(elements.front()).getInt() != 1 || + mlir::cast(elements.back()).getInt() != 1) + return false; + + Builder b(op->getContext()); + *x = b.getI32IntegerAttr(mlir::cast(elements[1]).getInt()); + *y = b.getI32IntegerAttr(mlir::cast(elements[2]).getInt()); + *z = b.getI32IntegerAttr(mlir::cast(elements[3]).getInt()); + + return true; +} + +// Returns true if every element of the attribute is 1. All elements of `attr` +// must be `IntegerAttr`. +bool TFIntListIsAllOnes(const Attribute attr) { + const auto &elements = mlir::cast(attr).getValue(); + + return !std::any_of(elements.begin(), elements.end(), [](Attribute e) { + return mlir::cast(e).getValue() != 1; + }); +} + +bool IsBroadcastableElementsAttrs(mlir::TypedAttr a, mlir::TypedAttr b) { + // This would return false if we had unranked tensors (where they should + // probably be considered as broadcastable), but given we are working with + // attributes here that shouldn't be an issue, + return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type(); +} + +bool IsDimensionsDegenerateExceptLastOne(ArrayRef elements_shape) { + if (elements_shape.empty()) return true; + + for (auto dim : elements_shape.drop_back(1)) { + if (dim != 1) return false; + } + return true; +} + +bool IsDimensionsDegenerateExceptLastOne(TypedAttr val) { + if (auto ranked_type = mlir::dyn_cast(val.getType())) { + return IsDimensionsDegenerateExceptLastOne(ranked_type.getShape()); + } + return false; +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/utils/validators.h b/tensorflow/compiler/mlir/utils/validators.h new file mode 100644 index 00000000000000..b55bd219914603 --- /dev/null +++ b/tensorflow/compiler/mlir/utils/validators.h @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common validators used by TFLite transformation +// passes to validate op attributes or values. + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_VALIDATORS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_VALIDATORS_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// TODO(jpienaar): Change these to being one of these variants and/or generate +// these predicates. + +// Returns true if the given TensorFlow op does not have a `data_format` +// attribute (then default to "NHWC"), or its `data_format` attribute is "NHWC". +inline bool TFDataFormatIsNHWC(Operation *op) { + auto attr = op->getAttrOfType("data_format"); + return !attr || attr.getValue() == "NHWC"; +} + +// Returns true if the given TensorFlow op does not have a `data_format` +// attribute (then default to "NDHWC"), or its `data_format` attribute is +// "NDHWC". +inline bool TFDataFormatIsNDHWC(Operation *op) { + auto attr = op->getAttrOfType("data_format"); + return !attr || attr.getValue() == "NDHWC"; +} + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`. +bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y); + +// Returns true if the attribute is an integer list of the form [1, X, Y, 1]. +bool TFIntListIs1XY1(Attribute attr); + +// Returns true if the attribute is an integer list of the form [1, 1, X, Y]. +bool TFIntListIs11XY(Attribute attr); + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, Z, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`, z. +bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y, IntegerAttr *z); + +// Returns true if every element of the attribute is 1. All elements of `attr` +// must be `IntegerAttr`. +bool TFIntListIsAllOnes(Attribute attr); + +// Returns true iff the given value is a float32 tensor. +// is "DT_FLOAT". +inline bool TFTypeIsFloat32Tensor(Value value) { + auto tensorType = mlir::dyn_cast(value.getType()); + if (!tensorType) return false; + return tensorType.getElementType().isF32(); +} + +// Returns true iff the given value is a bf16 tensor. +inline bool TFTypeIsBFloat16Tensor(Value value) { + auto tensorType = mlir::dyn_cast(value.getType()); + if (!tensorType) return false; + return tensorType.getElementType().isBF16(); +} + +// Returns true iff the given value is a f16 tensor. +inline bool TFTypeIsHalfTensor(Value value) { + auto tensorType = mlir::dyn_cast(value.getType()); + if (!tensorType) return false; + return tensorType.getElementType().isF16(); +} + +// Returns true iff the given value is a f16 or bf16 tensor. +inline bool TFTypeIsBFloat16OrHalfTensor(Value value) { + return TFTypeIsBFloat16Tensor(value) || TFTypeIsHalfTensor(value); +} + +// Returns true iff the given TensorFlow op has a `padding` attribute whose +// value is "SAME" or "VALID", and writes the attribute to `padding`. +inline bool TFPaddingIsSameOrValid(Operation *op, StringAttr *padding) { + auto padding_attr = op->getAttrOfType("padding"); + if (padding_attr.getValue() != "SAME" && padding_attr.getValue() != "VALID") + return false; + *padding = padding_attr; + return true; +} + +/// Returns whether the given `a` and `b` have broadcast-compatible +/// types. +bool IsBroadcastableElementsAttrs(mlir::TypedAttr a, mlir::TypedAttr b); +// Returns true if every dimension of the attribute is 1 except the last one. +bool IsDimensionsDegenerateExceptLastOne(mlir::TypedAttr val); +// Returns true if every element is 1 except the last one. +bool IsDimensionsDegenerateExceptLastOne(ArrayRef elements_shape); + +} // end namespace TF +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_VALIDATORS_H_ diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index 809db242ac4afe..101ca75f8b68be 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -230,7 +230,8 @@ def testBetaincSanity(self): x = np.array([.3, .4, .0, .1], dtype=dtype) expected = sps.betainc(a, b, x) self._testTernary( - math_ops.betainc, a, b, x, expected, rtol=5e-6, atol=6e-6) + math_ops.betainc, a, b, x, expected, rtol=5e-5, atol=6e-5 + ) @parameterized.parameters( { diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index c83854fa7c4011..5504c8ecfd7765 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -1571,6 +1571,30 @@ def f(x): self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),)) + def test_op_backward_incompatibility(self): + """Test for ensuring XlaCallModuleOp with invalid bytecode.""" + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + def f(x): + # Use an invalid MLIR string that will fail to parse when loading the + # call module op, emulating a backward incompatibility. + corrupted_module = 'stablehlo.invalid_op' + return gen_xla_ops.xla_call_module( + [x], + version=xla.call_module_maximum_supported_version(), + module=corrupted_module, + Tout=[x.dtype], + Sout=[x.shape], + platforms=[self.testing_platform()], + ) + + # Expect any error message to be included after `:` + with self.assertRaisesRegex( + errors.InvalidArgumentError, + 'Cannot deserialize computation: .+', + ): + f(x) + if __name__ == '__main__': ops.enable_eager_execution( diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index 57f1cbdf3bd44f..50bd47ad73e77e 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -96,7 +96,7 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { // An non-merge op with inputs from then and else branch. absl::Status status = JoinCondStatesNonMerge(then_branch, else_branch).status(); - EXPECT_TRUE(errors::IsInvalidArgument(status)); + EXPECT_TRUE(absl::IsInvalidArgument(status)); // Merge between then and else branch. auto joined_or = JoinCondStatesMerge(m, then_branch, else_branch); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 604a24514f8e5a..7727853a8c4233 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -1114,7 +1114,7 @@ void ComplexTestFixture::RunTest() { if (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) { // This case violates the precondition of `FunctionalizeControlFlow`, we // expect an internal error. - ASSERT_EQ(errors::IsInternal(status1), true); + ASSERT_EQ(absl::IsInternal(status1), true); return; } else { // Supported cases, no error expected. diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index ddd1f23cbb068e..c88c4042ca2c7b 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -267,8 +267,11 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( // Get static MLIR Type from xla Shape. const xla::Shape &xla_shape = input_shapes[next_actual_input++]; - std::vector xla_dimensions(xla_shape.dimensions().begin(), - xla_shape.dimensions().end()); + std::vector xla_dimensions; + if (xla_shape.IsArray()) { + xla_dimensions = std::vector(xla_shape.dimensions().begin(), + xla_shape.dimensions().end()); + } TF_ASSIGN_OR_RETURN( mlir::Type element_type, ConvertPrimitiveTypeToMlirType(xla_shape.element_type(), builder)); @@ -399,9 +402,15 @@ absl::Status XlaCallModuleLoader::LoadModule( } // Parse the StableHLO/VHLO bytecode - module_ = mlir::stablehlo::deserializePortableArtifact(module_str, context_); - if (!module_) { - return absl::InvalidArgumentError("Cannot deserialize computation"); + { + mlir::StatusScopedDiagnosticHandler diag_handler(context_); + module_ = + mlir::stablehlo::deserializePortableArtifact(module_str, context_); + if (!module_) { + return absl::InvalidArgumentError( + absl::StrCat("Cannot deserialize computation: ", + diag_handler.ConsumeStatus().ToString())); + } } VLOG(3) << "Parsed serialized module (version = " << version << ", platforms = [" << absl::StrJoin(platforms, ", ") @@ -481,18 +490,14 @@ absl::Status XlaCallModuleLoader::ValidateStaticShapes() { absl::Status XlaCallModuleLoader::PrepareStablehloForLowering() { mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); - // TODO (b/393390051): Migrate required passes to StableHLO. + // TODO (b/410057228): Replace MHLO canonicalization with StableHLO. + // This code requires MHLO CaseOp canonicalization to remove unreachable + // branches, else `tf.call_tf_function` inlining can fail. mlir::PassManager pm(module_->getContext()); - applyTensorflowAndCLOptions(pm); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - pm.addNestedPass( - mlir::mhlo::createChloLegalizeToHloPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); - // In order to export to XLA, we must sink constants to control flow - // regions, since XLA uses functional control flow. - pm.addNestedPass( - mlir::mhlo::createSinkConstantsToControlFlowPass()); pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + if (failed(pm.run(*module_))) { return absl::InternalError( absl::StrCat("MHLO->HLO lowering passes failed: ", @@ -500,7 +505,7 @@ absl::Status XlaCallModuleLoader::PrepareStablehloForLowering() { } if (VLOG_IS_ON(5)) { - DumpMlirOpToFile("xla_call_module.after_mhlo_lowering", *module_); + DumpMlirOpToFile("xla_call_module.after_canonicalization", *module_); } return absl::OkStatus(); diff --git a/tensorflow/compiler/tf2xla/layout_util.cc b/tensorflow/compiler/tf2xla/layout_util.cc index 69075d3c712523..b000c49f1f962e 100644 --- a/tensorflow/compiler/tf2xla/layout_util.cc +++ b/tensorflow/compiler/tf2xla/layout_util.cc @@ -72,8 +72,8 @@ absl::Status RewriteLayoutWithShardedShape( sharding->TileOffsetForDevice(*xla_shape, device); std::vector limit = sharding->TileLimitForDevice(*xla_shape, device); - std::vector dimensions(xla_shape->dimensions_size()); - for (int64_t i = 0; i < xla_shape->dimensions_size(); ++i) { + std::vector dimensions(xla_shape->dimensions().size()); + for (int64_t i = 0; i < xla_shape->dimensions().size(); ++i) { dimensions[i] = limit[i] - offset[i]; } xla::Shape per_device_xla_shape = @@ -102,7 +102,7 @@ absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( std::optional sharding, bool fast_mem) { if (original_shape.IsTuple()) { std::vector elements; - for (int i = 0; i < original_shape.tuple_shapes_size(); ++i) { + for (int i = 0; i < original_shape.tuple_shapes().size(); ++i) { auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; TF_ASSIGN_OR_RETURN(auto element, ReshapeWithCorrectRepresentationAndSharding( @@ -131,7 +131,7 @@ absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( hlo_sharding, fast_mem, shape_determination_fns, &to_shape)); } if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { - for (int64_t i = 0; i < original_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < original_shape.dimensions().size(); ++i) { to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); } } diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 69c8a830937257..91e357ec69eaa9 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -35,9 +35,8 @@ limitations under the License. namespace tensorflow { absl::StatusOr XlaScatter( - const xla::XlaOp& buffer, const xla::XlaOp& updates, - const xla::XlaOp& indices, bool indices_are_vectors, - bool indices_are_sorted, + const xla::XlaOp buffer, const xla::XlaOp updates, const xla::XlaOp indices, + bool indices_are_vectors, bool indices_are_sorted, const std::function& combiner, xla::XlaBuilder* builder) { @@ -52,7 +51,7 @@ absl::StatusOr XlaScatter( if (indices_are_vectors) { TF_RET_CHECK(!indices_dims.empty()); num_index_dims = indices_dims.back(); - if (num_index_dims > buffer_shape.dimensions_size()) { + if (num_index_dims > buffer_shape.dimensions().size()) { return errors::InvalidArgument( "The size of the minor dimension of the indices (shape: ", xla::ShapeUtil::HumanString(indices_shape), @@ -141,11 +140,11 @@ absl::StatusOr XlaScatter( xla::ScatterDimensionNumbers dim_numbers; dim_numbers.set_index_vector_dim(indices_are_vectors - ? indices_shape.dimensions_size() - 1 - : indices_shape.dimensions_size()); + ? indices_shape.dimensions().size() - 1 + : indices_shape.dimensions().size()); - int64_t updates_rank = updates_shape.dimensions_size(); - int64_t buffer_rank = buffer_shape.dimensions_size(); + int64_t updates_rank = updates_shape.dimensions().size(); + int64_t buffer_rank = buffer_shape.dimensions().size(); int64_t num_window_dims_in_updates = buffer_rank - num_index_dims; // If the rank of `updates` is 0 and does not match the expected rank of @@ -160,7 +159,7 @@ absl::StatusOr XlaScatter( if (updates_rank == 0 && expected_updates_rank != 0) { new_updates = xla::Broadcast(updates, expected_updates_dims); TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); - updates_rank = updates_shape.dimensions_size(); + updates_rank = updates_shape.dimensions().size(); } if (updates_rank > 0) { diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 90af6e63fcbf05..1428d173ea138c 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -45,9 +45,8 @@ namespace tensorflow { // the buffer using the combiner function. Otherwise, the updates replace the // existing values. The order of updates is implementation-defined. absl::StatusOr XlaScatter( - const xla::XlaOp& buffer, const xla::XlaOp& updates, - const xla::XlaOp& indices, bool indices_are_vectors, - bool indices_are_sorted, + xla::XlaOp buffer, xla::XlaOp updates, xla::XlaOp indices, + bool indices_are_vectors, bool indices_are_sorted, const std::function& combiner, xla::XlaBuilder* builder); diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index eae5fb83c5d682..f41c202b01e447 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" +#include "absl/status/status.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 04fbb0cf31f834..0d7549d81c20f6 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -43,7 +43,7 @@ absl::Status PopulateInfeedLayoutVector(const xla::Shape& shape, layouts->push_back(dim); } } else { - layouts->insert(layouts->end(), shape.dimensions_size(), -1); + layouts->insert(layouts->end(), shape.dimensions().size(), -1); } return absl::OkStatus(); } @@ -97,7 +97,7 @@ absl::Status XLAShapeToTensorShape(const xla::Shape& shape, " cannot be converted to a TensorShape"); } *tensor_shape = TensorShape(); - for (int i = 0; i < shape.dimensions_size(); ++i) { + for (int i = 0; i < shape.dimensions().size(); ++i) { TF_RETURN_IF_ERROR(tensor_shape->AddDimWithStatus(shape.dimensions(i))); } return absl::OkStatus(); @@ -237,7 +237,7 @@ absl::Status GetShapeWithLayout( "Nested tuples not supported: ", xla::ShapeUtil::HumanString(input_shape)); } - int64_t rank = shape.dimensions_size(); + int64_t rank = shape.dimensions().size(); if (position + rank > minor_to_major.size()) { return errors::InvalidArgument( "Not enough layout attribute elements: position=", position, @@ -259,7 +259,7 @@ absl::Status GetShapeWithLayout( } *output_shape = xla::ShapeUtil::MakeTupleShape(shapes); } else { - int64_t rank = input_shape.dimensions_size(); + int64_t rank = input_shape.dimensions().size(); const int64_t minor_to_major_size = minor_to_major.size(); if (rank != minor_to_major_size) { return errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 9cc8787d44b6ca..d61d66bfe53b72 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -136,7 +136,7 @@ TEST(ConvertGraphDefToXla, Sum) { config.mutable_feed(0)->mutable_id()->set_output_index( 123); /* invalid output_index */ - EXPECT_TRUE(errors::IsInvalidArgument( + EXPECT_TRUE(absl::IsInvalidArgument( ConvertGraphDefToXla(graph_def, config, client, &computation))); } diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 81d62cb1ba7412..cfdf4addccad53 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -185,7 +185,6 @@ tf_proto_library( "//tensorflow/core/example:protos_all", "//tensorflow/core/framework:protos_all", "//tensorflow/core/lib/core:error_codes_proto", - "//tensorflow/core/profiler/protobuf:xplane_proto", "//tensorflow/core/profiler:profiler_options_proto", "//tensorflow/core/protobuf:error_codes_proto_impl", "//tensorflow/core/protobuf:for_core_protos", @@ -475,7 +474,6 @@ cc_library( hdrs = ["//tensorflow/core/public:session_options.h"], visibility = ["//visibility:public"], deps = [ - ":lib", ":protos_all_cc", ], ) @@ -1045,7 +1043,9 @@ cc_library( "//tensorflow/core:mobile_additional_lib_deps", "//tensorflow/core/platform:resource", "//tensorflow/core/public:release_version", + "//tensorflow/core/util:onednn_env_vars", "//tensorflow/core/util:stats_calculator_portable", + "@local_xla//xla/tsl/util:safe_reinterpret_cast", ] + tf_portable_proto_lib() + tf_portable_deps_no_runtime(), alwayslink = 1, ) @@ -1082,6 +1082,7 @@ cc_library( # "@com_google_absl//absl/strings", # "@com_google_absl//absl/types:optional", # "@local_xla//xla/tsl/framework/fixedpoint", +# "@local_xla//xla/tsl/util:safe_reinterpret_cast", # "//tensorflow/core/platform:resource", # "//tensorflow/core/util:managed_stack_trace", # "//tensorflow/core/util:stats_calculator_portable", @@ -1488,7 +1489,6 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc_impl", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc_impl", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc_impl", - "//tensorflow/core/profiler/protobuf:xplane_proto_cc_impl", "//tensorflow/core/protobuf:autotuning_proto_cc_impl", "//tensorflow/core/protobuf:conv_autotuning_proto_cc_impl", ":protos_all_cc_impl", diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput.pbtxt new file mode 100644 index 00000000000000..7e008729792491 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput.pbtxt new file mode 100644 index 00000000000000..c9b5362c29f454 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput.pbtxt new file mode 100644 index 00000000000000..e55443b71e24b7 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt new file mode 100644 index 00000000000000..ccc3643bbf4345 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput.pbtxt new file mode 100644 index 00000000000000..493df681eabb6e --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput.pbtxt new file mode 100644 index 00000000000000..1725da02426c77 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput" + visibility: HIDDEN +} \ No newline at end of file diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt new file mode 100644 index 00000000000000..72728218d6ead2 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput.pbtxt new file mode 100644 index 00000000000000..7e008729792491 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput.pbtxt new file mode 100644 index 00000000000000..c9b5362c29f454 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput.pbtxt new file mode 100644 index 00000000000000..e55443b71e24b7 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt new file mode 100644 index 00000000000000..ccc3643bbf4345 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput.pbtxt new file mode 100644 index 00000000000000..493df681eabb6e --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput.pbtxt new file mode 100644 index 00000000000000..1725da02426c77 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput" + visibility: HIDDEN +} \ No newline at end of file diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt new file mode 100644 index 00000000000000..72728218d6ead2 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" + visibility: HIDDEN +} diff --git a/tensorflow/core/common_runtime/all_to_all_test.cc b/tensorflow/core/common_runtime/all_to_all_test.cc index ba483eb9452adc..96fc2c3581c3e7 100644 --- a/tensorflow/core/common_runtime/all_to_all_test.cc +++ b/tensorflow/core/common_runtime/all_to_all_test.cc @@ -150,7 +150,7 @@ TEST_F(AllToAllTest, WrongFirstDimensionSize) { absl::Status status = RunCollective(test_env_.get(), col_params.get(), device, &tensors[i], &tensors[i]); counter.DecrementCount(); - EXPECT_TRUE(errors::IsInvalidArgument(status)); + EXPECT_TRUE(absl::IsInvalidArgument(status)); }); } counter.Wait(); diff --git a/tensorflow/core/common_runtime/buf_rendezvous_test.cc b/tensorflow/core/common_runtime/buf_rendezvous_test.cc index b549b012a9ffa6..5a394c4c6c857e 100644 --- a/tensorflow/core/common_runtime/buf_rendezvous_test.cc +++ b/tensorflow/core/common_runtime/buf_rendezvous_test.cc @@ -267,7 +267,7 @@ TEST_F(BufRendezvousTest, DeviceIncarnationMismatch) { }, /*cancellation_manager=*/nullptr); note.WaitForNotification(); - EXPECT_TRUE(errors::IsFailedPrecondition(cons_status)); + EXPECT_TRUE(absl::IsFailedPrecondition(cons_status)); } TEST_F(BufRendezvousTest, ProvideThenCancel) { @@ -282,7 +282,7 @@ TEST_F(BufRendezvousTest, ProvideThenCancel) { &cm_); cm_.StartCancel(); note.WaitForNotification(); - EXPECT_TRUE(errors::IsCancelled(status)); + EXPECT_TRUE(absl::IsCancelled(status)); EXPECT_NE( status.message().find(absl::StrCat( "Operation was cancelled for BufRendezvous key ", *kDefaultKey)), @@ -301,7 +301,7 @@ TEST_F(BufRendezvousTest, CancelThenProvide) { }, &cm_); note.WaitForNotification(); - EXPECT_TRUE(errors::IsCancelled(status)); + EXPECT_TRUE(absl::IsCancelled(status)); EXPECT_NE( status.message().find(absl::StrCat( "Operation was cancelled for BufRendezvous key ", *kDefaultKey)), @@ -320,7 +320,7 @@ TEST_F(BufRendezvousTest, ConsumeThenCancel) { &cm_); cm_.StartCancel(); note.WaitForNotification(); - EXPECT_TRUE(errors::IsCancelled(status)); + EXPECT_TRUE(absl::IsCancelled(status)); EXPECT_NE( status.message().find(absl::StrCat( "Operation was cancelled for BufRendezvous key ", *kDefaultKey)), @@ -339,7 +339,7 @@ TEST_F(BufRendezvousTest, CancelThenConsume) { }, &cm_); note.WaitForNotification(); - EXPECT_TRUE(errors::IsCancelled(status)); + EXPECT_TRUE(absl::IsCancelled(status)); EXPECT_NE( status.message().find(absl::StrCat( "Operation was cancelled for BufRendezvous key ", *kDefaultKey)), @@ -392,23 +392,23 @@ TEST_F(BufRendezvousTest, CancelThenProvideConsume) { *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, [&prod_status, &prod_callback_called](const absl::Status& s) { prod_status = s; - EXPECT_TRUE(errors::IsCancelled(prod_status)); + EXPECT_TRUE(absl::IsCancelled(prod_status)); prod_callback_called = true; }, &cm_); EXPECT_TRUE(prod_callback_called); - EXPECT_TRUE(errors::IsCancelled(prod_status)); + EXPECT_TRUE(absl::IsCancelled(prod_status)); br_->ConsumeBuf( *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, [&cons_status, &cons_callback_called](const absl::Status& s, BufRendezvous::Hook* h) { cons_status = s; - EXPECT_TRUE(errors::IsCancelled(cons_status)); + EXPECT_TRUE(absl::IsCancelled(cons_status)); cons_callback_called = true; }, &cm_); EXPECT_TRUE(cons_callback_called); - EXPECT_TRUE(errors::IsCancelled(cons_status)); + EXPECT_TRUE(absl::IsCancelled(cons_status)); } } // namespace diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc index ff60c2d5dcd97d..f52655dda366be 100644 --- a/tensorflow/core/common_runtime/collective_rma_local_test.cc +++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc @@ -165,7 +165,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) { done.Notify(); }); done.WaitForNotification(); - EXPECT_TRUE(errors::IsInternal(status)); + EXPECT_TRUE(absl::IsInternal(status)); } TEST_F(CollectiveRemoteAccessLocalTest, RecvThenCancel) { @@ -187,7 +187,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, RecvThenCancel) { cm_->StartCancel(); recv_note.WaitForNotification(); EXPECT_TRUE(cm_->IsCancelled()); - EXPECT_TRUE(errors::IsCancelled(recv_status)); + EXPECT_TRUE(absl::IsCancelled(recv_status)); } TEST_F(CollectiveRemoteAccessLocalTest, CancelThenRecv) { @@ -209,7 +209,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, CancelThenRecv) { }); recv_note.WaitForNotification(); EXPECT_TRUE(cm_->IsCancelled()); - EXPECT_TRUE(errors::IsCancelled(recv_status)); + EXPECT_TRUE(absl::IsCancelled(recv_status)); } TEST_F(CollectiveRemoteAccessLocalTest, PostThenCancel) { @@ -231,7 +231,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostThenCancel) { cm_->StartCancel(); send_note.WaitForNotification(); EXPECT_TRUE(cm_->IsCancelled()); - EXPECT_TRUE(errors::IsCancelled(send_status)); + EXPECT_TRUE(absl::IsCancelled(send_status)); } TEST_F(CollectiveRemoteAccessLocalTest, CancelThenPost) { @@ -253,7 +253,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, CancelThenPost) { }); send_note.WaitForNotification(); EXPECT_TRUE(cm_->IsCancelled()); - EXPECT_TRUE(errors::IsCancelled(send_status)); + EXPECT_TRUE(absl::IsCancelled(send_status)); } } // namespace diff --git a/tensorflow/core/common_runtime/device_resolver_local_test.cc b/tensorflow/core/common_runtime/device_resolver_local_test.cc index 76d64f79af025b..9c37edfea2333f 100644 --- a/tensorflow/core/common_runtime/device_resolver_local_test.cc +++ b/tensorflow/core/common_runtime/device_resolver_local_test.cc @@ -53,19 +53,19 @@ TEST_F(DeviceResolverLocalTest, GetDeviceAttributesKnown) { TEST_F(DeviceResolverLocalTest, GetDeviceAttributesUnknown) { DeviceAttributes attributes; - EXPECT_TRUE(errors::IsNotFound(drl_->GetDeviceAttributes( + EXPECT_TRUE(absl::IsNotFound(drl_->GetDeviceAttributes( "/job:localhost/replica:0/task:0/device:CPU:9", &attributes))); } TEST_F(DeviceResolverLocalTest, GetAllDeviceAttributes) { std::vector attributes; - EXPECT_TRUE(errors::IsInternal( - drl_->GetAllDeviceAttributes(/*task*/ "", &attributes))); + EXPECT_TRUE( + absl::IsInternal(drl_->GetAllDeviceAttributes(/*task*/ "", &attributes))); } TEST_F(DeviceResolverLocalTest, UpdateDeviceAttributes) { std::vector attributes; - EXPECT_TRUE(errors::IsInternal(drl_->UpdateDeviceAttributes(attributes))); + EXPECT_TRUE(absl::IsInternal(drl_->UpdateDeviceAttributes(attributes))); } } // namespace diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 30d8d12897b4da..5309d473ca2077 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -904,7 +904,7 @@ absl::Status DirectSession::Run( } } const absl::Status s = call_frame.SetArgs(feed_args); - if (errors::IsInternal(s)) { + if (absl::IsInternal(s)) { return errors::InvalidArgument(s.message()); } else if (!s.ok()) { return s; @@ -925,7 +925,7 @@ absl::Status DirectSession::Run( std::vector sorted_outputs; const absl::Status s = call_frame.ConsumeRetvals( &sorted_outputs, /* allow_dead_tensors = */ false); - if (errors::IsInternal(s)) { + if (absl::IsInternal(s)) { return errors::InvalidArgument(s.message()); } else if (!s.ok()) { return s; diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index b2e617751825f0..7d9c80f5dc30ad 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -188,7 +188,7 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) { } absl::Status s = session->RunCallable(handle, {}, nullptr, nullptr); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE( absl::StrContains(s.message(), "`fetch_tensors` must be provided")); @@ -196,12 +196,12 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) { std::vector outputs; s = session->RunCallable(handle, {}, &outputs, nullptr); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains( s.message(), "Attempted to run callable after handle was released")); s = session->RunCallable(handle + 1, {}, &outputs, nullptr); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "No such callable handle")); } } @@ -231,7 +231,7 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_OptimizeForStaticGraph) { EXPECT_FLOAT_EQ(5.0, mat(0, 0)); s = session->Extend({}); - EXPECT_TRUE(errors::IsFailedPrecondition(s)); + EXPECT_TRUE(absl::IsFailedPrecondition(s)); EXPECT_TRUE(absl::StrContains(s.message(), "optimize_for_static_graph")); } @@ -268,7 +268,7 @@ TEST_F(DirectSessionMinusAXTest, s = session->Run(run_options, inputs, output_names, target_nodes, &outputs, &run_metadata); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE( absl::StrContains(s.message(), "disable_output_partition_graphs")); } @@ -305,7 +305,7 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_FinalizeWithCallables) { // Making a new callable fails because the session has been finalized. absl::Status s = session->MakeCallable(MakeCallableOptions({}, {y_ + ":0"}, {}), &handle); - EXPECT_TRUE(errors::IsFailedPrecondition(s)); + EXPECT_TRUE(absl::IsFailedPrecondition(s)); EXPECT_TRUE(absl::StrContains(s.message(), "Session has been finalized.")); } @@ -337,7 +337,7 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_FinalizeWithRun) { // Running a different subgraph fails because the session has been finalized. absl::Status s = session->Run({}, {y_ + ":0"}, {}, &outputs); - EXPECT_TRUE(errors::IsFailedPrecondition(s)); + EXPECT_TRUE(absl::IsFailedPrecondition(s)); EXPECT_TRUE(absl::StrContains(s.message(), "Session has been finalized.")); } @@ -543,7 +543,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) { Session::CallableHandle handle; absl::Status s = session->MakeCallable(callable_options, &handle); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "would create a cycle")); } @@ -557,7 +557,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) { Session::CallableHandle handle; absl::Status s = session->MakeCallable(callable_options, &handle); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "unknown node")); } @@ -572,7 +572,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) { Session::CallableHandle handle; absl::Status s = session->MakeCallable(callable_options, &handle); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "unknown edge")); } @@ -586,7 +586,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) { Session::CallableHandle handle; absl::Status s = session->MakeCallable(callable_options, &handle); - EXPECT_TRUE(errors::IsNotFound(s)); + EXPECT_TRUE(absl::IsNotFound(s)); EXPECT_TRUE(absl::StrContains(s.message(), "unable to find feed output")); } @@ -603,7 +603,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) { Session::CallableHandle handle; absl::Status s = session->MakeCallable(callable_options, &handle); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "fed more than once")); } @@ -618,7 +618,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) { Session::CallableHandle handle; absl::Status s = session->MakeCallable(callable_options, &handle); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "fed more than once")); } } @@ -1043,7 +1043,7 @@ TEST(DirectSessionTest, MultipleFeedTest) { {{first_const->name(), value_11}, {first_const->name(), value_22}}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, &outputs); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "fed more than once")); } @@ -1126,7 +1126,7 @@ TEST(DirectSessionTest, MultipleFeedTest_Callable) { {first_const->name(), first_const->name()}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {}), &handle); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "fed more than once")); } @@ -1280,7 +1280,7 @@ TEST(DirectSessionTest, MultipleFeedTestSomeSyncRun) { {{first_const->name(), value_11}, {first_const->name(), value_22}}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, &outputs, nullptr); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "fed more than once")); } @@ -1463,8 +1463,7 @@ TEST(DirectSessionTest, SessionMetadataKey) { // Trying to use the same metadata (name, version) will cause an error. Session* dup_ptr; - EXPECT_TRUE( - errors::IsInvalidArgument(NewSession(session_options0, &dup_ptr))); + EXPECT_TRUE(absl::IsInvalidArgument(NewSession(session_options0, &dup_ptr))); // A new (name, version) is fine. auto session_options1 = DefaultSessionOptions(); @@ -1503,7 +1502,7 @@ TEST(DirectSessionTest, SessionMetadataInvalid) { // Version should be >= 0. invalid_metadata->set_version(-1); Session* error_sess_ptr; - EXPECT_TRUE(errors::IsInvalidArgument( + EXPECT_TRUE(absl::IsInvalidArgument( NewSession(invalid_session_options, &error_sess_ptr))); } @@ -1657,7 +1656,7 @@ TEST(DirectSessionTest, DarthKernel) { TF_ASSERT_OK(sess->Create(def)); std::vector outputs; auto s = sess->Run({}, {y->name() + ":0"}, {}, &outputs); - EXPECT_TRUE(errors::IsInternal(s)); + EXPECT_TRUE(absl::IsInternal(s)); } // Have the Darth op in the graph placed on GPU, but don't run it. @@ -1677,7 +1676,7 @@ TEST(DirectSessionTest, PlacePrunedGraph) { SessionOptions options; std::unique_ptr sess(NewSession(options)); auto s = sess->Create(def); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); } { @@ -1794,7 +1793,7 @@ TEST(DirectSessionTest, PartialRunMissingFeed) { value_11.scalar()() = 11.0; s = session->PRun(handle, {{first_const->name(), value_11}}, {third_identity->name() + ":0"}, &outputs); - ASSERT_TRUE(errors::IsInvalidArgument(s)); + ASSERT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE( absl::StrContains(s.message(), "can't be computed from the feeds")); } @@ -1825,7 +1824,7 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) { // Fetch fourth_identity without feeds. s = session->PRun(handle, {}, {fourth_identity->name() + ":0"}, &outputs); - ASSERT_TRUE(errors::IsInvalidArgument(s)); + ASSERT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE( absl::StrContains(s.message(), "can't be computed from the feeds")); @@ -1963,7 +1962,7 @@ TEST(DirectSessionTest, CreateGraphFailsWhenAssigningAFedVar) { std::vector outputs; absl::Status s = session->Run({{a->name(), zero}}, {assign->name()}, {}, &outputs); - ASSERT_TRUE(errors::IsInvalidArgument(s)); + ASSERT_TRUE(absl::IsInvalidArgument(s)); } TEST(DirectSessionTest, TimeoutSession) { diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index f2ada4a8ab854e..871d279bf396bf 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -253,7 +253,6 @@ tf_cuda_library( "@local_xla//xla/pjrt/gpu:se_gpu_pjrt_client", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", - "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_impl", "@local_xla//xla/tsl/distributed_runtime/preemption:preemption_notifier", "@local_xla//xla/tsl/platform:statusor", ], @@ -267,7 +266,6 @@ tf_cuda_library( clean_dep("//tensorflow:linux_x86_64_with_weightwatcher"): [], ( clean_dep("//tensorflow:linux_x86_64"), - clean_dep("//tensorflow:haswell"), ): [ "//tensorflow/core", "//tensorflow/core/framework:resource_base", diff --git a/tensorflow/core/common_runtime/eager/execute_test.cc b/tensorflow/core/common_runtime/eager/execute_test.cc index e424f217130c93..ea174fd22f76a2 100644 --- a/tensorflow/core/common_runtime/eager/execute_test.cc +++ b/tensorflow/core/common_runtime/eager/execute_test.cc @@ -169,7 +169,7 @@ TEST(ExecuteTest, SimpleFunctionInt32BadFullType) { std::vector retvals(1); int num_retvals = retvals.size(); absl::Status status = EagerExecute(op.get(), retvals.data(), &num_retvals); - ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status; + ASSERT_TRUE(absl::IsInvalidArgument(status)) << "Actual status: " << status; EXPECT_TRUE( absl::StrContains(status.message(), "TFT_TENSOR has 0 args instead of 1")) << "Actual: " << status.message(); diff --git a/tensorflow/core/common_runtime/eager/placement_test.cc b/tensorflow/core/common_runtime/eager/placement_test.cc index b89b9384ba7196..87bdf17a449d77 100644 --- a/tensorflow/core/common_runtime/eager/placement_test.cc +++ b/tensorflow/core/common_runtime/eager/placement_test.cc @@ -128,7 +128,7 @@ TEST_F(PlacementTest, SelectDeviceExplicitHardPlacement) { absl::Status status = context()->SelectDevice(requested, invalid_op, &dev); LOG(ERROR) << status; - EXPECT_TRUE(errors::IsNotFound(status)); + EXPECT_TRUE(absl::IsNotFound(status)); EXPECT_TRUE( absl::StrContains(status.message(), "Could not find device for node")) << "unexpected error message " << status.message(); @@ -138,7 +138,7 @@ TEST_F(PlacementTest, SelectDeviceExplicitHardPlacement) { NodeDef node = NDef("x", "TestOp", {}, {}); status = context()->SelectDevice(requested, node, &dev); - EXPECT_TRUE(errors::IsInvalidArgument(status)); + EXPECT_TRUE(absl::IsInvalidArgument(status)); EXPECT_TRUE(absl::StrContains(status.message(), "Could not satisfy device specification")) << "unexpected error message " << status.message(); @@ -169,7 +169,7 @@ TEST_F(PlacementTest, SelectDeviceExplicitSoftPlacement) { absl::Status status = context()->SelectDevice(requested, invalid_op, &dev); LOG(ERROR) << status; - EXPECT_TRUE(errors::IsNotFound(status)); + EXPECT_TRUE(absl::IsNotFound(status)); EXPECT_TRUE( absl::StrContains(status.message(), "Could not find device for node")) << "unexpected error message " << status.message(); diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index 278b2e61e2874a..f75716dc55a51e 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -1,4 +1,6 @@ load("@bazel_skylib//lib:selects.bzl", "selects") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load("@local_xla//xla/tsl:tsl.bzl", "if_cuda_libs") load( "//tensorflow:tensorflow.bzl", @@ -215,7 +217,6 @@ tf_cuda_library( clean_dep("//tensorflow:linux_x86_64_with_weightwatcher"): [], ( clean_dep("//tensorflow:linux_x86_64"), - clean_dep("//tensorflow:haswell"), ): [ "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/jit:flags", @@ -231,6 +232,10 @@ tf_cuda_library( ]), ) + if_cuda_or_rocm([ "@local_tsl//tsl/platform:dso_loader", + ]) + if_cuda([ + "@local_xla//xla/stream_executor/cuda:all_runtime", + ]) + if_rocm([ + "@local_xla//xla/stream_executor/rocm:all_runtime", ]), alwayslink = 1, ) diff --git a/tensorflow/core/data/hash_utils.cc b/tensorflow/core/data/hash_utils.cc index be806eedc62529..b3593646999882 100644 --- a/tensorflow/core/data/hash_utils.cc +++ b/tensorflow/core/data/hash_utils.cc @@ -116,7 +116,7 @@ absl::Status ShouldIgnoreInput(const NodeDef& node, int i, bool* result) { return absl::OkStatus(); } } - } else if (errors::IsNotFound(status)) { + } else if (absl::IsNotFound(status)) { LOG(WARNING) << "Cannot find " << node.op() << " in global op registry, so cannot determine which " "inputs are seeds."; diff --git a/tensorflow/core/data/metric_utils.cc b/tensorflow/core/data/metric_utils.cc index 8816593e3c050f..1219c04980a516 100644 --- a/tensorflow/core/data/metric_utils.cc +++ b/tensorflow/core/data/metric_utils.cc @@ -32,7 +32,7 @@ namespace tensorflow { namespace data { namespace { -// Safely subtracts `x` from `y` avoiding underflow. +// Safely subtracts `y` from `x` avoiding underflow. uint64_t safe_sub(uint64_t x, uint64_t y) { return x >= y ? x - y : 0; } } // namespace diff --git a/tensorflow/core/data/rewrite_utils.cc b/tensorflow/core/data/rewrite_utils.cc index 34cafb3403eaae..0628b408722d1e 100644 --- a/tensorflow/core/data/rewrite_utils.cc +++ b/tensorflow/core/data/rewrite_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/refcount.h" @@ -243,8 +244,7 @@ absl::Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, VLOG(3) << "Failed to hash tensor: " << s; } } - string graph_hash = - strings::StrCat(strings::Hex(hash, strings::kZeroPad16)); + std::string graph_hash = absl::StrCat(absl::Hex(hash, absl::kZeroPad16)); metrics::RecordTFDataFingerprint(graph_hash); }); } diff --git a/tensorflow/core/data/root_dataset.cc b/tensorflow/core/data/root_dataset.cc index 9cd418c6e556d9..068ccf51a94310 100644 --- a/tensorflow/core/data/root_dataset.cc +++ b/tensorflow/core/data/root_dataset.cc @@ -502,7 +502,7 @@ absl::Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input, *output = rewritten_output.get(); bool rewritten = (*output != input); - if (errors::IsDeadlineExceeded(s)) { + if (absl::IsDeadlineExceeded(s)) { // Ignore DeadlineExceeded as it implies that the attempted rewrite took too // long which should not prevent further computation. LOG(WARNING) << s; diff --git a/tensorflow/core/data/service/common.cc b/tensorflow/core/data/service/common.cc index adde241b38634c..787066353c772a 100644 --- a/tensorflow/core/data/service/common.cc +++ b/tensorflow/core/data/service/common.cc @@ -132,8 +132,8 @@ absl::StatusOr ParseDeploymentMode(absl::string_view s) { } bool IsPreemptedError(const absl::Status& status) { - return errors::IsAborted(status) || errors::IsCancelled(status) || - errors::IsUnavailable(status); + return absl::IsAborted(status) || absl::IsCancelled(status) || + absl::IsUnavailable(status); } } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index 1b3af854ac1940..3fcd560a54920a 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -227,8 +227,10 @@ absl::Status DataServiceDispatcherImpl::Start() { env_->RecursivelyCreateDir(DatasetsDir(config_.work_dir()))); } if (!config_.fault_tolerant_mode()) { - LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will " - "not be able to recover its state on restart."; + LOG(INFO) << "Started tf.data service dispatcher in non-fault-tolerant " + "mode with config: " + << config_.DebugString() + << "\nIt will not recover its state on restart."; started_ = true; return absl::OkStatus(); } @@ -240,7 +242,7 @@ absl::Status DataServiceDispatcherImpl::Start() { bool end_of_journal = false; FileJournalReader reader(env_, JournalDir(config_.work_dir())); absl::Status s = reader.Read(update, end_of_journal); - if (errors::IsNotFound(s)) { + if (absl::IsNotFound(s)) { LOG(INFO) << "No journal found. Starting dispatcher from new state."; } else if (!s.ok()) { return s; @@ -270,7 +272,7 @@ absl::Status DataServiceDispatcherImpl::Start() { TF_RETURN_IF_ERROR(journal_writer_.value()->EnsureInitialized()); TF_RETURN_IF_ERROR(RestoreSnapshots()); started_ = true; - LOG(INFO) << "Started tf.data service dispatcher with config " + LOG(INFO) << "Started tf.data service dispatcher with config: " << config_.DebugString(); return absl::OkStatus(); } @@ -438,7 +440,7 @@ absl::Status DataServiceDispatcherImpl::WorkerHeartbeat( std::vector> assigned_tasks; absl::Status s = state_.TasksForWorker(worker_address, assigned_tasks); if (!s.ok()) { - if (!errors::IsNotFound(s)) { + if (!absl::IsNotFound(s)) { return s; } VLOG(1) << "Registering new worker at address " << worker_address; @@ -637,7 +639,7 @@ DataServiceDispatcherImpl::FindDataset( absl::Status status = state_.DatasetFromId(request.dataset_id(), existing_dataset); - if (errors::IsNotFound(status)) { + if (absl::IsNotFound(status)) { return std::optional(); } TF_RETURN_IF_ERROR(status); @@ -704,7 +706,7 @@ absl::Status DataServiceDispatcherImpl::GetOrCreateJob( absl::Status s = state_.JobByName(job_name, job); if (s.ok()) { TF_RETURN_IF_ERROR(ValidateMatchingJob(job, *request)); - } else if (errors::IsNotFound(s)) { + } else if (absl::IsNotFound(s)) { TF_RETURN_IF_ERROR(CreateJob(job_name, *request, job)); } else { return s; @@ -729,10 +731,10 @@ absl::Status DataServiceDispatcherImpl::GetOrCreateIteration( TF_RETURN_IF_ERROR(state_.JobFromId(request->job_id(), job)); IterationKey key(job->job_name, request->repetition()); absl::Status s = state_.IterationByKey(key, iteration); - if (!s.ok() && !errors::IsNotFound(s)) { + if (!s.ok() && !absl::IsNotFound(s)) { return s; } - if (errors::IsNotFound(s) || iteration->garbage_collected) { + if (absl::IsNotFound(s) || iteration->garbage_collected) { TF_RETURN_IF_ERROR(CreateIteration(*request, iteration)); TF_RETURN_IF_ERROR(CreateTasksForIteration(iteration, tasks)); } @@ -755,7 +757,7 @@ absl::Status DataServiceDispatcherImpl::MaybeRemoveTask( { mutex_lock l(mu_); absl::Status s = state_.TaskFromId(request->task_id(), task); - if (errors::IsNotFound(s)) { + if (absl::IsNotFound(s)) { // Task is already removed. response->set_removed(true); return absl::OkStatus(); @@ -1074,7 +1076,7 @@ absl::Status DataServiceDispatcherImpl::ClientHeartbeat( std::shared_ptr iteration; absl::Status s = state_.IterationForIterationClientId( request->iteration_client_id(), iteration); - if (errors::IsNotFound(s) && !config_.fault_tolerant_mode()) { + if (absl::IsNotFound(s) && !config_.fault_tolerant_mode()) { return errors::NotFound( "Unknown iteration client id ", request->iteration_client_id(), ". The dispatcher is not configured to be fault tolerant, so this " diff --git a/tensorflow/core/data/service/task_runner_test.cc b/tensorflow/core/data/service/task_runner_test.cc index 62b1ab63251083..52fdf89dc9b89c 100644 --- a/tensorflow/core/data/service/task_runner_test.cc +++ b/tensorflow/core/data/service/task_runner_test.cc @@ -475,7 +475,7 @@ TEST(CachingTaskRunnerTest, Errors) { if (element.ok()) { result.push_back(*element); } - if (errors::IsInvalidArgument(element.status())) { + if (absl::IsInvalidArgument(element.status())) { EXPECT_THAT( element.status(), StatusIs(error::INVALID_ARGUMENT, diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index 9e1cf15b9d1f0a..c89c8a1c4881f4 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -701,7 +701,7 @@ void DataServiceWorkerImpl::UpdateTasks(const WorkerHeartbeatResponse& response) continue; } absl::Status s = ProcessTaskInternal(task); - if (!s.ok() && !errors::IsAlreadyExists(s)) { + if (!s.ok() && !absl::IsAlreadyExists(s)) { LOG(WARNING) << "Failed to start processing task " << task.task_id() << ": " << s; } diff --git a/tensorflow/core/data/snapshot_utils.cc b/tensorflow/core/data/snapshot_utils.cc index 0a35361092fd71..576cbed01fb633 100644 --- a/tensorflow/core/data/snapshot_utils.cc +++ b/tensorflow/core/data/snapshot_utils.cc @@ -1060,8 +1060,7 @@ absl::Status ReadMetadataFile( absl::Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash, const GraphDef* graph) { - std::string hash_hex = - strings::StrCat(strings::Hex(hash, strings::kZeroPad16)); + std::string hash_hex = absl::StrCat(absl::Hex(hash, absl::kZeroPad16)); std::string graph_file = io::JoinPath(path, absl::StrCat(hash_hex, "-graph.pbtxt")); diff --git a/tensorflow/core/debug/debug_graph_utils_test.cc b/tensorflow/core/debug/debug_graph_utils_test.cc index 5ffee94043a002..207b8bc1b3c1f7 100644 --- a/tensorflow/core/debug/debug_graph_utils_test.cc +++ b/tensorflow/core/debug/debug_graph_utils_test.cc @@ -47,15 +47,15 @@ TEST_F(DebugGraphUtilsTest, TestMalformedDebugOpName) { absl::Status s = ParseDebugOpName("(mute_if_healthy=true)", &debug_op_name_proper, &attributes); - ASSERT_TRUE(errors::IsInvalidArgument(s)); + ASSERT_TRUE(absl::IsInvalidArgument(s)); s = ParseDebugOpName("DebugNumericSummary(", &debug_op_name_proper, &attributes); - ASSERT_TRUE(errors::IsInvalidArgument(s)); + ASSERT_TRUE(absl::IsInvalidArgument(s)); s = ParseDebugOpName("DebugNumericSummary)", &debug_op_name_proper, &attributes); - ASSERT_TRUE(errors::IsInvalidArgument(s)); + ASSERT_TRUE(absl::IsInvalidArgument(s)); } TEST_F(DebugGraphUtilsTest, TestDebugOpNameWithMalformedAttributes) { @@ -64,28 +64,28 @@ TEST_F(DebugGraphUtilsTest, TestDebugOpNameWithMalformedAttributes) { absl::Status s = ParseDebugOpName("DebugNumericSummary(=)", &debug_op_name_proper, &attributes); - ASSERT_TRUE(errors::IsInvalidArgument(s)); + ASSERT_TRUE(absl::IsInvalidArgument(s)); s = ParseDebugOpName("DebugNumericSummary(mute_if_healthy=)", &debug_op_name_proper, &attributes); - ASSERT_TRUE(errors::IsInvalidArgument(s)); + ASSERT_TRUE(absl::IsInvalidArgument(s)); s = ParseDebugOpName("DebugNumericSummary(=true)", &debug_op_name_proper, &attributes); - ASSERT_TRUE(errors::IsInvalidArgument(s)); + ASSERT_TRUE(absl::IsInvalidArgument(s)); s = ParseDebugOpName("DebugNumericSummary(mute_if_healthy:true)", &debug_op_name_proper, &attributes); - ASSERT_TRUE(errors::IsInvalidArgument(s)); + ASSERT_TRUE(absl::IsInvalidArgument(s)); s = ParseDebugOpName("DebugNumericSummary(mute_if_healthy=true;threshold=)", &debug_op_name_proper, &attributes); - ASSERT_TRUE(errors::IsInvalidArgument(s)); + ASSERT_TRUE(absl::IsInvalidArgument(s)); s = ParseDebugOpName( "DebugNumericSummary(mute_if_healthy=true;threshold:300.0)", &debug_op_name_proper, &attributes); - ASSERT_TRUE(errors::IsInvalidArgument(s)); + ASSERT_TRUE(absl::IsInvalidArgument(s)); } TEST_F(DebugGraphUtilsTest, TestValidDebugOpNameWithSingleAttribute) { @@ -134,7 +134,7 @@ TEST_F(DebugGraphUtilsTest, TestValidDebugOpNameWithMoreDuplicateAttributes) { "DebugNumericSummary(mute_if_healthy=true; lower_bound=3; " "mute_if_healthy=false;)", &debug_op_name_proper, &attributes); - ASSERT_TRUE(errors::IsInvalidArgument(s)); + ASSERT_TRUE(absl::IsInvalidArgument(s)); } TEST_F(DebugGraphUtilsTest, TestValidDebugOpNameWithWhitespaceInAttributes) { diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc index f48758f993677f..fe9fba5ffa50eb 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc @@ -346,7 +346,7 @@ TEST_F(DeviceResDistTest, DifferentIncarnation) { const string task_name = "/job:worker/replica:0/task:1"; const string device_name = absl::StrCat(task_name, "/device:CPU:0"); IssueRequest(task_name, device_name, num_workers * num_devices); - EXPECT_TRUE(errors::IsFailedPrecondition(status_[device_name])); + EXPECT_TRUE(absl::IsFailedPrecondition(status_[device_name])); } TEST_F(DeviceResDistTest, BroadcastSourceRank0) { diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc index 0338abeda899d3..3a97b7342023f8 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc @@ -568,7 +568,7 @@ TEST_P(CollRMADistTest, WorkerRestart) { post_restart_note.Notify(); }); post_restart_note.WaitForNotification(); - EXPECT_TRUE(errors::IsFailedPrecondition(consumer_status)); + EXPECT_TRUE(absl::IsFailedPrecondition(consumer_status)); } TEST_P(CollRMADistTest, CheckHealthOKWithCachedAttr) { @@ -611,7 +611,7 @@ TEST_P(CollRMADistTest, CheckHealthRestarted) { check_health_done.Notify(); }); check_health_done.WaitForNotification(); - EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status)); + EXPECT_TRUE(absl::IsFailedPrecondition(check_health_status)); } TEST_P(CollRMADistTest, CheckHealthFailedPeer) { @@ -628,7 +628,7 @@ TEST_P(CollRMADistTest, CheckHealthFailedPeer) { check_health_done.Notify(); }); check_health_done.WaitForNotification(); - EXPECT_TRUE(errors::IsUnavailable(check_health_status)); + EXPECT_TRUE(absl::IsUnavailable(check_health_status)); } TEST_P(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) { @@ -643,7 +643,7 @@ TEST_P(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) { check_health_done.Notify(); }); check_health_done.WaitForNotification(); - EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status)); + EXPECT_TRUE(absl::IsFailedPrecondition(check_health_status)); } INSTANTIATE_TEST_SUITE_P( diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc index 95ee288c8db803..2a455583ee15e4 100644 --- a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc +++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc @@ -57,85 +57,6 @@ class MockCoordinationServiceAgent : public CoordinationServiceAgent { (override)); MOCK_METHOD(absl::Status, CancelBarrier, (std::string_view barrier_id), (override)); - - // All the following member functions are not needed for testing. - MOCK_METHOD(absl::Status, Initialize, - (Env * env, std::string_view job_name, int task_id, - const CoordinationServiceConfig& configs, - std::unique_ptr leader_client, - StatusCallback error_fn, bool recoverable), - (override)); - MOCK_METHOD(absl::Status, Initialize, - (Env * env, std::string_view job_name, int task_id, - const CoordinationServiceConfig& configs, - std::unique_ptr leader_client, - StatusCallback error_fn), - (override)); - MOCK_METHOD(absl::Status, Initialize, - (Env * env, const CoordinatedTask& task, - const CoordinationServiceConfig& configs, - std::unique_ptr leader_client, - StatusCallback error_fn), - (override)); - MOCK_METHOD(bool, IsInitialized, (), (override)); - MOCK_METHOD(bool, IsConnected, (), (override)); - MOCK_METHOD(bool, IsError, (), (override)); - MOCK_METHOD(absl::Status, Connect, (), (override)); - MOCK_METHOD(absl::Status, WaitForAllTasks, (const DeviceInfo& local_devices), - (override)); - MOCK_METHOD(const DeviceInfo&, GetClusterDeviceInfo, (), (override)); - MOCK_METHOD(absl::StatusOr, GetOwnTask, (), (override)); - MOCK_METHOD(absl::StatusOr>, - GetTaskState, (const std::vector& task), - (override)); - MOCK_METHOD(absl::StatusOr>, - GetJobState, (absl::string_view job_nam), (override)); - MOCK_METHOD(absl::Status, ReportError, (const absl::Status& error), - (override)); - MOCK_METHOD(absl::Status, Shutdown, (), (override)); - MOCK_METHOD(absl::Status, Reset, (), (override)); - MOCK_METHOD(absl::StatusOr, GetKeyValue, (std::string_view key), - (override)); - MOCK_METHOD(absl::StatusOr, GetKeyValue, - (std::string_view key, absl::Duration timeout), (override)); - MOCK_METHOD(std::shared_ptr, GetKeyValueAsync, - (std::string_view key, StatusOrValueCallback done), (override)); - MOCK_METHOD(absl::StatusOr, TryGetKeyValue, - (std::string_view key), (override)); - MOCK_METHOD(absl::StatusOr>, GetKeyValueDir, - (std::string_view key), (override)); - MOCK_METHOD(void, GetKeyValueDirAsync, - (std::string_view key, StatusOrValueDirCallback done), - (override)); - MOCK_METHOD(absl::Status, InsertKeyValue, - (std::string_view key, std::string_view value), (override)); - MOCK_METHOD(absl::Status, InsertKeyValue, - (std::string_view key, std::string_view value, - bool allow_overwrite), - (override)); - MOCK_METHOD(absl::Status, DeleteKeyValue, (std::string_view key), (override)); - MOCK_METHOD(absl::Status, UpdateKeyValue, - (std::string_view key, std::string_view value), (override)); - MOCK_METHOD(absl::Status, StartWatchKey, - (std::string_view key, ChangedKeyValuesCallback on_change), - (override)); - MOCK_METHOD(absl::Status, StopWatchKey, (std::string_view key), (override)); - MOCK_METHOD(void, WaitAtBarrierAsync, - (std::string_view barrier_id, absl::Duration timeout, - const std::vector& tasks, StatusCallback done), - (override)); - MOCK_METHOD(void, CancelBarrierAsync, - (std::string_view barrier_id, StatusCallback done), (override)); - MOCK_METHOD(absl::StatusOr>, GetAliveTasks, - (const std::vector& tasks), (override)); - MOCK_METHOD(void, AddJobStateCallback, (JobStateCallback callback), - (override)); - MOCK_METHOD(absl::StatusOr, GetEnv, (), (override)); - MOCK_METHOD(void, SetError, (const absl::Status& error), (override)); - MOCK_METHOD(absl::Status, ActivateWatch, - (std::string_view key, - (const std::map&)), - (override)); }; constexpr auto kTestKey = "test_key"; diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc index 0c0eb756837e5a..0c2bdba1da59d4 100644 --- a/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc @@ -83,7 +83,7 @@ TEST_F(DeviceResDistTest, GetDeviceAttributesLocal) { TEST_F(DeviceResDistTest, GetDeviceAttributesLocalUnknown) { DeviceAttributes attributes; - EXPECT_TRUE(errors::IsNotFound(dev_resolver_->GetDeviceAttributes( + EXPECT_TRUE(absl::IsNotFound(dev_resolver_->GetDeviceAttributes( "/job:worker/replica:0/task:0/device:CPU:9", &attributes))); } @@ -109,7 +109,7 @@ TEST_F(DeviceResDistTest, GetAllDeviceAttributes) { TEST_F(DeviceResDistTest, GetAllDeviceAttributesUnknown) { std::vector attributes; - EXPECT_TRUE(errors::IsNotFound(dev_resolver_->GetAllDeviceAttributes( + EXPECT_TRUE(absl::IsNotFound(dev_resolver_->GetAllDeviceAttributes( "/job:worker/replica:0/task:3", &attributes))); } @@ -157,7 +157,7 @@ TEST_F(DeviceResDistTest, UpdateDeviceAttributesDifferentIncarnation) { attributes.push_back( NewDevice("CPU", "/job:worker/replica:0/task:0/device:CPU:1") ->attributes()); - EXPECT_TRUE(errors::IsFailedPrecondition( + EXPECT_TRUE(absl::IsFailedPrecondition( dev_resolver_->UpdateDeviceAttributes(attributes))); } diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc index c053c91f8fdb9d..a9f54eac03a96a 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc @@ -496,7 +496,7 @@ void RemoteCopyNode::RunAsync(StatusCallback done) { const std::shared_ptr& captured_state = captured_state_; auto done_wrapper = [captured_state, done = std::move(done)](const absl::Status& s) { - if (!s.ok() && errors::IsCancelled(s)) { + if (!s.ok() && absl::IsCancelled(s)) { absl::Status send_status = captured_state->GetSendStatus(); if (!send_status.ok()) { // In this case, Recv is cancelled because the Send op failed. diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 5ef36545817eeb..c0351bbe448535 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -335,6 +335,7 @@ cc_library( "//tensorflow/core/nccl:collective_communicator", "//tensorflow/core/profiler/rpc:profiler_service_impl", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service", "@local_xla//xla/tsl/distributed_runtime/rpc:async_service_interface", ] + tf_protos_profiler_service() + tf_grpc_dependencies() + tf_grpc_cc_dependencies(), alwayslink = 1, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc index 6ccc00364c3962..f85d7836928b79 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc @@ -152,7 +152,7 @@ class GrpcRemoteMaster : public MasterInterface { ctx.set_deadline(absl::ToChronoTime(absl::Now() + timeout)); } s = FromGrpcStatus((stub_.get()->*pfunc)(&ctx, *request, response)); - if (!errors::IsUnavailable(s)) { + if (!absl::IsUnavailable(s)) { return s; } // TODO(b/117162170): we may want to make this configurable. diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 56d21e876f6db3..cc5ab2bd5a2cb5 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -27,6 +27,7 @@ limitations under the License. #include "grpcpp/security/credentials.h" #include "grpcpp/server_builder.h" #include "absl/strings/numbers.h" +#include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -279,9 +280,11 @@ absl::Status GrpcServer::Init(const GrpcServerOptions& opts) { opts.worker_service_options) .release(); eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder); - thread::ThreadPool* compute_pool = ComputePool(sess_opts); - coordination_service_ = - new GrpcCoordinationServiceImpl(compute_pool, &builder); + coordination_compute_pool_ = std::make_unique( + env_, "CoordinationServiceRpcHandler", + /*num_threads=*/4); + coordination_service_ = new GrpcCoordinationServiceImpl( + coordination_compute_pool_.get(), &builder); profiler_service_ = tsl::profiler::CreateProfilerService(); builder.RegisterService(profiler_service_.get()); @@ -331,7 +334,7 @@ absl::Status GrpcServer::Init(const GrpcServerOptions& opts) { return WorkerCacheFactory(options, worker_cache); }, grpc_coordination_service->GetRpcHandler()); - worker_env_.compute_pool = compute_pool; + worker_env_.compute_pool = ComputePool(sess_opts); // Finish setting up master environment. master_env_.ops = OpRegistry::Global(); @@ -522,7 +525,7 @@ absl::Status GrpcServer::SetCoordinationServiceAgentInstance( } absl::Status GrpcServer::SetCoordinationServiceInstance( - tsl::CoordinationServiceInterface* service) { + tsl::CoordinationService* service) { auto* coord_service = static_cast(coordination_service_); coord_service->SetCoordinationServiceInstance(service); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index ca162c193d3b15..431e4c4490be2a 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/platform/env.h" +#include "tsl/platform/threadpool.h" #include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" namespace tensorflow { @@ -175,7 +176,7 @@ class GrpcServer : public ServerInterface { GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); } absl::Status SetCoordinationServiceInstance( - tsl::CoordinationServiceInterface* service); + tsl::CoordinationService* service); private: Env* env_; @@ -225,6 +226,7 @@ class GrpcServer : public ServerInterface { std::shared_ptr worker_session_; // Experimental coordination service implementation, and RPC polling thread. + std::unique_ptr coordination_compute_pool_ = nullptr; tsl::AsyncServiceInterface* coordination_service_ = nullptr; std::unique_ptr coordination_thread_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc index 9e293d70e0e3ea..573627e1ae2ddf 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc @@ -346,7 +346,7 @@ TEST(GrpcSessionTest, DisableOutputPartitionGraphs) { RunMetadata run_metadata; absl::Status s = session->Run(run_options, {}, {}, {node_names[2]}, nullptr, &run_metadata); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); EXPECT_TRUE( absl::StrContains(s.message(), "disable_output_partition_graphs")); } diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index 701ce3ed4e61a3..80044d1fcb6aa1 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -177,7 +177,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) { bool val_dead = false; Rendezvous::Args args; TF_ASSERT_OK(rendez->Initialize(&worker_session_)); - EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); + EXPECT_TRUE(absl::IsAborted(rendez->Recv(key, args, &val, &val_dead))); } { // Cleanup causes Abort(). const int64_t step_id = 321; @@ -190,7 +190,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) { bool val_dead = false; Rendezvous::Args args; TF_ASSERT_OK(rendez->Initialize(&worker_session_)); - EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); + EXPECT_TRUE(absl::IsAborted(rendez->Recv(key, args, &val, &val_dead))); } } @@ -212,7 +212,7 @@ TEST_F(RpcRendezvousMgrTest, LocalCancel) { Rendezvous::Args args; args.cancellation_manager = cm; TF_ASSERT_OK(rendez->Initialize(&worker_session_)); - EXPECT_TRUE(errors::IsCancelled(rendez->Recv(key, args, &val, &val_dead))); + EXPECT_TRUE(absl::IsCancelled(rendez->Recv(key, args, &val, &val_dead))); n.WaitForNotification(); delete cm; } diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index a881b2952fa5fa..44e3eaf1ecc2d5 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -289,9 +289,8 @@ absl::Status SessionMgr::CreateSession( // Initialize coordination service if it is the leader. if (IsMultiClientLeader(server_def, coordination_config)) { - coordination_service_ = - tsl::CoordinationServiceInterface::EnableCoordinationService( - worker_env_->env, coordination_config, std::move(client_cache)); + coordination_service_ = tsl::CoordinationService::Create( + worker_env_->env, coordination_config, std::move(client_cache)); if (coordination_handler_ != nullptr) { coordination_handler_->SetServiceInstance(coordination_service_.get()); } diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h index 55c64f45c9daeb..0a2bddddb1aeb7 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.h +++ b/tensorflow/core/distributed_runtime/session_mgr.h @@ -134,7 +134,7 @@ class SessionMgr { std::unique_ptr default_worker_cache_; std::shared_ptr legacy_session_; - std::unique_ptr coordination_service_; + std::unique_ptr coordination_service_; std::unique_ptr coordination_service_agent_; bool is_logging_active_ = false; diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 744b1df9e09c74..1bff2fba6601b2 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -44,6 +44,10 @@ default_visibility = [ #internal nexus library tests, "//tensorflow/compiler/jit:__subpackages__", #internal library, + # TODO(matthurd): to be removed when summary.proto.h deps moves to TSL + "@org_xprof//xprof:__subpackages__", + "//tensorflow/cc/experimental/tfa:__subpackages__", + "//tensorflow/compiler/mlir/utils:__subpackages__", ] package( diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index 0f4ab03958a53e..0071995578152b 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -67,7 +67,7 @@ class Other : public ResourceBase { class Finalizable : public ResourceBase { public: - explicit Finalizable(absl::Nonnull finalize_count) + explicit Finalizable(int* absl_nonnull finalize_count) : finalize_count_(*finalize_count) {} ~Finalizable() override = default; diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 73bbc554ea1b05..e04c3e2f12eb72 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -1268,6 +1268,10 @@ inline float PrintOneElement(float8_e4m3fn f, bool print_v2) { return static_cast(f); } +inline float PrintOneElement(float8_e4m3b11fnuz f, bool print_v2) { + return static_cast(f); +} + inline int16_t PrintOneElement(int4 a, bool print_v2) { return static_cast(a); } @@ -1454,6 +1458,9 @@ string Tensor::SummarizeValue(int64_t max_entries, bool print_v2) const { case DT_FLOAT8_E4M3FN: return SummarizeArray(limit, num_elts, shape_, data, print_v2); + case DT_FLOAT8_E4M3B11FNUZ: + return SummarizeArray(limit, num_elts, shape_, data, + print_v2); case DT_FLOAT: return SummarizeArray(limit, num_elts, shape_, data, print_v2); break; diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 0b08a127cdbd13..640806b2afaec0 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -1038,7 +1038,7 @@ absl::Status Partition(const PartitionOptions& opts, Graph* g, if (opts.need_to_record_start_times) { int64_t start_time; status = GetNodeAttr(*dst_def, "_start_time", &start_time); - if (errors::IsNotFound(status)) { + if (absl::IsNotFound(status)) { start_time = opts.start_times[dst->id()].value(); AddNodeAttr("_start_time", start_time, dst_def); } else if (!status.ok()) { @@ -1101,14 +1101,14 @@ absl::Status Partition(const PartitionOptions& opts, Graph* g, int64_t recv_start_time = 0; if (opts.scheduling_for_recvs) { status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time); - if (errors::IsNotFound(status) && opts.need_to_record_start_times) { + if (absl::IsNotFound(status) && opts.need_to_record_start_times) { send_start_time = opts.start_times[src->id()].value(); } else if (!status.ok()) { return status; } status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time); - if (errors::IsNotFound(status) && opts.need_to_record_start_times) { + if (absl::IsNotFound(status) && opts.need_to_record_start_times) { recv_start_time = opts.start_times[dst->id()].value(); } else if (!status.ok()) { return status; diff --git a/tensorflow/core/grappler/clusters/single_machine_test.cc b/tensorflow/core/grappler/clusters/single_machine_test.cc index 4e3a0019da2e67..fb3960ea2a8b79 100644 --- a/tensorflow/core/grappler/clusters/single_machine_test.cc +++ b/tensorflow/core/grappler/clusters/single_machine_test.cc @@ -249,9 +249,9 @@ TEST_F(SingleMachineTest, TimeOuts) { TF_CHECK_OK(cluster_->Initialize(item)); RunMetadata metadata; absl::Status s1 = cluster_->Run(item.graph, item.feed, item.fetch, &metadata); - EXPECT_TRUE(errors::IsDeadlineExceeded(s1)); + EXPECT_TRUE(absl::IsDeadlineExceeded(s1)); absl::Status s2 = cluster_->Run(item.graph, item.feed, item.fetch, &metadata); - EXPECT_TRUE(errors::IsDeadlineExceeded(s2)); + EXPECT_TRUE(absl::IsDeadlineExceeded(s2)); } static void RunInfiniteTFLoop() { @@ -337,7 +337,7 @@ static void RunInfiniteTFLoop() { TF_CHECK_OK(cluster.Initialize(item)); absl::Status s1 = cluster.Run(item.graph, item.feed, item.fetch, nullptr); - if (!errors::IsDeadlineExceeded(s1)) { + if (!absl::IsDeadlineExceeded(s1)) { LOG(ERROR) << "Expected 'deadline exceeded' error, got " << s1; // Exit to break the infinite loop _exit(1); @@ -345,7 +345,7 @@ static void RunInfiniteTFLoop() { // Attempt to shutdown the cluster and make sure we get the proper error code. absl::Status s2 = cluster.Shutdown(); - if (!errors::IsUnavailable(s2)) { + if (!absl::IsUnavailable(s2)) { LOG(ERROR) << "Expected 'unavailable' error, got " << s2; // Exit to break the infinite loop _exit(2); @@ -633,7 +633,7 @@ TEST_F(SingleMachineTest, PeakMemoryStatsNotEnabled) { absl::Status s = cluster.GetPeakMemoryUsage(&device_peak_memory); TF_CHECK_OK(cluster.Shutdown()); ASSERT_FALSE(s.ok()); - EXPECT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(absl::IsInvalidArgument(s)); } #endif diff --git a/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter_test.cc b/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter_test.cc index c032e61580d6f6..7b3a784bd205e8 100644 --- a/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter_test.cc +++ b/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter_test.cc @@ -183,7 +183,7 @@ TEST_P(BatchOpRewriterTest, InvalidArgumentForAdaptiveBatchScheduler) { Status status = optimizer.Optimize(nullptr, item, &optimized_graph); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(errors::IsInvalidArgument(status)); + EXPECT_TRUE(absl::IsInvalidArgument(status)); } // Tests that reserved attributes relevant with adaptive scheduler are diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index c05e5d1e84683d..ba2b40dc7708eb 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -896,7 +896,7 @@ TEST_F(MetaOptimizerTest, RunPostOptimizationVerifiersOnInvalidGraph) { MetaOptimizer optimizer_with_post_verifiers(nullptr, config_proto); absl::Status status = optimizer_with_post_verifiers.Optimize(nullptr, item, &output); - EXPECT_TRUE(errors::IsInvalidArgument(status)); + EXPECT_TRUE(absl::IsInvalidArgument(status)); EXPECT_TRUE(absl::StrContains( status.message(), "NodeDef expected inputs 'float' do not match 3 inputs specified")); diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index dbbab4e2c9c492..a5a48347f07517 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -489,17 +489,13 @@ bool RuntimeFusionEnabled(const Cluster* cluster) { } } } - bool runtime_fusion_enabled = CudnnUseRuntimeFusion() && - CudnnUseFrontend() && num_gpus > 0 && - num_gpus == num_ampere; + bool runtime_fusion_enabled = + CudnnUseRuntimeFusion() && num_gpus > 0 && num_gpus == num_ampere; if (CudnnUseRuntimeFusion() && !runtime_fusion_enabled) { - VLOG(1) << "Enabling Cudnn with runtime compilation requires the " - << "Cudnn frontend and Ampere GPUs or later, but we got " - << "Cudnn frontend is " - << (CudnnUseFrontend() ? "enabled" : "disabled") << " and " - << num_ampere << " Ampere GPU(s) out of total " << num_gpus - << " GPU(s)"; + VLOG(1) << "Enabling Cudnn with runtime compilation requires " + << "Ampere (sm_80) GPUs or later, but we got " << num_ampere + << " sm_80+ GPU(s) out of total " << num_gpus << " GPU(s)"; } return runtime_fusion_enabled; diff --git a/tensorflow/core/grappler/optimizers/tfg_optimizer_hook_test.cc b/tensorflow/core/grappler/optimizers/tfg_optimizer_hook_test.cc index 82f0be75e9efc4..9087fcc1b3aa91 100644 --- a/tensorflow/core/grappler/optimizers/tfg_optimizer_hook_test.cc +++ b/tensorflow/core/grappler/optimizers/tfg_optimizer_hook_test.cc @@ -119,7 +119,7 @@ TEST(TFGOptimizerTest, TestImportErrorReturnsAborted) { // Expect an aborted error. EXPECT_FALSE(status.ok()); - EXPECT_TRUE(errors::IsAborted(status)); + EXPECT_TRUE(absl::IsAborted(status)); } TEST(TFGOptimizerTest, TestPassErrorIsFatal) { @@ -139,8 +139,8 @@ TEST(TFGOptimizerTest, TestPassErrorIsFatal) { // Expect a non-aborted, non-timeout error. EXPECT_FALSE(status.ok()); - EXPECT_FALSE(errors::IsAborted(status)); - EXPECT_TRUE(errors::IsInvalidArgument(status)); + EXPECT_FALSE(absl::IsAborted(status)); + EXPECT_TRUE(absl::IsInvalidArgument(status)); } TEST(TFGOptimizerTest, TestImportErrorMetaOptimizerIsNotFatal) { diff --git a/tensorflow/core/grappler/verifiers/structure_verifier_test.cc b/tensorflow/core/grappler/verifiers/structure_verifier_test.cc index 95c0a759159c91..562deb5367493c 100644 --- a/tensorflow/core/grappler/verifiers/structure_verifier_test.cc +++ b/tensorflow/core/grappler/verifiers/structure_verifier_test.cc @@ -85,7 +85,7 @@ TEST_F(StructureVerifierTest, OpNotRegistered) { "node { name: 't1' op: 'TestMul' input: [ 'input:0', 't2' ] }" "node { name: 't2' op: 'TestMul' input: [ 'input:1', 't1' ] }"); absl::Status status = verifier_->Verify(graph_); - EXPECT_TRUE(errors::IsNotFound(status)); + EXPECT_TRUE(absl::IsNotFound(status)); EXPECT_TRUE(absl::StrContains(status.message(), "Op type not registered")); } @@ -94,7 +94,7 @@ TEST_F(StructureVerifierTest, DuplicateNodeNames) { "node { name: 'A' op: 'TestParams' }" "node { name: 'A' op: 'TestInput' }"); absl::Status status = verifier_->Verify(graph_); - EXPECT_TRUE(errors::IsAlreadyExists(status)); + EXPECT_TRUE(absl::IsAlreadyExists(status)); EXPECT_TRUE(absl::StrContains(status.message(), "Node already exists:")); } @@ -104,7 +104,7 @@ TEST_F(StructureVerifierTest, GraphWithInvalidCycle) { "node { name: 't1' op: 'TestMul' input: [ 'input:0', 't2' ] }" "node { name: 't2' op: 'TestMul' input: [ 'input:1', 't1' ] }"); absl::Status status = verifier_->Verify(graph_); - EXPECT_TRUE(errors::IsInvalidArgument(status)); + EXPECT_TRUE(absl::IsInvalidArgument(status)); EXPECT_TRUE(absl::StrContains( status.message(), "The graph couldn't be sorted in topological order")); } diff --git a/tensorflow/core/ir/BUILD b/tensorflow/core/ir/BUILD index 6d9aee324fb32d..9a77ffcfd97283 100644 --- a/tensorflow/core/ir/BUILD +++ b/tensorflow/core/ir/BUILD @@ -23,16 +23,10 @@ td_library( gentbl_cc_library( name = "InterfacesIncGen", - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "interfaces.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "interfaces.cc.inc", - ), - ], + tbl_outs = { + "interfaces.h.inc": ["-gen-op-interface-decls"], + "interfaces.cc.inc": ["-gen-op-interface-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "interfaces.td", deps = [ @@ -59,56 +53,38 @@ td_library( gentbl_cc_library( name = "DialectIncGen", - tbl_outs = [ - ( - [ - "-gen-op-decls", - "-dialect", - "tfg", - ], - "ops.h.inc", - ), - ( - [ - "-gen-op-defs", - "-dialect", - "tfg", - ], - "ops.cc.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect", - "tfg", - ], - "dialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - "-dialect", - "tfg", - ], - "dialect.cc.inc", - ), - ( - [ - "-gen-attrdef-decls", - "-attrdefs-dialect", - "tfg", - ], - "attributes.h.inc", - ), - ( - [ - "-gen-attrdef-defs", - "-attrdefs-dialect", - "tfg", - ], - "attributes.cc.inc", - ), - ], + tbl_outs = { + "ops.h.inc": [ + "-gen-op-decls", + "-dialect", + "tfg", + ], + "ops.cc.inc": [ + "-gen-op-defs", + "-dialect", + "tfg", + ], + "dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect", + "tfg", + ], + "dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect", + "tfg", + ], + "attributes.h.inc": [ + "-gen-attrdef-decls", + "-attrdefs-dialect", + "tfg", + ], + "attributes.cc.inc": [ + "-gen-attrdef-defs", + "-attrdefs-dialect", + "tfg", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ops.td", deps = [ diff --git a/tensorflow/core/ir/importexport/convert_types.cc b/tensorflow/core/ir/importexport/convert_types.cc index c4e25aed1e882d..7aff049e2d7548 100644 --- a/tensorflow/core/ir/importexport/convert_types.cc +++ b/tensorflow/core/ir/importexport/convert_types.cc @@ -171,7 +171,7 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { } #define HANDLE_TF_TYPE(tftype, enumerant, name) \ - if (type.isa()) { \ + if (llvm::isa(type)) { \ *dtype = tensorflow::DT_##enumerant; \ return ::tensorflow::OkStatus(); \ } diff --git a/tensorflow/core/ir/ops.td b/tensorflow/core/ir/ops.td index 0cb9ea90d8b92e..b6bbbee3b6e88e 100644 --- a/tensorflow/core/ir/ops.td +++ b/tensorflow/core/ir/ops.td @@ -684,7 +684,7 @@ class TFGraph_CaseLikeRegionOp : TFGraph_RegionOp< RegionAttr getPreservedAttrs(unsigned index) { if (auto attrs = getRegionAttrsAttr()) - return attrs[index].cast(); + return llvm::cast(attrs[index]); return {}; } void setPreservedAttrs(unsigned index, RegionAttr attrs) { diff --git a/tensorflow/core/ir/types/BUILD b/tensorflow/core/ir/types/BUILD index 4577967ebc7d19..73d8e61c03c6a3 100644 --- a/tensorflow/core/ir/types/BUILD +++ b/tensorflow/core/ir/types/BUILD @@ -30,16 +30,10 @@ td_library( gentbl_cc_library( name = "DialectIncGen", - tbl_outs = [ - ( - ["-gen-dialect-decls"], - "dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "dialect.cpp.inc", - ), - ], + tbl_outs = { + "dialect.h.inc": ["-gen-dialect-decls"], + "dialect.cpp.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "dialect.td", deps = [ @@ -49,24 +43,12 @@ gentbl_cc_library( gentbl_cc_library( name = "AttributesIncGen", - tbl_outs = [ - ( - ["-gen-attrdef-decls"], - "attributes.h.inc", - ), - ( - ["-gen-attrdef-defs"], - "attributes.cc.inc", - ), - ( - ["-gen-enum-decls"], - "attributes_enum.h.inc", - ), - ( - ["-gen-enum-defs"], - "attributes_enum.cc.inc", - ), - ], + tbl_outs = { + "attributes.h.inc": ["-gen-attrdef-decls"], + "attributes.cc.inc": ["-gen-attrdef-defs"], + "attributes_enum.h.inc": ["-gen-enum-decls"], + "attributes_enum.cc.inc": ["-gen-enum-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "attributes.td", deps = [":DialectTdFiles"], @@ -74,16 +56,10 @@ gentbl_cc_library( gentbl_cc_library( name = "TypesIncGen", - tbl_outs = [ - ( - ["-gen-typedef-decls"], - "types.h.inc", - ), - ( - ["-gen-typedef-defs"], - "types.cc.inc", - ), - ], + tbl_outs = { + "types.h.inc": ["-gen-typedef-decls"], + "types.cc.inc": ["-gen-typedef-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "types.td", deps = [":DialectTdFiles"], diff --git a/tensorflow/core/ir/types/dialect.cc b/tensorflow/core/ir/types/dialect.cc index 30da252b573e87..886b2265ab1f58 100644 --- a/tensorflow/core/ir/types/dialect.cc +++ b/tensorflow/core/ir/types/dialect.cc @@ -164,15 +164,15 @@ Type TFTypeDialect::parseType(DialectAsmParser& parser) const { // Entry point for Type parsing, TableGen generated code will handle the // dispatch to the individual classes. void TFTypeDialect::printType(Type type, DialectAsmPrinter& printer) const { -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ - if (auto derived_ty = type.dyn_cast()) { \ - printer << name; \ - return; \ +#define HANDLE_TF_TYPE(tftype, enumerant, name) \ + if (auto derived_ty = mlir::dyn_cast(type)) { \ + printer << name; \ + return; \ } -#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) \ - if (auto derived_ty = type.dyn_cast()) { \ - Print##tftype##Type(derived_ty, printer); \ - return; \ +#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) \ + if (auto derived_ty = mlir::dyn_cast(type)) { \ + Print##tftype##Type(derived_ty, printer); \ + return; \ } // NOLINTNEXTLINE: intended redundant include. #include "tensorflow/core/ir/types/types.def" @@ -584,8 +584,8 @@ TensorFlowType TensorFlowRefType::get(Type type) { llvm_unreachable("unexpected integer type"); } } -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ - if (auto derived_ty = type.dyn_cast()) \ +#define HANDLE_TF_TYPE(tftype, enumerant, name) \ + if (auto derived_ty = mlir::dyn_cast(type)) \ return tftype##RefType::get(ctx); #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) @@ -631,7 +631,7 @@ Type TensorFlowRefType::RemoveRef() { if (mlir::isa(*this)) return ComplexType::get(Float64Type::get(ctx)); #define HANDLE_TF_TYPE(tftype, enumerant, name) \ - if (isa()) return tftype##Type::get(ctx); + if (mlir::isa(*this)) return tftype##Type::get(ctx); #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) // NOLINTNEXTLINE diff --git a/tensorflow/core/ir/utility.cc b/tensorflow/core/ir/utility.cc index e04b4df00c98dd..34ab5dc5e44f95 100644 --- a/tensorflow/core/ir/utility.cc +++ b/tensorflow/core/ir/utility.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/ir/utility.h" +#include +#include #include #include "mlir/IR/Block.h" // from @llvm-project diff --git a/tensorflow/core/ir/utils/shape_inference_utils.cc b/tensorflow/core/ir/utils/shape_inference_utils.cc index 753ad1450b8a9e..2e417e25f08362 100644 --- a/tensorflow/core/ir/utils/shape_inference_utils.cc +++ b/tensorflow/core/ir/utils/shape_inference_utils.cc @@ -93,7 +93,7 @@ NamedAttrList GetAllAttributesFromOperation(Operation* op) { // TODO(tlongeri): Should num_elements overflow be handled by the MLIR // verifier? Are there other cases? std::optional GetShapeFromMlirType(Type t) { - if (auto ranked_type = t.dyn_cast()) { + if (auto ranked_type = llvm::dyn_cast(t)) { tensorflow::PartialTensorShape shape; const absl::Status status = tensorflow::PartialTensorShape::BuildPartialTensorShape( @@ -106,7 +106,7 @@ std::optional GetShapeFromMlirType(Type t) { // Extracts a PartialTensorShape from the MLIR attr. std::optional GetShapeFromMlirAttr(Value v) { // Function arguments may have shape attr to describe its output shape. - if (auto arg = v.dyn_cast()) { + if (auto arg = dyn_cast(v)) { Operation* parent_op = arg.getOwner()->getParentOp(); if (auto func_op = llvm::dyn_cast(parent_op)) { int arg_idx = arg.getArgNumber(); @@ -116,7 +116,7 @@ std::optional GetShapeFromMlirAttr(Value v) { // "tf._output_shapes" in certain models may not store the shape as // ShapeAttr, ignore them because we don't know how to interpret it. - auto shape_attr = attrs[0].dyn_cast(); + auto shape_attr = llvm::dyn_cast(attrs[0]); if (shape_attr && shape_attr.hasRank()) return tensorflow::PartialTensorShape(shape_attr.getShape()); } @@ -131,7 +131,7 @@ std::unique_ptr>> GetSubtypesHelper(Type type) { auto type_with_subtypes = - type.cast().getElementType().dyn_cast(); + llvm::dyn_cast(llvm::cast(type).getElementType()); if (!type_with_subtypes || type_with_subtypes.getSubtypes().empty()) { return nullptr; } @@ -317,8 +317,8 @@ LogicalResult InferReturnTypeComponentsForTFOp( if (input_tensors[input]) continue; if (c.requested_input_tensor(input)) { - if (auto attr = operand_as_constant_fn(op->getOperand(input)) - .dyn_cast_or_null()) { + if (auto attr = llvm::dyn_cast_if_present( + operand_as_constant_fn(op->getOperand(input)))) { VLOG(4) << "Requesting " << input << " as constant\n"; tensorflow::Tensor* input_tensor = &tensors.at(input); auto status = ConvertToTensor(attr, input_tensor); @@ -336,7 +336,7 @@ LogicalResult InferReturnTypeComponentsForTFOp( if (c.requested_input_tensor_as_partial_shape(input) && !input_tensors[input] && !input_tensors_as_shapes[input].Handle()) { VLOG(4) << "Requesting " << input << " as shape\n"; - auto op_result = op->getOperand(input).dyn_cast(); + auto op_result = dyn_cast(op->getOperand(input)); if (!op_result) continue; // Resize on first valid shape computed. auto handle = op_result_as_shape_fn(c, op_result); @@ -370,7 +370,7 @@ LogicalResult InferReturnTypeComponentsForTFOp( Type new_element_type = result_element_type_fn(output); // Populate the handle shapes for a resource/variant. if (new_element_type && - new_element_type.isa()) { + isa(new_element_type)) { auto handle_shapes_types = c.output_handle_shapes_and_types(output); if (handle_shapes_types) { SmallVector subtypes; @@ -382,7 +382,7 @@ LogicalResult InferReturnTypeComponentsForTFOp( subtypes.push_back( CreateTensorType(c, shape_n_type.shape, element_type)); } - if (new_element_type.isa()) { + if (isa(new_element_type)) { new_element_type = tf_type::ResourceType::get(subtypes, op->getContext()); } else { diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index 5d131a81f55e7c..9a66862779e5cb 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -261,7 +261,6 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", - "//tensorflow/core:test_main", "//tensorflow/core/platform:status_matchers", "@com_google_absl//absl/base", "@com_google_absl//absl/container:fixed_array", diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index 545baba8110c10..84722b65043a21 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -1054,39 +1054,35 @@ void BatchResourceBase::ProcessFuncBatch( batch->task(batch->num_tasks() - 1).captured_inputs; args.insert(args.end(), captured_inputs.begin(), captured_inputs.end()); - uint64 current_time = EnvTime::NowNanos(); - for (int i = 0; i < batch->num_tasks(); ++i) { - RecordBatchDelayUs((current_time - batch->task(i).start_time) * 1e-3, - model_name, last_task_context->op_kernel().name(), - processed_size); - RecordBatchDelayUsV2((current_time - batch->task(i).start_time) * 1e-3, - model_name, last_task_context->op_kernel().name(), - processed_size); - } + RecordBatchDelayMetrics( + *batch, model_name, op_name, processed_size, + /*batch_schedule_time=*/absl::FromUnixNanos(EnvTime::NowNanos()), + GetBatchTimeout()); + // Releases the cleanup method here, because the callback of the function // library runtime will handle it now. finally.release(); - ProcessFuncBatchImpl( - last_task, args, &combined_outputs, [&](const absl::Status& run_status) { - absl::Status final_status; - auto run_finally = gtl::MakeCleanup([&]() { - // We do the cleanup here as an optimization, so that - // it runs in the underlying TF inter-op threadpool. - // Running it in the threadpool, let's the ensuing - // ops be scheduled faster, because the executor will - // add them to the front of the threadpool's task - // queue rather than the end. - cleanup_fn(final_status); - }); - final_status = run_status; - if (!final_status.ok()) { - return; - } - if (last_task.forced_warmup_batch_size == 0) { - final_status = SplitOutputTensors(combined_outputs, batch.get(), - unbatched_tasks); - } - }); + ProcessFuncBatchImpl(last_task, args, &combined_outputs, + [&](const absl::Status& run_status) { + absl::Status final_status; + auto run_finally = gtl::MakeCleanup([&]() { + // We do the cleanup here as an optimization, so that + // it runs in the underlying TF inter-op threadpool. + // Running it in the threadpool, let's the ensuing + // ops be scheduled faster, because the executor will + // add them to the front of the threadpool's task + // queue rather than the end. + cleanup_fn(final_status); + }); + final_status = run_status; + if (!final_status.ok()) { + return; + } + if (last_task.forced_warmup_batch_size == 0) { + final_status = SplitOutputTensors( + combined_outputs, batch.get(), unbatched_tasks); + } + }); } // Processes a batch of one or more BatchTask entries. @@ -1248,6 +1244,17 @@ absl::Status BatchResourceBase::LookupOrCreateBatcherQueue( return absl::OkStatus(); } +std::optional BatchResourceBase::GetBatchTimeout() const { + if (batcher_) { + return absl::Microseconds(batcher_queue_options_.batch_timeout_micros); + } + if (adaptive_batcher_) { + return absl::Microseconds( + adaptive_batcher_queue_options_.batch_timeout_micros); + } + return std::nullopt; +} + void BatchResourceBase::SplitBatchCostsAndRecordMetrics( const std::string& model_name, const std::string& op_name, const std::vector>& @@ -1324,5 +1331,50 @@ void BatchResourceBase::SplitBatchCostsAndRecordMetrics( } } +void BatchResourceBase::RecordBatchDelayMetrics( + const BatchResourceBase::BatchT& batch, const std::string& model_name, + const std::string& op_name, int64_t processed_size, + absl::Time batch_schedule_time, + std::optional batch_timeout) { + absl::Time earliest_task_start_time = absl::InfiniteFuture(); + for (int i = 0; i < batch.num_tasks(); ++i) { + earliest_task_start_time = + std::min(earliest_task_start_time, + absl::FromUnixNanos(batch.task(i).start_time)); + } + for (int i = 0; i < batch.num_tasks(); ++i) { + const BatchResourceBase::BatchTask& task = batch.task(i); + + const absl::Time start_time = absl::FromUnixNanos(task.start_time); + const absl::Duration total_scheduler_delay = + batch_schedule_time - start_time; + RecordBatchDelayUs(absl::ToInt64Microseconds(total_scheduler_delay), + model_name, op_name, processed_size); + RecordBatchDelayUsV2(absl::ToInt64Microseconds(total_scheduler_delay), + model_name, op_name, processed_size); + + RequestCost* request_cost = task.request_cost; + // Skip recording the cost if the request_cost is null. + if (!request_cost) continue; + + // The duration from when the task was enqueued to when the earliest task in + // its batch has been in the queue for a duration of batch_timeout (i.e. + // when the task is eligible being scheduled into a batch, regardless of the + // number of tasks in the queue) is considered as batching delay, and the + // remaining duration in the queue is considered as queueing delay. + const absl::Duration remaining_batch_timeout = + std::max(earliest_task_start_time + + batch_timeout.value_or(absl::ZeroDuration()) - start_time, + absl::ZeroDuration()); + const absl::Duration batching_delay = + std::min(remaining_batch_timeout, total_scheduler_delay); + const absl::Duration queueing_delay = + total_scheduler_delay - batching_delay; + request_cost->RecordMetrics( + {{"batching_delay_msecs", absl::ToDoubleMilliseconds(batching_delay)}, + {"queueing_delay_msecs", absl::ToDoubleMilliseconds(queueing_delay)}}); + } +} + } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h index 633724de79a21b..54e83c82367f3e 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.h +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -27,6 +28,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/blocking_counter.h" +#include "absl/time/time.h" #include "xla/tsl/platform/criticality.h" #include "tensorflow/core/common_runtime/cost_measurement_registry.h" #include "tensorflow/core/common_runtime/request_cost.h" @@ -82,7 +84,7 @@ class BatchResourceBase : public ResourceBase { // Note input from one batch-op invocation is valid and considered a // specialized `slice`. struct BatchTask : public tensorflow::serving::BatchTask { - BatchTask() : criticality_val(tsl::criticality::GetCriticality()){}; + BatchTask() : criticality_val(tsl::criticality::GetCriticality()) {}; // A unique ID to identify this invocation of Batch. int64_t guid; @@ -274,6 +276,14 @@ class BatchResourceBase : public ResourceBase { batch_cost_measurements, int64_t processed_size, BatchT& batch); + // Records information about the delay between a task being registered and + // that task being scheduled into a batch. + static void RecordBatchDelayMetrics( + const BatchResourceBase::BatchT& batch, const std::string& model_name, + const std::string& op_name, int64_t processed_size, + absl::Time batch_schedule_time, + std::optional batch_timeout); + private: // Implementation of calling the process batch function. virtual void ProcessFuncBatchImpl( @@ -346,6 +356,10 @@ class BatchResourceBase : public ResourceBase { const string& op_name, BatcherQueueT** queue); + // Returns the batch timeout for the configured scheduler, or nullopt if the + // scheduler does not have such a parameter. + std::optional GetBatchTimeout() const; + SessionMetadata session_metadata_; absl::Mutex outstanding_batch_mu_; diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc index b3e2548b58d326..ee912719fba873 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "absl/time/clock.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/tsl/platform/criticality.h" @@ -129,10 +130,12 @@ class TestGcuCostMeasurement : public CostMeasurement { REGISTER_COST_MEASUREMENT("test_gcu", TestGcuCostMeasurement); std::unique_ptr MakeBatchTask( - const int64_t task_size, RequestCost* request_cost) { + const int64_t task_size, RequestCost* request_cost, + absl::Time start_time = absl::UnixEpoch()) { auto task = std::make_unique(); task->inputs.push_back(Tensor(DT_DOUBLE, TensorShape({task_size, 1}))); task->request_cost = request_cost; + task->start_time = absl::ToUnixNanos(start_time); return task; } @@ -418,6 +421,136 @@ TEST(SplitBatchCostsAndRecordMetricsTest, GlobalBatchStatsProcessedSize) { original_cumulative_processed_size + 4); } +TEST(RecordBatchDelayMetricsTest, + TwoRequestsWithNoQueueingDelayAndSchedulingAtBatchTimeout) { + const absl::Duration batch_timeout = absl::Seconds(1); + const absl::Duration task2_delay = batch_timeout / 4; + const absl::Time task1_start_time = absl::Now(); + const absl::Time task2_start_time = task1_start_time + task2_delay; + const absl::Time batch_schedule_time = task1_start_time + batch_timeout; + + BatchResourceBase::BatchT batch; + RequestCost cost1, cost2; + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost1, task1_start_time)); + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost2, task2_start_time)); + batch.Close(); + + BatchResourceBase::RecordBatchDelayMetrics( + batch, "model_name", "op_name", /*processed_size=*/20, + batch_schedule_time, batch_timeout); + + EXPECT_THAT( + batch.task(0).request_cost->GetMetrics(), + UnorderedElementsAre(Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(batch_timeout)), + Pair("queueing_delay_msecs", 0))); + EXPECT_THAT(batch.task(1).request_cost->GetMetrics(), + UnorderedElementsAre( + Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(batch_timeout - task2_delay)), + Pair("queueing_delay_msecs", 0))); +} + +TEST(RecordBatchDelayMetricsTest, + TwoRequestsWithNoQueueingDelayAndSchedulingAfterSecondRequest) { + const absl::Duration batch_timeout = absl::Seconds(1); + const absl::Duration task2_delay = batch_timeout / 4; + const absl::Duration scheduling_delay = batch_timeout / 10; + const absl::Time task1_start_time = absl::Now(); + const absl::Time task2_start_time = task1_start_time + task2_delay; + const absl::Time batch_schedule_time = + task1_start_time + task2_delay + scheduling_delay; + + BatchResourceBase::BatchT batch; + RequestCost cost1, cost2; + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost1, task1_start_time)); + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost2, task2_start_time)); + batch.Close(); + + BatchResourceBase::RecordBatchDelayMetrics( + batch, "model_name", "op_name", /*processed_size=*/20, + batch_schedule_time, batch_timeout); + + EXPECT_THAT( + batch.task(0).request_cost->GetMetrics(), + UnorderedElementsAre( + Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(task2_delay + scheduling_delay)), + Pair("queueing_delay_msecs", 0))); + EXPECT_THAT( + batch.task(1).request_cost->GetMetrics(), + UnorderedElementsAre(Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(scheduling_delay)), + Pair("queueing_delay_msecs", 0))); +} + +TEST(RecordBatchDelayMetricsTest, TwoRequestWithQueueingDelay) { + const absl::Duration batch_timeout = absl::Seconds(1); + const absl::Duration task2_delay = batch_timeout / 4; + const absl::Duration queueing_delay = 5 * batch_timeout; + const absl::Time task1_start_time = absl::Now(); + const absl::Time task2_start_time = task1_start_time + task2_delay; + const absl::Time batch_schedule_time = + task1_start_time + batch_timeout + queueing_delay; + + BatchResourceBase::BatchT batch; + RequestCost cost1, cost2; + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost1, task1_start_time)); + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost2, task2_start_time)); + batch.Close(); + + BatchResourceBase::RecordBatchDelayMetrics( + batch, "model_name", "op_name", /*processed_size=*/20, + batch_schedule_time, batch_timeout); + + EXPECT_THAT( + batch.task(0).request_cost->GetMetrics(), + UnorderedElementsAre(Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(batch_timeout)), + Pair("queueing_delay_msecs", + absl::ToDoubleMilliseconds(queueing_delay)))); + EXPECT_THAT(batch.task(1).request_cost->GetMetrics(), + UnorderedElementsAre( + Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(batch_timeout - task2_delay)), + Pair("queueing_delay_msecs", + absl::ToDoubleMilliseconds(queueing_delay)))); +} + +TEST(RecordBatchDelayMetricsTest, + TwoRequestsWithQueueingDelayAndSecondArrivingAfterBatchTimeout) { + const absl::Duration batch_timeout = absl::Seconds(1); + const absl::Duration task2_delay = 3 * batch_timeout; + const absl::Duration queueing_delay = 5 * batch_timeout; + const absl::Time task1_start_time = absl::Now(); + const absl::Time task2_start_time = task1_start_time + task2_delay; + const absl::Time batch_schedule_time = + task1_start_time + task2_delay + queueing_delay; + + BatchResourceBase::BatchT batch; + RequestCost cost1, cost2; + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost1, task1_start_time)); + batch.AddTask(MakeBatchTask(/*task_size=*/1, &cost2, task2_start_time)); + batch.Close(); + + BatchResourceBase::RecordBatchDelayMetrics( + batch, "model_name", "op_name", /*processed_size=*/20, + batch_schedule_time, batch_timeout); + + EXPECT_THAT(batch.task(0).request_cost->GetMetrics(), + UnorderedElementsAre( + Pair("batching_delay_msecs", + absl::ToDoubleMilliseconds(batch_timeout)), + Pair("queueing_delay_msecs", + absl::ToDoubleMilliseconds(task2_delay - batch_timeout + + queueing_delay)))); + EXPECT_THAT( + batch.task(1).request_cost->GetMetrics(), + UnorderedElementsAre(Pair("batching_delay_msecs", 0), + Pair("queueing_delay_msecs", + absl::ToDoubleMilliseconds(queueing_delay)))); +} + class BatchResourceBaseTest : public ::testing::Test { protected: // Like BatchResourceBase but overrides abstract methods, one of which diff --git a/tensorflow/core/kernels/conv_ops_fused_int8.cc b/tensorflow/core/kernels/conv_ops_fused_int8.cc index 5d7af5c8482be9..7f919d5087dbbe 100644 --- a/tensorflow/core/kernels/conv_ops_fused_int8.cc +++ b/tensorflow/core/kernels/conv_ops_fused_int8.cc @@ -570,21 +570,20 @@ void operator()( constexpr auto type = se::dnn::ToDataType::value; constexpr auto bias_type = se::dnn::ToDataType::value; - const bool use_cudnn_frontend = CudnnUseFrontend(); AutotuneEntry autotune_entry; if (!FusedConvAutotuneMap::GetInstance()->Find(fused_conv_parameters, &autotune_entry)) { - VLOG(2) << "Autotuning fused convolution (use_frontend=" - << use_cudnn_frontend << "): " << fused_conv_parameters.ToString(); + VLOG(2) << "Autotuning fused convolution: " + << fused_conv_parameters.ToString(); profiler::ScopedAnnotation trace("cudnn_autotuning"); std::vector> runners; auto dnn = stream->parent()->AsDnn(); CHECK_NE(dnn, nullptr); TF_CHECK_OK(dnn->GetFusedConvolveRunners( - use_cudnn_frontend, se::dnn::ConvolutionKind::FORWARD, type, bias_type, - type, conv_scale, side_input_scale, /*leakyrelu_alpha=*/0.0, stream, - conv_input_desc, filter_desc, bias_desc, output_desc, conv_desc, + se::dnn::ConvolutionKind::FORWARD, type, bias_type, type, conv_scale, + side_input_scale, /*leakyrelu_alpha=*/0.0, stream, conv_input_desc, + filter_desc, bias_desc, output_desc, conv_desc, /*use_fallback=*/false, dnn_activation_mode, GetNumericOptionsForCuDnn(), &runners)); @@ -621,7 +620,7 @@ void operator()( } } - if (!CudnnUseFrontend() || found_working_engine) { + if (found_working_engine) { auto runners_or = BestCudnnConvAlgorithm( results, std::move(runners)); OP_REQUIRES_OK(ctx, runners_or.status()); @@ -632,10 +631,9 @@ void operator()( auto dnn = stream->parent()->AsDnn(); CHECK_NE(dnn, nullptr); TF_CHECK_OK(dnn->GetFusedConvolveRunners( - use_cudnn_frontend, se::dnn::ConvolutionKind::FORWARD, type, - bias_type, type, conv_scale, side_input_scale, leakyrelu_alpha, - stream, conv_input_desc, filter_desc, bias_desc, output_desc, - conv_desc, + se::dnn::ConvolutionKind::FORWARD, type, bias_type, type, conv_scale, + side_input_scale, leakyrelu_alpha, stream, conv_input_desc, + filter_desc, bias_desc, output_desc, conv_desc, /*use_fallback=*/true, dnn_activation_mode, GetNumericOptionsForCuDnn(), &fallback_runners)); diff --git a/tensorflow/core/kernels/conv_ops_gpu.cc b/tensorflow/core/kernels/conv_ops_gpu.cc index 608fcaca5a76e1..d47a0ff628f6b3 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu.cc @@ -103,11 +103,11 @@ StatusOr> AutotuneFusedConv( return absl::InvalidArgumentError("No DNN in stream executor."); } TF_RETURN_IF_ERROR(dnn->GetFusedConvolveRunners( - CudnnUseFrontend(), se::dnn::ConvolutionKind::FORWARD, element_type, - element_type, element_type, conv_scale, side_input_scale, - leakyrelu_alpha, stream, input_desc, filter_desc, bias_desc, - output_desc, conv_desc, /*use_fallback=*/false, activation_mode, - GetNumericOptionsForCuDnn(), &runners)); + se::dnn::ConvolutionKind::FORWARD, element_type, element_type, + element_type, conv_scale, side_input_scale, leakyrelu_alpha, stream, + input_desc, filter_desc, bias_desc, output_desc, conv_desc, + /*use_fallback=*/false, activation_mode, GetNumericOptionsForCuDnn(), + &runners)); auto launch_func = [&](se::ScratchAllocator* allocator_used, @@ -142,7 +142,7 @@ StatusOr> AutotuneFusedConv( } } - if (!CudnnUseFrontend() || found_working_engine) { + if (found_working_engine) { TF_ASSIGN_OR_RETURN(autotune_entry, BestCudnnConvAlgorithm( results, std::move(runners))); @@ -154,11 +154,11 @@ StatusOr> AutotuneFusedConv( std::vector> fallback_runners; TF_RETURN_IF_ERROR(dnn->GetFusedConvolveRunners( - CudnnUseFrontend(), se::dnn::ConvolutionKind::FORWARD, element_type, - element_type, element_type, conv_scale, side_input_scale, - leakyrelu_alpha, stream, input_desc, filter_desc, bias_desc, - output_desc, conv_desc, /*use_fallback=*/true, activation_mode, - GetNumericOptionsForCuDnn(), &fallback_runners)); + se::dnn::ConvolutionKind::FORWARD, element_type, element_type, + element_type, conv_scale, side_input_scale, leakyrelu_alpha, stream, + input_desc, filter_desc, bias_desc, output_desc, conv_desc, + /*use_fallback=*/true, activation_mode, GetNumericOptionsForCuDnn(), + &fallback_runners)); TF_ASSIGN_OR_RETURN(auto fallback_results, internal::AutotuneConvImpl( @@ -289,10 +289,10 @@ StatusOr> AutotuneUnfusedConv( return absl::InvalidArgumentError("No DNN in stream executor."); } TF_RETURN_IF_ERROR(dnn->GetConvolveRunners( - CudnnUseFrontend(), kind, element_type, element_type, stream, - input_desc, input_ptr, filter_desc, filter_ptr, output_desc, output_ptr, - conv_desc, /*use_fallback=*/false, &rz_allocator, - GetNumericOptionsForCuDnn(), &runners)); + kind, element_type, element_type, stream, input_desc, input_ptr, + filter_desc, filter_ptr, output_desc, output_ptr, conv_desc, + /*use_fallback=*/false, &rz_allocator, GetNumericOptionsForCuDnn(), + &runners)); auto launch_func = [&](se::ScratchAllocator* allocator_used, const std::unique_ptr& runner, @@ -323,7 +323,7 @@ StatusOr> AutotuneUnfusedConv( } } - if (!CudnnUseFrontend() || found_working_engine) { + if (found_working_engine) { TF_ASSIGN_OR_RETURN( autotune_entry, BestCudnnConvAlgorithm(results, std::move(runners))); @@ -334,10 +334,10 @@ StatusOr> AutotuneUnfusedConv( << conv_parameters.ToString(); std::vector> fallback_runners; TF_RETURN_IF_ERROR(dnn->GetConvolveRunners( - CudnnUseFrontend(), kind, element_type, element_type, stream, - input_desc, input_ptr, filter_desc, filter_ptr, output_desc, - output_ptr, conv_desc, /*use_fallback=*/true, &rz_allocator, - GetNumericOptionsForCuDnn(), &fallback_runners)); + kind, element_type, element_type, stream, input_desc, input_ptr, + filter_desc, filter_ptr, output_desc, output_ptr, conv_desc, + /*use_fallback=*/true, &rz_allocator, GetNumericOptionsForCuDnn(), + &fallback_runners)); TF_ASSIGN_OR_RETURN(auto fallback_results, internal::AutotuneConvImpl( diff --git a/tensorflow/core/kernels/cwise_op_leakyrelu.cc b/tensorflow/core/kernels/cwise_op_leakyrelu.cc index 0de8a65a01aa71..7bff5c9f61c1a1 100644 --- a/tensorflow/core/kernels/cwise_op_leakyrelu.cc +++ b/tensorflow/core/kernels/cwise_op_leakyrelu.cc @@ -36,8 +36,8 @@ namespace internal { template struct leakyrelu_op { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit leakyrelu_op(float val = 0.2f) - EIGEN_NO_THROW { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit leakyrelu_op( + float val = 0.2f) { m_alpha = Scalar(val); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index eff4f1c145518f..c6f46b3253968c 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -1195,7 +1195,7 @@ void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, auto handle = HandleFromInput(ctx, 2); absl::Status s = ctx->resource_manager()->Lookup( handle.container(), handle.name(), &manager); - if (errors::IsNotFound(s)) { + if (absl::IsNotFound(s)) { owns_resource = true; OP_REQUIRES_OK( ctx, diff --git a/tensorflow/core/kernels/data/cache_ops.cc b/tensorflow/core/kernels/data/cache_ops.cc index 32b01e74a1ab2e..0dce7f73215f92 100644 --- a/tensorflow/core/kernels/data/cache_ops.cc +++ b/tensorflow/core/kernels/data/cache_ops.cc @@ -94,7 +94,7 @@ void DeleteMemoryCacheOp::Compute(OpKernelContext* ctx) { const ResourceHandle& handle = ctx->input(0).flat()(0); // The resource might have been already deleted by the dataset. absl::Status s = ctx->resource_manager()->Delete(handle); - if (!errors::IsNotFound(s)) { + if (!absl::IsNotFound(s)) { OP_REQUIRES_OK(ctx, s); } } diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index f1d7e58c141158..34dad6e46a59c4 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -939,8 +939,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { << dump_status; } - std::string graph_hash = - strings::StrCat(strings::Hex(hash, strings::kZeroPad16)); + std::string graph_hash = absl::StrCat(absl::Hex(hash, absl::kZeroPad16)); LOG(INFO) << "Graph def serialized to hash: " << graph_hash; *output = new Dataset(ctx, input, path, graph_hash, reader_path_prefix_, @@ -1667,8 +1666,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { hash_dir_(hash_dir), run_id_(run_id) { if (run_id_.empty()) { - run_id_ = strings::StrCat( - strings::Hex(random::New64(), strings::kZeroPad4)); + run_id_ = absl::StrCat(absl::Hex(random::New64(), absl::kZeroPad4)); } run_dir_ = io::JoinPath(dataset()->writer_path_prefix_, hash_dir_, run_id_); diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc index e996fac56ae648..7d2eed92576931 100644 --- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc @@ -301,7 +301,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { *end_of_sequence = false; return absl::OkStatus(); } - if (errors::IsOutOfRange(s) && !record.empty()) { + if (absl::IsOutOfRange(s) && !record.empty()) { uint64 body_size = current_pos + record.size() - (dataset()->header_bytes_ + dataset()->footer_bytes_); diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index d7e6faeec7a345..0f6f7b6fe892d6 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -144,7 +144,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { ctx, state_, out_tensors, model_node()); if (s.ok()) { *end_of_sequence = false; - } else if (errors::IsOutOfRange(s)) { + } else if (absl::IsOutOfRange(s)) { // `next_func` may deliberately raise `errors::OutOfRange` // to indicate that we should terminate the iteration. s = absl::OkStatus(); diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index bf034a569733f5..d733640e3bd38b 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -190,7 +190,7 @@ class MapDatasetOp::Dataset : public DatasetBase { absl::Status s = instantiated_captured_func_->Run( ctx, std::move(args), out_tensors, model_node()); - if (errors::IsOutOfRange(s)) { + if (absl::IsOutOfRange(s)) { if (dataset()->preserve_cardinality_) { // To guarantee that the transformation preserves the cardinality of // the dataset, we convert `OutOfRange` to `InvalidArgument` as the diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index d06cb1ffe419f0..98cb9f34f3aa30 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -86,7 +86,7 @@ void MakeDatasetHelper(OpKernelContext* ctx, absl::Status s = RewriteDataset(ctx, input, std::move(config_factory), /*record_fingerprint=*/false, &rewritten); *output = rewritten.release(); - if (errors::IsDeadlineExceeded(s)) { + if (absl::IsDeadlineExceeded(s)) { // Ignore DeadlineExceeded as it implies that the attempted rewrite took too // long which should not prevent further computation. LOG(WARNING) << s.ToString(); diff --git a/tensorflow/core/kernels/data/parallel_filter_dataset_op.cc b/tensorflow/core/kernels/data/parallel_filter_dataset_op.cc index e8a16c0a08e20a..3099d6e654fc65 100644 --- a/tensorflow/core/kernels/data/parallel_filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_filter_dataset_op.cc @@ -407,7 +407,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { *end_of_sequence = false; return absl::OkStatus(); } - if (errors::IsOutOfRange(result->status)) { + if (absl::IsOutOfRange(result->status)) { // `predicate` may deliberately raise `errors::OutOfRange` to indicate // that we should terminate the iteration early. return errors::InvalidArgument( diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 92e0827c7cd928..65373934f2512e 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -625,7 +625,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { *end_of_sequence = false; return absl::OkStatus(); } - if (errors::IsOutOfRange(result->status)) { + if (absl::IsOutOfRange(result->status)) { if (preserve_cardinality_) { // To guarantee that the transformation preserves the cardinality of // the dataset, we convert `OutOfRange` to `InvalidArgument` as the diff --git a/tensorflow/core/kernels/data/shard_dataset_op.cc b/tensorflow/core/kernels/data/shard_dataset_op.cc index 522c95364a10e4..6edd6259655ce6 100644 --- a/tensorflow/core/kernels/data/shard_dataset_op.cc +++ b/tensorflow/core/kernels/data/shard_dataset_op.cc @@ -214,7 +214,7 @@ class ShardDatasetOp::Dataset : public DatasetBase { absl::Status s = input_impl_->Skip(ctx, dataset()->num_shards_ - next_index_, end_of_sequence, &num_skipped); - if (*end_of_sequence || errors::IsOutOfRange(s)) { + if (*end_of_sequence || absl::IsOutOfRange(s)) { // `dataset()->require_non_empty_` implies that this transformation // was introduced by auto_sharding rewrite, so it's acceptable // produce an error message that assumes auto-sharding context. diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 04e1ae70179972..c331fc16a69ddb 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -816,7 +816,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSeed2, &seed2)); RandomSeeds seeds(seed, seed2); bool owns_resource = false; - if (errors::IsNotFound(s)) { + if (absl::IsNotFound(s)) { owns_resource = true; OP_REQUIRES_OK( ctx, @@ -848,7 +848,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, absl::Status s = ctx->resource_manager()->Lookup( handle.container(), handle.name(), &manager); bool owns_resource = false; - if (errors::IsNotFound(s)) { + if (absl::IsNotFound(s)) { owns_resource = true; LOG(WARNING) << "Failed to find seed generator resource. Falling back to " "using a non-deterministically seeded generator and " @@ -1076,7 +1076,7 @@ void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx, absl::Status s = ctx->resource_manager()->Lookup( handle.container(), handle.name(), &manager); bool owns_resource = false; - if (errors::IsNotFound(s)) { + if (absl::IsNotFound(s)) { owns_resource = true; OP_REQUIRES_OK( ctx, diff --git a/tensorflow/core/kernels/data/text_line_dataset_op.cc b/tensorflow/core/kernels/data/text_line_dataset_op.cc index f7b858b6566437..1d07fc81c04455 100644 --- a/tensorflow/core/kernels/data/text_line_dataset_op.cc +++ b/tensorflow/core/kernels/data/text_line_dataset_op.cc @@ -121,7 +121,7 @@ class TextLineDatasetOp::Dataset : public DatasetBase { out_tensors->push_back(std::move(line_contents)); *end_of_sequence = false; return absl::OkStatus(); - } else if (!errors::IsOutOfRange(s)) { + } else if (!absl::IsOutOfRange(s)) { // Report non-EOF errors to the caller. return s; } diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op.cc b/tensorflow/core/kernels/data/tf_record_dataset_op.cc index df91a2b9ea9e0a..d391c6dcd731f3 100644 --- a/tensorflow/core/kernels/data/tf_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/tf_record_dataset_op.cc @@ -155,7 +155,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } out_tensors->pop_back(); - if (!errors::IsOutOfRange(s)) { + if (!absl::IsOutOfRange(s)) { // In case of other errors e.g., DataLoss, we still move forward // the file index so that it works with ignore_errors. // Otherwise the same file will repeat. @@ -197,7 +197,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { *end_of_sequence = false; return absl::OkStatus(); } - if (!errors::IsOutOfRange(s)) { + if (!absl::IsOutOfRange(s)) { // In case of other errors e.g., DataLoss, we still move forward // the file index so that it works with ignore_errors. // Otherwise the same file will repeat. diff --git a/tensorflow/core/kernels/gpu_prim.h b/tensorflow/core/kernels/gpu_prim.h index bef22b50ada12c..bcc873bfb03180 100644 --- a/tensorflow/core/kernels/gpu_prim.h +++ b/tensorflow/core/kernels/gpu_prim.h @@ -44,10 +44,9 @@ __device__ __forceinline__ void ThreadStoreVolatilePtr( Eigen::numext::bit_cast(val); } -template <> -__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer( +__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer( Eigen::half *ptr, Int2Type /*is_primitive*/) { - uint16_t result = *reinterpret_cast(ptr); + const uint16_t result = *reinterpret_cast(ptr); return Eigen::numext::bit_cast(result); } @@ -59,10 +58,8 @@ __device__ __forceinline__ void ThreadStoreVolatilePtr( Eigen::numext::bit_cast(val); } -template <> -__device__ __forceinline__ Eigen::bfloat16 -ThreadLoadVolatilePointer(Eigen::bfloat16 *ptr, - Int2Type /*is_primitive*/) { +__device__ __forceinline__ Eigen::bfloat16 ThreadLoadVolatilePointer( + Eigen::bfloat16 *ptr, Int2Type /*is_primitive*/) { uint16_t result = *reinterpret_cast(ptr); return Eigen::numext::bit_cast(result); } diff --git a/tensorflow/core/kernels/image/BUILD b/tensorflow/core/kernels/image/BUILD index 480eadb279bb2a..a47f1771243a1b 100644 --- a/tensorflow/core/kernels/image/BUILD +++ b/tensorflow/core/kernels/image/BUILD @@ -318,7 +318,10 @@ tf_kernel_library( tf_kernel_library( name = "sample_distorted_bounding_box_op", prefix = "sample_distorted_bounding_box_op", - deps = IMAGE_DEPS + ["//tensorflow/core/kernels:stateless_random_ops"], + deps = IMAGE_DEPS + [ + "//tensorflow/core/kernels:stateless_random_ops", + "@com_google_absl//absl/log:check", + ], ) tf_kernel_library( diff --git a/tensorflow/core/kernels/image/adjust_hue_op.cc b/tensorflow/core/kernels/image/adjust_hue_op.cc index fb089f13f8edd9..8795185c365dfd 100644 --- a/tensorflow/core/kernels/image/adjust_hue_op.cc +++ b/tensorflow/core/kernels/image/adjust_hue_op.cc @@ -11,16 +11,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #define EIGEN_USE_THREADS #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif -#include "tensorflow/core/kernels/image/adjust_hue_op.h" - -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -28,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/image/adjust_hue_op.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/work_sharder.h" diff --git a/tensorflow/core/kernels/image/adjust_saturation_op.cc b/tensorflow/core/kernels/image/adjust_saturation_op.cc index 5c108aa2ab7434..5387e636f69a4f 100644 --- a/tensorflow/core/kernels/image/adjust_saturation_op.cc +++ b/tensorflow/core/kernels/image/adjust_saturation_op.cc @@ -12,22 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include #define EIGEN_USE_THREADS #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif -#include "tensorflow/core/kernels/image/adjust_saturation_op.h" - -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/image/adjust_saturation_op.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/work_sharder.h" diff --git a/tensorflow/core/kernels/image/encode_jpeg_op_test.cc b/tensorflow/core/kernels/image/encode_jpeg_op_test.cc index 922a3aff5f72b0..0e51b4e244141f 100644 --- a/tensorflow/core/kernels/image/encode_jpeg_op_test.cc +++ b/tensorflow/core/kernels/image/encode_jpeg_op_test.cc @@ -39,7 +39,7 @@ TEST_F(EncodeJpegWithVariableQualityTest, FailsForInvalidQuality) { {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); AddInputFromArray(TensorShape({}), {200}); absl::Status status = RunOpKernel(); - EXPECT_TRUE(errors::IsInvalidArgument(status)); + EXPECT_TRUE(absl::IsInvalidArgument(status)); EXPECT_TRUE(absl::StartsWith(status.message(), "quality must be in [0,100]")); } diff --git a/tensorflow/core/kernels/image/extract_image_patches_op.cc b/tensorflow/core/kernels/image/extract_image_patches_op.cc index a1dbcd9efa3650..b40c59147e51b5 100644 --- a/tensorflow/core/kernels/image/extract_image_patches_op.cc +++ b/tensorflow/core/kernels/image/extract_image_patches_op.cc @@ -15,6 +15,7 @@ limitations under the License. // See docs in ../ops/image_ops.cc. +#include #define USE_EIGEN_TENSOR #define EIGEN_USE_THREADS diff --git a/tensorflow/core/kernels/image/random_crop_op.cc b/tensorflow/core/kernels/image/random_crop_op.cc index 987001c58c0a69..1fceed794d29a4 100644 --- a/tensorflow/core/kernels/image/random_crop_op.cc +++ b/tensorflow/core/kernels/image/random_crop_op.cc @@ -15,6 +15,8 @@ limitations under the License. // See docs in ../ops/image_ops.cc. +#include + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/kernels/image/resize_bicubic_op.cc b/tensorflow/core/kernels/image/resize_bicubic_op.cc index 23e6251f8a0f48..338a9fbfcf9a98 100644 --- a/tensorflow/core/kernels/image/resize_bicubic_op.cc +++ b/tensorflow/core/kernels/image/resize_bicubic_op.cc @@ -14,6 +14,11 @@ limitations under the License. ==============================================================================*/ // See docs in ../ops/image_ops.cc +#include +#include +#include +#include +#include #define EIGEN_USE_THREADS #include diff --git a/tensorflow/core/kernels/image/sample_distorted_bounding_box_op.cc b/tensorflow/core/kernels/image/sample_distorted_bounding_box_op.cc index 90e26496ed8f0a..a754a8cec1fc62 100644 --- a/tensorflow/core/kernels/image/sample_distorted_bounding_box_op.cc +++ b/tensorflow/core/kernels/image/sample_distorted_bounding_box_op.cc @@ -15,8 +15,13 @@ limitations under the License. // See docs in ../ops/image_ops.cc. #include +#include #include +#include +#include +#include +#include "absl/log/check.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc index f632785545df68..6be7d241921739 100644 --- a/tensorflow/core/kernels/matmul_op_fused.cc +++ b/tensorflow/core/kernels/matmul_op_fused.cc @@ -369,8 +369,8 @@ StatusOr> AutotuneFusedMatmul( return errors::Internal("No DNN in stream executor."); } TF_RETURN_IF_ERROR(dnn->GetFusedMatmulRunners( - CudnnUseFrontend(), element_type, element_type, element_type, stream, - trans_a, trans_b, m, n, k, lda, ldb, ldc, activation_mode, + element_type, element_type, element_type, stream, trans_a, trans_b, m, + n, k, lda, ldb, ldc, activation_mode, /*use_fallback=*/false, GetNumericOptionsForCuDnn(), &runners)); auto launch_func = @@ -413,8 +413,8 @@ StatusOr> AutotuneFusedMatmul( std::vector> fallback_runners; TF_RETURN_IF_ERROR(dnn->GetFusedMatmulRunners( - CudnnUseFrontend(), element_type, element_type, element_type, stream, - trans_a, trans_b, m, n, k, lda, ldb, ldc, activation_mode, + element_type, element_type, element_type, stream, trans_a, trans_b, m, + n, k, lda, ldb, ldc, activation_mode, /*use_fallback=*/true, GetNumericOptionsForCuDnn(), &fallback_runners)); diff --git a/tensorflow/core/kernels/rnn/BUILD b/tensorflow/core/kernels/rnn/BUILD index 3b9298c5bac42f..17b2545986a7c6 100644 --- a/tensorflow/core/kernels/rnn/BUILD +++ b/tensorflow/core/kernels/rnn/BUILD @@ -49,6 +49,8 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/kernels:eigen_helpers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@eigen_archive//:eigen3", ], ) diff --git a/tensorflow/core/kernels/rnn/gru_ops.cc b/tensorflow/core/kernels/rnn/gru_ops.cc index ed424e922a4fe3..f1722497cc81c3 100644 --- a/tensorflow/core/kernels/rnn/gru_ops.cc +++ b/tensorflow/core/kernels/rnn/gru_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/rnn/gru_ops.h" diff --git a/tensorflow/core/kernels/rnn/lstm_ops.cc b/tensorflow/core/kernels/rnn/lstm_ops.cc index 5bf12c3b56cd62..8fb0dcfd9ce645 100644 --- a/tensorflow/core/kernels/rnn/lstm_ops.cc +++ b/tensorflow/core/kernels/rnn/lstm_ops.cc @@ -13,15 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" #define EIGEN_USE_THREADS #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "tensorflow/core/kernels/rnn/lstm_ops.h" - -#include #include #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive @@ -31,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/rnn/lstm_ops.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc b/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc index 795791254a475d..4b3867edfbf4bd 100644 --- a/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU diff --git a/tensorflow/core/lib/gif/gif_io.cc b/tensorflow/core/lib/gif/gif_io.cc index e4123665dfe788..61c9ea32792be8 100644 --- a/tensorflow/core/lib/gif/gif_io.cc +++ b/tensorflow/core/lib/gif/gif_io.cc @@ -61,6 +61,7 @@ uint8* Decode(const void* srcdata, int datasize, string* error_string, bool expand_animations) { int error_code = D_GIF_SUCCEEDED; InputBufferInfo info = {reinterpret_cast(srcdata), datasize}; + /// NOTE: After this, gif file is mostly not initialized! GifFileType* gif_file = DGifOpen(static_cast(&info), &input_callback, &error_code); const auto cleanup = gtl::MakeCleanup([gif_file]() { @@ -82,20 +83,18 @@ uint8* Decode(const void* srcdata, int datasize, // Stop load if no images are detected or the allocation of the last image // buffer was failed. if (gif_file->ImageCount <= 0 || - gif_file->SavedImages[gif_file->ImageCount - 1].RasterBits == nullptr || - gif_file->Error == D_GIF_ERR_EOF_TOO_SOON) { + gif_file->SavedImages[gif_file->ImageCount - 1].RasterBits == nullptr) { return nullptr; } LOG(ERROR) << *error_string; } + int target_num_frames = gif_file->ImageCount; - if (gif_file->ImageCount <= 0) { + if (target_num_frames <= 0) { *error_string = "gif file does not contain any image"; return nullptr; } - int target_num_frames = gif_file->ImageCount; - // Don't request more memory than needed for each frame, preventing OOM int max_frame_width = 0; int max_frame_height = 0; diff --git a/tensorflow/core/lib/gtl/BUILD b/tensorflow/core/lib/gtl/BUILD index 31a74dca33bb07..338d6fe6fb4529 100644 --- a/tensorflow/core/lib/gtl/BUILD +++ b/tensorflow/core/lib/gtl/BUILD @@ -22,7 +22,7 @@ package( # tensorflow/examples/custom_ops_doc/simple_hash_table uses map_util "//tensorflow/examples/custom_ops_doc/simple_hash_table:__pkg__", # tensorflow/core/profiler/convert uses map_util - "//tensorflow/core/profiler/convert:__pkg__", + "@org_xprof//xprof/convert:__pkg__", ], licenses = ["notice"], ) diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput.pbtxt new file mode 100644 index 00000000000000..19e97d82db8d86 --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput.pbtxt @@ -0,0 +1,107 @@ +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "preserved_weights" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "accumulator" + type: DT_FLOAT + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_accumulator" + type: DT_FLOAT + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput.pbtxt new file mode 100644 index 00000000000000..92d6891fb9ef60 --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput.pbtxt @@ -0,0 +1,135 @@ +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "preserved_weights" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "accumulator" + type: DT_FLOAT + } + input_arg { + name: "momenta" + type: DT_FLOAT + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_accumulator" + type: DT_FLOAT + } + output_arg { + name: "updated_momenta" + type: DT_FLOAT + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "use_nesterov" + type: "bool" + } + attr { + name: "exponent" + type: "float" + } + attr { + name: "beta1" + type: "float" + } + attr { + name: "beta2" + type: "float" + } + attr { + name: "epsilon" + type: "float" + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput.pbtxt new file mode 100644 index 00000000000000..850cd016b7297f --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput.pbtxt @@ -0,0 +1,131 @@ +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "preserved_weights" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "momenta" + type: DT_FLOAT + } + input_arg { + name: "velocity" + type: DT_FLOAT + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_momenta" + type: DT_FLOAT + } + output_arg { + name: "updated_velocity" + type: DT_FLOAT + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "use_sum_inside_sqrt" + type: "bool" + } + attr { + name: "beta1" + type: "float" + } + attr { + name: "beta2" + type: "float" + } + attr { + name: "epsilon" + type: "float" + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt new file mode 100644 index 00000000000000..9a2cd09ba62f36 --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput.pbtxt @@ -0,0 +1,104 @@ +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "preserved_weights" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "tables" + type: DT_FLOAT + number_attr: "N" + } + input_arg { + name: "hyperparameters" + type: DT_FLOAT + number_attr: "M" + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + output_arg { + name: "updated_tables" + type: DT_FLOAT + number_attr: "N" + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "M" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "optimizer_custom_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput.pbtxt new file mode 100644 index 00000000000000..3b4da49183e69c --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput.pbtxt @@ -0,0 +1,135 @@ +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "preserved_weights" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "accumulator" + type: DT_FLOAT + } + input_arg { + name: "linear" + type: DT_FLOAT + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_accumulator" + type: DT_FLOAT + } + output_arg { + name: "updated_linear" + type: DT_FLOAT + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "multiply_linear_by_learning_rate" + type: "bool" + } + attr { + name: "beta" + type: "float" + } + attr { + name: "learning_rate_power" + type: "float" + } + attr { + name: "l1_regularization_strength" + type: "float" + } + attr { + name: "l2_regularization_strength" + type: "float" + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput.pbtxt new file mode 100644 index 00000000000000..4f3f35d0f34f4a --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput.pbtxt @@ -0,0 +1,99 @@ +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "preserved_weights" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt new file mode 100644 index 00000000000000..1379b143fc3cb3 --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput.pbtxt @@ -0,0 +1,165 @@ +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + output_arg { + name: "activations" + type: DT_FLOAT + } + output_arg { + name: "preserved_valencies" + type: DT_INT32 + } + output_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + attr { + name: "input_size" + type: "int" + has_minimum: true + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_computation" + type: "func" + } + attr { + name: "quantization_config_low" + type: "float" + } + attr { + name: "quantization_config_high" + type: "float" + } + attr { + name: "quantization_config_num_buckets" + type: "int" + has_minimum: true + } + attr { + name: "table_name" + type: "string" + } +} +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + output_arg { + name: "activations" + type: DT_FLOAT + } + output_arg { + name: "preserved_valencies" + type: DT_INT32 + } + output_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + attr { + name: "input_size" + type: "int" + has_minimum: true + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_computation" + type: "func" + } + attr { + name: "quantization_config_low" + type: "float" + } + attr { + name: "quantization_config_high" + type: "float" + } + attr { + name: "quantization_config_num_buckets" + type: "int" + has_minimum: true + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt index fbf266b10e35fa..41f2c7e450e1be 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt @@ -114,3 +114,75 @@ op { type: "string" } } +op { + name: "XlaSparseDenseMatmulGradWithAdagradAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "accumulator" + type: DT_FLOAT + } + input_arg { + name: "num_minibatches_per_physical_sparse_core" + type: DT_INT32 + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_accumulator" + type: DT_FLOAT + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradAndStaticBufferSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradAndStaticBufferSize.pbtxt index 359a038ea9b6b7..d163a385310401 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradAndStaticBufferSize.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradAndStaticBufferSize.pbtxt @@ -75,3 +75,87 @@ op { type: "string" } } +op { + name: "XlaSparseDenseMatmulGradWithAdagradAndStaticBufferSize" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "accumulator" + type: DT_FLOAT + } + input_arg { + name: "num_minibatches_per_physical_sparse_core" + type: DT_INT32 + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_accumulator" + type: DT_FLOAT + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_ids_per_sparse_core" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "max_unique_ids_per_sparse_core" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt index 5150a4f23b598f..bb07dc634bfff6 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt @@ -170,3 +170,103 @@ op { type: "string" } } +op { + name: "XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "accumulator" + type: DT_FLOAT + } + input_arg { + name: "momenta" + type: DT_FLOAT + } + input_arg { + name: "num_minibatches_per_physical_sparse_core" + type: DT_INT32 + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_accumulator" + type: DT_FLOAT + } + output_arg { + name: "updated_momenta" + type: DT_FLOAT + } + attr { + name: "use_nesterov" + type: "bool" + } + attr { + name: "exponent" + type: "float" + } + attr { + name: "beta1" + type: "float" + } + attr { + name: "beta2" + type: "float" + } + attr { + name: "epsilon" + type: "float" + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradMomentumAndStaticBufferSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradMomentumAndStaticBufferSize.pbtxt index 4fd6fa9bb5a5b2..bbcb164284b8bd 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradMomentumAndStaticBufferSize.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradMomentumAndStaticBufferSize.pbtxt @@ -103,3 +103,115 @@ op { type: "string" } } +op { + name: "XlaSparseDenseMatmulGradWithAdagradMomentumAndStaticBufferSize" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "accumulator" + type: DT_FLOAT + } + input_arg { + name: "momenta" + type: DT_FLOAT + } + input_arg { + name: "num_minibatches_per_physical_sparse_core" + type: DT_INT32 + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_accumulator" + type: DT_FLOAT + } + output_arg { + name: "updated_momenta" + type: DT_FLOAT + } + attr { + name: "use_nesterov" + type: "bool" + } + attr { + name: "exponent" + type: "float" + } + attr { + name: "beta1" + type: "float" + } + attr { + name: "beta2" + type: "float" + } + attr { + name: "epsilon" + type: "float" + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_ids_per_sparse_core" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "max_unique_ids_per_sparse_core" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt index aaa27b25954a9e..1041a351d30fa5 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt @@ -162,3 +162,99 @@ op { type: "string" } } +op { + name: "XlaSparseDenseMatmulGradWithAdamAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "momenta" + type: DT_FLOAT + } + input_arg { + name: "velocity" + type: DT_FLOAT + } + input_arg { + name: "num_minibatches_per_physical_sparse_core" + type: DT_INT32 + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_momenta" + type: DT_FLOAT + } + output_arg { + name: "updated_velocity" + type: DT_FLOAT + } + attr { + name: "use_sum_inside_sqrt" + type: "bool" + } + attr { + name: "beta1" + type: "float" + } + attr { + name: "beta2" + type: "float" + } + attr { + name: "epsilon" + type: "float" + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdamAndStaticBufferSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdamAndStaticBufferSize.pbtxt index 5024f72b5c66cb..992de662552114 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdamAndStaticBufferSize.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdamAndStaticBufferSize.pbtxt @@ -99,3 +99,111 @@ op { type: "string" } } +op { + name: "XlaSparseDenseMatmulGradWithAdamAndStaticBufferSize" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "momenta" + type: DT_FLOAT + } + input_arg { + name: "velocity" + type: DT_FLOAT + } + input_arg { + name: "num_minibatches_per_physical_sparse_core" + type: DT_INT32 + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_momenta" + type: DT_FLOAT + } + output_arg { + name: "updated_velocity" + type: DT_FLOAT + } + attr { + name: "use_sum_inside_sqrt" + type: "bool" + } + attr { + name: "beta1" + type: "float" + } + attr { + name: "beta2" + type: "float" + } + attr { + name: "epsilon" + type: "float" + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_ids_per_sparse_core" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "max_unique_ids_per_sparse_core" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithCsrInput.pbtxt index 9c4e7f05f03570..c5fa98477af82e 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithCsrInput.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithCsrInput.pbtxt @@ -60,3 +60,72 @@ op { type: "string" } } +op { + name: "XlaSparseDenseMatmulGradWithCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "tables" + type: DT_FLOAT + number_attr: "N" + } + input_arg { + name: "hyperparameters" + type: DT_FLOAT + number_attr: "M" + } + input_arg { + name: "num_minibatches_per_physical_sparse_core" + type: DT_INT32 + } + output_arg { + name: "updated_tables" + type: DT_FLOAT + number_attr: "N" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "M" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "custom_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt index 261f25bebfd7df..693f05a2d553fa 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt @@ -170,3 +170,103 @@ op { type: "string" } } +op { + name: "XlaSparseDenseMatmulGradWithFtrlAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "accumulator" + type: DT_FLOAT + } + input_arg { + name: "linear" + type: DT_FLOAT + } + input_arg { + name: "num_minibatches_per_physical_sparse_core" + type: DT_INT32 + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_accumulator" + type: DT_FLOAT + } + output_arg { + name: "updated_linear" + type: DT_FLOAT + } + attr { + name: "multiply_linear_by_learning_rate" + type: "bool" + } + attr { + name: "beta" + type: "float" + } + attr { + name: "learning_rate_power" + type: "float" + } + attr { + name: "l1_regularization_strength" + type: "float" + } + attr { + name: "l2_regularization_strength" + type: "float" + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithFtrlAndStaticBufferSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithFtrlAndStaticBufferSize.pbtxt index f2f57f2f744d7b..44254c5d5f8993 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithFtrlAndStaticBufferSize.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithFtrlAndStaticBufferSize.pbtxt @@ -103,3 +103,115 @@ op { type: "string" } } +op { + name: "XlaSparseDenseMatmulGradWithFtrlAndStaticBufferSize" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "accumulator" + type: DT_FLOAT + } + input_arg { + name: "linear" + type: DT_FLOAT + } + input_arg { + name: "num_minibatches_per_physical_sparse_core" + type: DT_INT32 + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_accumulator" + type: DT_FLOAT + } + output_arg { + name: "updated_linear" + type: DT_FLOAT + } + attr { + name: "multiply_linear_by_learning_rate" + type: "bool" + } + attr { + name: "beta" + type: "float" + } + attr { + name: "learning_rate_power" + type: "float" + } + attr { + name: "l1_regularization_strength" + type: "float" + } + attr { + name: "l2_regularization_strength" + type: "float" + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_ids_per_sparse_core" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "max_unique_ids_per_sparse_core" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt index 9446a6fa98c515..5b496b7b543f0a 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt @@ -98,3 +98,67 @@ op { type: "string" } } +op { + name: "XlaSparseDenseMatmulGradWithSgdAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "num_minibatches_per_physical_sparse_core" + type: DT_INT32 + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithSgdAndStaticBufferSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithSgdAndStaticBufferSize.pbtxt index dbb06c95f6d643..362489eef3bef7 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithSgdAndStaticBufferSize.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithSgdAndStaticBufferSize.pbtxt @@ -67,3 +67,79 @@ op { type: "string" } } +op { + name: "XlaSparseDenseMatmulGradWithSgdAndStaticBufferSize" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "num_minibatches_per_physical_sparse_core" + type: DT_INT32 + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_ids_per_sparse_core" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "max_unique_ids_per_sparse_core" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulWithCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulWithCsrInput.pbtxt index 2b4bc1dcba74ac..2093beb487799f 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulWithCsrInput.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulWithCsrInput.pbtxt @@ -51,3 +51,63 @@ op { type: "string" } } +op { + name: "XlaSparseDenseMatmulWithCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "num_minibatches_per_physical_sparse_core" + type: DT_INT32 + } + output_arg { + name: "activations" + type: DT_FLOAT + } + attr { + name: "input_size" + type: "int" + has_minimum: true + } + attr { + name: "quantization_config_low" + type: "float" + } + attr { + name: "quantization_config_high" + type: "float" + } + attr { + name: "quantization_config_num_buckets" + type: "int" + has_minimum: true + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulWithStaticBufferSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulWithStaticBufferSize.pbtxt index 471ded1635244a..f712f5e406e4ac 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulWithStaticBufferSize.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulWithStaticBufferSize.pbtxt @@ -63,3 +63,75 @@ op { type: "string" } } +op { + name: "XlaSparseDenseMatmulWithStaticBufferSize" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "num_minibatches_per_physical_sparse_core" + type: DT_INT32 + } + output_arg { + name: "activations" + type: DT_FLOAT + } + attr { + name: "input_size" + type: "int" + has_minimum: true + } + attr { + name: "quantization_config_low" + type: "float" + } + attr { + name: "quantization_config_high" + type: "float" + } + attr { + name: "quantization_config_num_buckets" + type: "int" + has_minimum: true + } + attr { + name: "max_ids_per_sparse_core" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "max_unique_ids_per_sparse_core" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index df2588439eda9b..1ce120a11054a8 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -67609,6 +67609,803 @@ op { has_minimum: true } } +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "preserved_weights" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "accumulator" + type: DT_FLOAT + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_accumulator" + type: DT_FLOAT + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "preserved_weights" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "accumulator" + type: DT_FLOAT + } + input_arg { + name: "momenta" + type: DT_FLOAT + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_accumulator" + type: DT_FLOAT + } + output_arg { + name: "updated_momenta" + type: DT_FLOAT + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "use_nesterov" + type: "bool" + } + attr { + name: "exponent" + type: "float" + } + attr { + name: "beta1" + type: "float" + } + attr { + name: "beta2" + type: "float" + } + attr { + name: "epsilon" + type: "float" + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "preserved_weights" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "momenta" + type: DT_FLOAT + } + input_arg { + name: "velocity" + type: DT_FLOAT + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_momenta" + type: DT_FLOAT + } + output_arg { + name: "updated_velocity" + type: DT_FLOAT + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "use_sum_inside_sqrt" + type: "bool" + } + attr { + name: "beta1" + type: "float" + } + attr { + name: "beta2" + type: "float" + } + attr { + name: "epsilon" + type: "float" + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "preserved_weights" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "tables" + type: DT_FLOAT + number_attr: "N" + } + input_arg { + name: "hyperparameters" + type: DT_FLOAT + number_attr: "M" + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + output_arg { + name: "updated_tables" + type: DT_FLOAT + number_attr: "N" + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "M" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "optimizer_custom_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "preserved_weights" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "accumulator" + type: DT_FLOAT + } + input_arg { + name: "linear" + type: DT_FLOAT + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_accumulator" + type: DT_FLOAT + } + output_arg { + name: "updated_linear" + type: DT_FLOAT + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "multiply_linear_by_learning_rate" + type: "bool" + } + attr { + name: "beta" + type: "float" + } + attr { + name: "learning_rate_power" + type: "float" + } + attr { + name: "l1_regularization_strength" + type: "float" + } + attr { + name: "l2_regularization_strength" + type: "float" + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + input_arg { + name: "preserved_valencies" + type: DT_INT32 + } + input_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + input_arg { + name: "preserved_weights" + type: DT_FLOAT + } + input_arg { + name: "activation_gradients" + type: DT_FLOAT + } + input_arg { + name: "learning_rate" + type: DT_FLOAT + } + input_arg { + name: "combiner_weights_learning_rate" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_embedding_table" + type: DT_FLOAT + } + output_arg { + name: "updated_weights" + type: DT_FLOAT + } + attr { + name: "clip_weight_min" + type: "float" + default_value { + f: -inf + } + } + attr { + name: "clip_weight_max" + type: "float" + default_value { + f: inf + } + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_table_vjp_computation" + type: "func" + } + attr { + name: "combiner_weights_vjp_computation" + type: "func" + } + attr { + name: "table_name" + type: "string" + } +} +op { + name: "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput" + input_arg { + name: "row_pointers" + type: DT_INT32 + } + input_arg { + name: "sorted_sample_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_token_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_pos_ids" + type: DT_INT32 + } + input_arg { + name: "sorted_gains" + type: DT_FLOAT + } + input_arg { + name: "embedding_table" + type: DT_FLOAT + } + input_arg { + name: "weights" + type: DT_FLOAT + } + output_arg { + name: "activations" + type: DT_FLOAT + } + output_arg { + name: "preserved_valencies" + type: DT_INT32 + } + output_arg { + name: "preserved_vectors" + type: DT_FLOAT + } + attr { + name: "input_size" + type: "int" + has_minimum: true + } + attr { + name: "max_valency" + type: "int" + has_minimum: true + } + attr { + name: "num_weights" + type: "int" + has_minimum: true + } + attr { + name: "combiner_computation" + type: "func" + } + attr { + name: "quantization_config_low" + type: "float" + } + attr { + name: "quantization_config_high" + type: "float" + } + attr { + name: "quantization_config_num_buckets" + type: "int" + has_minimum: true + } + attr { + name: "table_name" + type: "string" + } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } +} op { name: "XlaSparseDenseMatmulGradWithAdagradAndCsrInput" input_arg { @@ -67673,6 +68470,13 @@ op { name: "table_name" type: "string" } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } } op { name: "XlaSparseDenseMatmulGradWithAdagradAndStaticBufferSize" @@ -67750,6 +68554,13 @@ op { name: "table_name" type: "string" } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } } op { name: "XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput" @@ -67843,6 +68654,13 @@ op { name: "table_name" type: "string" } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } } op { name: "XlaSparseDenseMatmulGradWithAdagradMomentumAndStaticBufferSize" @@ -67948,6 +68766,13 @@ op { name: "table_name" type: "string" } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } } op { name: "XlaSparseDenseMatmulGradWithAdamAndCsrInput" @@ -68037,6 +68862,13 @@ op { name: "table_name" type: "string" } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } } op { name: "XlaSparseDenseMatmulGradWithAdamAndStaticBufferSize" @@ -68138,6 +68970,13 @@ op { name: "table_name" type: "string" } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } } op { name: "XlaSparseDenseMatmulGradWithCsrInput" @@ -68200,6 +69039,13 @@ op { name: "table_name" type: "string" } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } } op { name: "XlaSparseDenseMatmulGradWithFtrlAndCsrInput" @@ -68293,6 +69139,13 @@ op { name: "table_name" type: "string" } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } } op { name: "XlaSparseDenseMatmulGradWithFtrlAndStaticBufferSize" @@ -68398,6 +69251,13 @@ op { name: "table_name" type: "string" } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } } op { name: "XlaSparseDenseMatmulGradWithSgdAndCsrInput" @@ -68455,6 +69315,13 @@ op { name: "table_name" type: "string" } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } } op { name: "XlaSparseDenseMatmulGradWithSgdAndStaticBufferSize" @@ -68524,6 +69391,13 @@ op { name: "table_name" type: "string" } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } } op { name: "XlaSparseDenseMatmulWithCsrInput" @@ -68577,6 +69451,13 @@ op { name: "table_name" type: "string" } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } } op { name: "XlaSparseDenseMatmulWithStaticBufferSize" @@ -68642,6 +69523,13 @@ op { name: "table_name" type: "string" } + attr { + name: "num_sparsecores_per_device" + type: "int" + default_value { + i: -1 + } + } } op { name: "XlaSplitND" diff --git a/tensorflow/core/platform/strcat.h b/tensorflow/core/platform/strcat.h index 9a11dd2dbb529f..c4052364a66db8 100644 --- a/tensorflow/core/platform/strcat.h +++ b/tensorflow/core/platform/strcat.h @@ -28,21 +28,6 @@ namespace strings { // NOLINTBEGIN(misc-unused-using-decls) using tsl::strings::AlphaNum; using tsl::strings::Hex; -using tsl::strings::kZeroPad10; -using tsl::strings::kZeroPad11; -using tsl::strings::kZeroPad12; -using tsl::strings::kZeroPad13; -using tsl::strings::kZeroPad14; -using tsl::strings::kZeroPad15; -using tsl::strings::kZeroPad16; -using tsl::strings::kZeroPad2; -using tsl::strings::kZeroPad3; -using tsl::strings::kZeroPad4; -using tsl::strings::kZeroPad5; -using tsl::strings::kZeroPad6; -using tsl::strings::kZeroPad7; -using tsl::strings::kZeroPad8; -using tsl::strings::kZeroPad9; using tsl::strings::PadSpec; using tsl::strings::StrAppend; using tsl::strings::StrCat; diff --git a/tensorflow/core/platform/tensor_coding.cc b/tensorflow/core/platform/tensor_coding.cc index b5aa5ffe150c8e..53328afe0bcf79 100644 --- a/tensorflow/core/platform/tensor_coding.cc +++ b/tensorflow/core/platform/tensor_coding.cc @@ -27,6 +27,7 @@ limitations under the License. #if defined(TENSORFLOW_PROTOBUF_USES_CORD) #include "strings/cord_varint.h" +#include "util/gtl/stl_util.h" #endif // defined(TENSORFLOW_PROTOBUF_USES_CORD) namespace tensorflow { diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 5359e934e2bf01..0d5c900684814b 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -1,1477 +1,74 @@ -load("//tensorflow:tensorflow.bzl", "if_oss", "tf_cc_test") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_alias", "tf_profiler_copts") +load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], # TODO(matthurd): Update to profiler:internal after xprof migration. + default_visibility = ["//tensorflow/core/profiler:internal"], licenses = ["notice"], ) cc_library( - name = "xplane_to_op_metrics_db", - srcs = ["xplane_to_op_metrics_db.cc"], - hdrs = ["xplane_to_op_metrics_db.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_metrics_db_combiner", - ":op_stack", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/utils:cost_utils", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "//tensorflow/core/profiler/utils:op_utils", - "//tensorflow/core/profiler/utils:trace_utils", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - ], -) - -tf_cc_test( - name = "xplane_to_op_metrics_db_test", - size = "small", - srcs = ["xplane_to_op_metrics_db_test.cc"], - deps = [ - ":xplane_to_op_metrics_db", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - ], -) - -cc_library( - name = "op_metrics_db_combiner", - srcs = ["op_metrics_db_combiner.cc"], - hdrs = ["op_metrics_db_combiner.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/platform:protobuf", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - ], -) - -cc_library( - name = "op_metrics_to_record", - srcs = ["op_metrics_to_record.cc"], - hdrs = ["op_metrics_to_record.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "op_stack", - hdrs = ["op_stack.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "op_stats_to_hlo_stats", - srcs = ["op_stats_to_hlo_stats.cc"], - hdrs = ["op_stats_to_hlo_stats.h"], - deps = [ - ":data_table_utils", - ":op_metrics_to_record", - "//tensorflow/core/profiler/protobuf:hlo_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - ], -) - -cc_library( - name = "op_stats_to_roofline_model", - srcs = ["op_stats_to_roofline_model.cc"], - hdrs = ["op_stats_to_roofline_model.h"], - deps = [ - ":op_metrics_db_combiner", - ":op_metrics_to_record", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:roofline_model_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", - "@com_google_absl//absl/log:check", - "@local_tsl//tsl/platform:protobuf", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "op_stats_to_op_profile", - srcs = ["op_stats_to_op_profile.cc"], - hdrs = ["op_stats_to_op_profile.h"], - deps = [ - ":op_profile_builder", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_profile_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "op_stats_to_overview_page", - srcs = ["op_stats_to_overview_page.cc"], - hdrs = ["op_stats_to_overview_page.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_metrics_to_record", - ":op_stats_to_input_pipeline_analysis", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc", - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:overview_page_proto_cc", - "//tensorflow/core/profiler/protobuf:power_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", - "//tensorflow/core/profiler/utils:hardware_type_utils", - "//tensorflow/core/profiler/utils:html_utils", - "//tensorflow/core/profiler/utils:kernel_stats_utils", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:format_utils", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - ], -) - -cc_library( - name = "op_stats_to_pod_stats", - srcs = ["op_stats_to_pod_stats.cc"], - hdrs = ["op_stats_to_pod_stats.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:pod_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", - "//tensorflow/core/profiler/utils:event_span", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -tf_cc_test( - name = "op_stats_to_pod_stats_test", - srcs = ["op_stats_to_pod_stats_test.cc"], - deps = [ - ":op_stats_to_pod_stats", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", - "//tensorflow/core/profiler/utils:event_span", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "op_stats_to_pod_viewer", - srcs = ["op_stats_to_pod_viewer.cc"], - hdrs = ["op_stats_to_pod_viewer.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_stats_to_pod_stats", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:pod_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:pod_viewer_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", - "@com_google_absl//absl/log:check", - ], -) - -tf_cc_test( - name = "op_stats_to_pod_viewer_test", - srcs = ["op_stats_to_pod_viewer_test.cc"], - deps = [ - ":op_stats_to_pod_viewer", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:pod_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", - "//tensorflow/core/profiler/utils:event_span", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -tf_cc_test( - name = "xplane_to_trace_container_test", - srcs = ["xplane_to_trace_container_test.cc"], - deps = [ - ":xplane_to_trace_container", - "//tensorflow/core/util/proto:proto_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "op_stats_to_input_pipeline_analysis", - srcs = ["op_stats_to_input_pipeline_analysis.cc"], - hdrs = ["op_stats_to_input_pipeline_analysis.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_metrics_to_record", - ":profile_time_breakdown", - ":step_events_to_steps_db", - ":tpu_input_pipeline_analysis_constants", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/platform:logging", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/protobuf:tpu_input_pipeline_proto_cc", - "//tensorflow/core/profiler/utils:diagnostics", - "//tensorflow/core/profiler/utils:event_span", - "//tensorflow/core/profiler/utils:html_utils", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "//tensorflow/core/profiler/utils:tpu_step_breakdown_utils", - "//tensorflow/core/profiler/utils:tpu_step_details_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:format_utils", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/util:stats_calculator_portable", - ], -) - -tf_cc_test( - name = "op_stats_to_input_pipeline_analysis_test", - srcs = ["op_stats_to_input_pipeline_analysis_test.cc"], - deps = [ - ":op_stats_to_input_pipeline_analysis", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:event_span", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/tsl/profiler/utils:timespan", - ], -) - -cc_library( - name = "op_stats_to_tf_stats", - srcs = ["op_stats_to_tf_stats.cc"], - hdrs = ["op_stats_to_tf_stats.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_metrics_to_record", - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc", - "//tensorflow/core/profiler/utils:kernel_stats_utils", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -tf_cc_test( - name = "op_stats_to_tf_stats_test", - size = "small", - srcs = ["op_stats_to_tf_stats_test.cc"], - deps = [ - ":op_stats_to_tf_stats", - ":xplane_to_op_stats", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "step_events_to_steps_db", - srcs = ["step_events_to_steps_db.cc"], - hdrs = ["step_events_to_steps_db.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_metrics_db_combiner", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:event_span", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@local_xla//xla/tsl/profiler/utils:timespan", - ], -) - -cc_library( - name = "xplane_to_op_stats", - srcs = ["xplane_to_op_stats.cc"], - hdrs = ["xplane_to_op_stats.h"], + name = "xplane_to_step_stats", + srcs = ["xplane_to_step_stats.cc"], + hdrs = ["xplane_to_step_stats.h"], copts = tf_profiler_copts(), - visibility = ["@local_xla//xla/tsl/profiler:friends"], deps = [ - ":duty_cycle_combiner", - ":duty_cycle_tracker", - ":op_metrics_db_combiner", - ":repository", - ":step_events_to_steps_db", - ":xplane_to_kernel_stats_db", - ":xplane_to_op_metrics_db", - ":xplane_to_step_events", - ":xplane_to_tf_functions", - "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", - "//tensorflow/core/profiler/utils:device_caps_utils", - "//tensorflow/core/profiler/utils:event_span", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/utils:gpu_event_stats", - "//tensorflow/core/profiler/utils:hardware_type_utils", - "//tensorflow/core/profiler/utils:hlo_module_map", - "//tensorflow/core/profiler/utils:hlo_proto_map", - "//tensorflow/core/profiler/utils:kernel_stats_utils", - "//tensorflow/core/profiler/utils:op_utils", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_utils", "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", "@local_xla//xla/tsl/profiler/utils:math_utils", "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - ], -) - -cc_library( - name = "multi_xplanes_to_op_stats", - srcs = ["multi_xplanes_to_op_stats.cc"], - hdrs = ["multi_xplanes_to_op_stats.h"], - copts = tf_profiler_copts(), - deps = [ - ":op_stats_combiner", - ":preprocess_single_host_xplane", - ":repository", - ":xplane_to_op_stats", - "//tensorflow/core:portable_gif_internal", - "//tensorflow/core/platform:status", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/utils:hardware_type_utils", - "//tensorflow/core/profiler/utils:step_intersection", - "@com_google_absl//absl/status", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:statusor", - ], -) - -tf_cc_test( - name = "xplane_to_op_stats_test", - size = "small", - srcs = ["xplane_to_op_stats_test.cc"], - deps = [ - ":duty_cycle_tracker", - ":multi_xplanes_to_op_stats", - ":repository", - ":step_events_to_steps_db", - ":xplane_to_op_stats", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - ], -) - -cc_library( - name = "xplane_to_step_events", - srcs = ["xplane_to_step_events.cc"], - hdrs = ["xplane_to_step_events.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:event_span", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "//tensorflow/core/profiler/utils:trace_utils", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - ], -) - -tf_cc_test( - name = "xplane_to_step_events_test", - size = "small", - srcs = ["xplane_to_step_events_test.cc"], - deps = [ - ":xplane_to_step_events", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/utils:event_span", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - ], -) - -cc_library( - name = "xplane_to_kernel_stats_db", - srcs = ["xplane_to_kernel_stats_db.cc"], - hdrs = ["xplane_to_kernel_stats_db.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "//tensorflow/core/profiler/utils:gpu_event_stats", - "//tensorflow/core/profiler/utils:hlo_module_map", - "//tensorflow/core/profiler/utils:kernel_stats_utils", - "//tensorflow/core/profiler/utils:trace_utils", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@org_xprof//xprof/utils:gpu_event_stats", ], ) -tf_cc_test( - name = "xplane_to_kernel_stats_db_test", - size = "small", - srcs = ["xplane_to_kernel_stats_db_test.cc"], - deps = [ - ":xplane_to_kernel_stats_db", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "//tensorflow/core/profiler/utils:kernel_stats_utils", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - ], -) +# DO NOT ADD NEW DEPENDENCIES TO ANY TARGET IN THIS FILE. +# Instead, use //third_party/xprof/convert. cc_library( - name = "xplane_to_tf_functions", - srcs = ["xplane_to_tf_functions.cc"], - hdrs = ["xplane_to_tf_functions.h"], + name = "hlo_proto_to_memory_visualization_utils", + hdrs = ["hlo_proto_to_memory_visualization_utils.h"], copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:timespan", + visibility = [ + "//learning/deepmind/jax/statix:__subpackages__", + "//platforms/xla/tools/shardy_migration:__subpackages__", + "//smartass/brain/tpu_worker:__subpackages__", + "//tensorflow/core/profiler:internal", ], -) - -tf_cc_test( - name = "xplane_to_tf_functions_test", - size = "small", - srcs = ["xplane_to_tf_functions_test.cc"], deps = [ - ":xplane_to_tf_functions", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "//tensorflow/core/profiler/utils:xplane_utils", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@org_xprof//xprof/convert:hlo_proto_to_memory_visualization_utils", ], ) cc_library( - name = "xplane_to_memory_profile", - srcs = ["xplane_to_memory_profile.cc"], - hdrs = ["xplane_to_memory_profile.h"], + name = "profile_time_breakdown", + hdrs = ["profile_time_breakdown.h"], copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/framework:protos_all_cc", - "//tensorflow/core/profiler/protobuf:memory_profile_proto_cc", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_utils", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + visibility = [ + "//platforms/performance/autograppler/utils:__subpackages__", + "//tensorflow/core/profiler:internal", ], -) - -tf_cc_test( - name = "xplane_to_memory_profile_test", - size = "small", - srcs = ["xplane_to_memory_profile_test.cc"], deps = [ - ":xplane_to_memory_profile", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:memory_profile_proto_cc", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:group_events", + "@org_xprof//xprof/convert:profile_time_breakdown", ], ) cc_library( - name = "op_stats_combiner", - srcs = ["op_stats_combiner.cc"], - hdrs = ["op_stats_combiner.h"], + name = "xplane_to_op_stats", + hdrs = ["xplane_to_op_stats.h"], copts = tf_profiler_copts(), - deps = [ - ":op_metrics_db_combiner", - ":xplane_to_tf_functions", - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:power_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/protobuf:topology_proto_cc", - "//tensorflow/core/profiler/utils:hardware_type_utils", - "//tensorflow/core/profiler/utils:kernel_stats_utils", - "//tensorflow/core/profiler/utils:step_intersection", - "@com_google_absl//absl/container:flat_hash_map", - ], -) - -tf_cc_test( - name = "op_stats_combiner_test", - srcs = ["op_stats_combiner_test.cc"], - deps = [ - ":op_stats_combiner", - "//tensorflow/core:portable_gif_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "//tensorflow/core/profiler/utils:step_intersection", - "@com_google_absl//absl/container:flat_hash_map", + visibility = [ + "//platforms/xla/tools/multihost_hlo_runner/hybrid_sim:__subpackages__", + "//tensorflow/core/profiler:internal", ], -) - -cc_library( - name = "preprocess_single_host_xplane", - srcs = ["preprocess_single_host_xplane.cc"], - hdrs = ["preprocess_single_host_xplane.h"], - copts = tf_profiler_copts(), - visibility = ["//tensorflow/core/profiler:internal"], - deps = [ - "//tensorflow/core/profiler/utils:derived_timeline", - "//tensorflow/core/profiler/utils:xplane_schema", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:preprocess_xplane", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - ], -) - -cc_library( - name = "xplane_to_tools_data", - srcs = ["xplane_to_tools_data.cc"], - hdrs = ["xplane_to_tools_data.h"], - copts = tf_profiler_copts(), - deps = [ - ":compute_inference_latency", - ":hlo_to_tools_data", - ":multi_xplanes_to_op_stats", - ":multi_xspace_to_inference_stats", - ":op_stats_to_hlo_stats", - ":op_stats_to_input_pipeline_analysis", - ":op_stats_to_op_profile", - ":op_stats_to_overview_page", - ":op_stats_to_pod_viewer", - ":op_stats_to_roofline_model", - ":op_stats_to_tf_stats", - ":preprocess_single_host_xplane", - ":process_megascale_dcn", - ":repository", - ":tool_options", - ":xplane_to_dcn_collective_stats", - ":xplane_to_memory_profile", - ":xplane_to_op_stats", - ":xplane_to_tf_data_stats", - ":xplane_to_tool_names", - ":xplane_to_trace_container", - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:dcn_slack_analysis_proto_cc", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc", - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:op_profile_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:overview_page_proto_cc", - "//tensorflow/core/profiler/protobuf:roofline_model_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_data_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc", - "//tensorflow/core/profiler/utils:hardware_type_utils", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/convert:xplane_to_trace_events", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@org_xprof//xprof/convert/trace_viewer:trace_events_to_json", - "@org_xprof//xprof/convert/trace_viewer:trace_viewer_visibility", - ], -) - -cc_library( - name = "xplane_to_tf_data_stats", - srcs = ["xplane_to_tf_data_stats.cc"], - hdrs = ["xplane_to_tf_data_stats.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:tf_data_stats_proto_cc", - "//tensorflow/core/profiler/utils:html_utils", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - ], -) - -tf_cc_test( - name = "xplane_to_tf_data_stats_test", - size = "small", - srcs = ["xplane_to_tf_data_stats_test.cc"], - tags = if_oss([ - "manual", - "no_oss", - ]), # b/169705709, no protobuf matchers in OSS. - deps = [ - ":xplane_to_tf_data_stats", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:tf_data_stats_proto_cc", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_test_utils", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - ], -) - -cc_library( - name = "xplane_to_step_stats", - srcs = ["xplane_to_step_stats.cc"], - hdrs = ["xplane_to_step_stats.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/profiler/utils:gpu_event_stats", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_utils", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - ], -) - -cc_library( - name = "hlo_to_tools_data", - srcs = ["hlo_to_tools_data.cc"], - hdrs = ["hlo_to_tools_data.h"], - copts = tf_profiler_copts(), - visibility = ["//visibility:private"], - deps = [ - ":hlo_proto_to_graph_view", - ":hlo_proto_to_memory_visualization_utils", - ":repository", - ":tool_options", - ":xplane_to_hlo", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:memory_viewer_preprocess_proto_cc", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_xla//xla/service:hlo_proto_cc", - ], -) - -cc_library( - name = "hlo_proto_to_memory_visualization_utils", - srcs = ["hlo_proto_to_memory_visualization_utils.cc"], - hdrs = ["hlo_proto_to_memory_visualization_utils.h"], - copts = tf_profiler_copts(), - visibility = ["//tensorflow/core/profiler/protobuf:memory_viewer_friends"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:memory_viewer_preprocess_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@local_xla//xla:shape_util", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/service:hlo_proto_cc", - ], -) - -tf_cc_test( - name = "hlo_proto_to_memory_visualization_utils_test", - srcs = ["hlo_proto_to_memory_visualization_utils_test.cc"], - deps = [ - ":hlo_proto_to_memory_visualization_utils", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:memory_viewer_preprocess_proto_cc", - "//tensorflow/core/util/proto:proto_utils", - "@com_google_absl//absl/strings:str_format", - "@local_xla//xla/service:hlo_proto_cc", - ], -) - -cc_library( - name = "xplane_to_hlo", - srcs = ["xplane_to_hlo.cc"], - hdrs = ["xplane_to_hlo.h"], - copts = tf_profiler_copts(), - deps = [ - ":repository", - "//tensorflow/core:lib", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:statusor", - "//tensorflow/core/profiler/utils:hlo_proto_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/tsl/profiler/utils:file_system_utils", - ], -) - -cc_library( - name = "op_profile_builder", - srcs = ["op_profile_builder.cc"], - hdrs = ["op_profile_builder.h"], - deps = [ - ":op_metrics_db_combiner", - ":op_metrics_to_record", - "//tensorflow/core:lib", - "//tensorflow/core:lib_headers_for_pybind", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_profile_proto_cc", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "hlo_proto_to_graph_view", - srcs = ["hlo_proto_to_graph_view.cc"], - hdrs = ["hlo_proto_to_graph_view.h"], - deps = [ - ":tool_options", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - # copybara:uncomment(b/360874576) "//tensorflow/compiler/mlir/lite/experimental/google/tooling/hlo_adapter:direct_hlo_to_json_graph_convert", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:hlo_graph_dumper", - "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/tsl/platform:errors", - "@local_xla//xla/tsl/platform:statusor", - "//tensorflow/core/platform:statusor", - "//tensorflow/core/profiler/utils:hlo_module_utils", - "//tensorflow/core/profiler/utils:hlo_proto_to_module", - # copybara:uncomment "@com_github_nlohmann_json//:json", - ], -) - -tf_cc_test( - name = "hlo_proto_to_graph_view_test", - size = "small", - srcs = ["hlo_proto_to_graph_view_test.cc"], - deps = [ - ":hlo_proto_to_graph_view", - ":tool_options", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/protobuf:error_codes_proto_impl_cc", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/service:hlo_graph_dumper", - "@local_xla//xla/tsl/platform:status_matchers", - "@local_xla//xla/tsl/platform:statusor", - ], -) - -cc_library( - name = "tool_options", - hdrs = ["tool_options.h"], - deps = [ - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "repository", - srcs = ["repository.cc"], - hdrs = ["repository.h"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/platform:errors", - "//tensorflow/core/profiler/utils:hlo_module_map", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:statusor", - "@local_xla//xla/tsl/profiler/utils:file_system_utils", - ], -) - -tf_cc_test( - name = "repository_test", - size = "small", - srcs = ["repository_test.cc"], - deps = [ - ":repository", - "//tensorflow/core/platform:errors", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:status", - ], -) - -cc_library( - name = "xplane_to_tool_names", - srcs = ["xplane_to_tool_names.cc"], - hdrs = ["xplane_to_tool_names.h"], - deps = [ - ":repository", - ":xplane_to_dcn_collective_stats", - ":xplane_to_hlo", - "//tensorflow/core/platform:statusor", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/platform:statusor", - ], -) - -tf_cc_test( - name = "xplane_to_tool_names_test", - size = "small", - srcs = ["xplane_to_tool_names_test.cc"], - deps = [ - ":repository", - ":xplane_to_tool_names", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:status", - ], -) - -cc_library( - name = "xplane_to_trace_container", - srcs = ["xplane_to_trace_container.cc"], - hdrs = ["xplane_to_trace_container.h"], - copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core/profiler/protobuf:trace_events_proto_cc", - "//tensorflow/core/profiler/protobuf:trace_events_raw_proto_cc", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:trace_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - "@org_xprof//xprof/convert/trace_viewer:trace_event_arguments_builder", - "@org_xprof//xprof/convert/trace_viewer:trace_events", - "@org_xprof//xprof/convert/trace_viewer:trace_events_util", - ], -) - -cc_library( - name = "dcn_utils", - srcs = ["dcn_utils.cc"], - hdrs = ["dcn_utils.h"], - visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - ], -) - -tf_cc_test( - name = "dcn_utils_test", - srcs = ["dcn_utils_test.cc"], - deps = [ - ":dcn_utils", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:xplane_builder", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - ], -) - -cc_library( - name = "dcn_analysis", - srcs = ["dcn_analysis.cc"], - hdrs = ["dcn_analysis.h"], - visibility = ["//visibility:public"], - deps = [ - ":dcn_utils", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_builder", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - ], -) - -cc_library( - name = "process_megascale_dcn", - srcs = ["process_megascale_dcn.cc"], - hdrs = ["process_megascale_dcn.h"], - deps = [ - ":dcn_analysis", - "//tensorflow/core/profiler/utils:xplane_utils", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - ], -) - -tf_cc_test( - name = "dcn_analysis_test", - srcs = ["dcn_analysis_test.cc"], - deps = [ - ":dcn_analysis", - ":dcn_utils", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:xplane_builder", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - ], -) - -cc_library( - name = "xspace_to_dcn_slack_analysis", - srcs = ["xspace_to_dcn_slack_analysis.cc"], - hdrs = ["xspace_to_dcn_slack_analysis.h"], - deps = [ - "//tensorflow/core/profiler/protobuf:dcn_collective_info_proto_cc", - "//tensorflow/core/profiler/protobuf:dcn_slack_analysis_proto_cc", - "//tensorflow/core/profiler/protobuf:topology_proto_cc", - "//tensorflow/core/profiler/utils:hlo_module_utils", - "//tensorflow/core/profiler/utils:hlo_proto_map", - "//tensorflow/core/profiler/utils:hlo_proto_to_module", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:regexp", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla:shape_util", - "@local_xla//xla:side_effect_util", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/tsl/platform:statusor", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - ], -) - -cc_library( - name = "dcn_slack_analysis_combiner", - srcs = ["dcn_slack_analysis_combiner.cc"], - hdrs = ["dcn_slack_analysis_combiner.h"], - deps = [ - "//tensorflow/core/profiler/protobuf:dcn_slack_analysis_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "xplane_to_dcn_collective_stats", - srcs = ["xplane_to_dcn_collective_stats.cc"], - hdrs = ["xplane_to_dcn_collective_stats.h"], - copts = tf_profiler_copts(), - deps = [ - ":dcn_slack_analysis_combiner", - ":repository", - ":xspace_to_dcn_slack_analysis", - "//tensorflow/core:lib", - "//tensorflow/core/platform:statusor", - "//tensorflow/core/profiler/protobuf:dcn_slack_analysis_proto_cc", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:statusor", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - ], -) - -tf_cc_test( - name = "xplane_to_dcn_collective_stats_test", - size = "small", - srcs = ["xplane_to_dcn_collective_stats_test.cc"], - deps = [ - ":repository", - ":xplane_to_dcn_collective_stats", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:dcn_slack_analysis_proto_cc", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:status", - ], -) - -cc_library( - name = "inference_stats", - srcs = ["inference_stats.cc"], - hdrs = ["inference_stats.h"], - deps = [ - "//tensorflow/core/lib/gtl:map_util", - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "//tensorflow/core/profiler/utils:event_span", - "//tensorflow/core/profiler/utils:xplane_schema", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:logging", - "@local_xla//xla/tsl/profiler/utils:device_utils", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:xplane_builder", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - ], -) - -cc_library( - name = "inference_stats_sampler", - srcs = ["inference_stats_sampler.cc"], - hdrs = ["inference_stats_sampler.h"], - deps = [ - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:string_view", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "inference_stats_grouping", - srcs = ["inference_stats_grouping.cc"], - hdrs = ["inference_stats_grouping.h"], - deps = [ - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/platform:protobuf", - "@local_xla//xla/tsl/lib/gtl:map_util", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:timespan", - ], -) - -cc_library( - name = "inference_stats_combiner", - srcs = ["inference_stats_combiner.cc"], - hdrs = ["inference_stats_combiner.h"], - deps = [ - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings:string_view", - "@local_xla//xla/tsl/lib/gtl:map_util", - ], -) - -cc_library( - name = "multi_xspace_to_inference_stats", - srcs = ["multi_xspace_to_inference_stats.cc"], - hdrs = ["multi_xspace_to_inference_stats.h"], - deps = [ - ":inference_stats", - ":inference_stats_combiner", - ":inference_stats_grouping", - ":inference_stats_sampler", - ":preprocess_single_host_xplane", - ":repository", - ":xplane_to_step_events", - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "//tensorflow/core/profiler/utils:event_span", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_visitor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/platform:statusor", - "@local_xla//xla/tsl/profiler/utils:device_utils", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_utils", - ], -) - -tf_cc_test( - name = "inference_stats_grouping_test", - srcs = ["inference_stats_grouping_test.cc"], - tags = [ - "no_oss", - ], - deps = [ - ":inference_stats_grouping", - "//tensorflow/core:test", - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/tests:test_utils", - ], -) - -tf_cc_test( - name = "inference_stats_sampler_test", - srcs = ["inference_stats_sampler_test.cc"], - tags = [ - "no_oss", - ], - deps = [ - ":inference_stats_sampler", - "//tensorflow/core:test", - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/tests:test_utils", - ], -) - -cc_library( - name = "compute_inference_latency", - srcs = ["compute_inference_latency.cc"], - hdrs = ["compute_inference_latency.h"], - visibility = ["//perftools/accelerators/xprof/convert:__pkg__"], - deps = [ - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:overview_page_proto_cc", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "profile_time_breakdown", - srcs = ["profile_time_breakdown.cc"], - hdrs = ["profile_time_breakdown.h"], - visibility = ["@local_xla//xla/tsl/profiler:friends"], - deps = [ - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:math_utils", - ], -) - -cc_library( - name = "tpu_input_pipeline_analysis_constants", - srcs = [tf_profiler_alias("@org_xprof//xprof/convert/", "tpu_input_pipeline_analysis_constants.cc")], - hdrs = ["tpu_input_pipeline_analysis_constants.h"], - visibility = ["@local_xla//xla/tsl/profiler:friends"], - deps = [ - "@com_google_absl//absl/strings:string_view", - "@local_xla//xla/tsl/platform:macros", - ], -) - -cc_library( - name = "duty_cycle_tracker", - srcs = ["duty_cycle_tracker.cc"], - hdrs = ["duty_cycle_tracker.h"], - deps = [ - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/log:check", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:timespan", - ], -) - -cc_library( - name = "duty_cycle_combiner", - hdrs = ["duty_cycle_combiner.h"], - deps = [ - ":duty_cycle_tracker", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - ], -) - -tf_cc_test( - name = "duty_cycle_tracker_test", - srcs = ["duty_cycle_tracker_test.cc"], - deps = [ - ":duty_cycle_tracker", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/log:check", - "@com_google_googletest//:gtest", - "@local_xla//xla/tsl/profiler/utils:timespan", - ], -) - -tf_cc_test( - name = "compute_inference_latency_test", - srcs = ["compute_inference_latency_test.cc"], - deps = [ - ":compute_inference_latency", - "//tensorflow/core/profiler/protobuf:inference_stats_proto_cc", - "@com_google_googletest//:gtest_main", - ], -) - -tf_cc_test( - name = "duty_cycle_combiner_test", - srcs = ["duty_cycle_combiner_test.cc"], - deps = [ - ":duty_cycle_combiner", - ":duty_cycle_tracker", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_googletest//:gtest", - "@local_xla//xla/tsl/profiler/utils:timespan", - ], -) - -cc_library( - name = "data_table_utils", - hdrs = ["data_table_utils.h"], - deps = [ - "@com_github_nlohmann_json//:json", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "data_table_utils_test", - srcs = ["data_table_utils_test.cc"], deps = [ - ":data_table_utils", - "@com_github_nlohmann_json//:json", - "@com_google_googletest//:gtest", - "@com_google_googletest//:gtest_main", + "@org_xprof//xprof/convert:xplane_to_op_stats", ], ) diff --git a/tensorflow/core/profiler/convert/compute_inference_latency.cc b/tensorflow/core/profiler/convert/compute_inference_latency.cc deleted file mode 100644 index ba0c8245fd033e..00000000000000 --- a/tensorflow/core/profiler/convert/compute_inference_latency.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/compute_inference_latency.h" - -#include -#include -#include -#include -#include - -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/overview_page.pb.h" - -namespace tensorflow::profiler { - -struct LatencyBreakdown { - double total_latency_us = 0.0; - double host_latency_us = 0.0; - double device_latency_us = 0.0; - double communication_latency_us = 0.0; -}; - -void SetLatencyBreakdown(const LatencyBreakdown& src, - OverviewLatencyBreakdown* res) { - res->set_total_latency_us(src.total_latency_us); - res->set_host_latency_us(src.host_latency_us); - res->set_device_latency_us(src.device_latency_us); - res->set_communication_latency_us(src.communication_latency_us); -} - -void SafeDivide(int64_t count, double* num) { - constexpr double kEpsilon = 1.0e-20; - if (count == 0 || std::abs(*num) < kEpsilon) { - *num = 0.0; - } else { - *num /= count; - } -} - -void ComputeAverage(int64_t count, LatencyBreakdown* breakdown) { - SafeDivide(count, &breakdown->total_latency_us); - SafeDivide(count, &breakdown->host_latency_us); - SafeDivide(count, &breakdown->device_latency_us); - SafeDivide(count, &breakdown->communication_latency_us); -} - -void ComputeBreakdownFromSessionRun( - const tensorflow::profiler::RequestDetail& request_detail, - LatencyBreakdown* res, LatencyBreakdown* avg) { - double session_run_duration_us = tsl::profiler::PicoToMicro( - request_detail.end_time_ps() - request_detail.start_time_ps()); - double device_time_us = - tsl::profiler::PicoToMicro(request_detail.device_time_ps()); - double communication_time_us = - tsl::profiler::PicoToMicro(request_detail.read_from_device_time_ps() + - request_detail.write_to_device_time_ps()); - double host_time_us = - session_run_duration_us - device_time_us - communication_time_us; - *res = {session_run_duration_us, host_time_us, device_time_us, - communication_time_us}; - - avg->total_latency_us += session_run_duration_us; - avg->device_latency_us += device_time_us; - avg->communication_latency_us += communication_time_us; - avg->host_latency_us += - session_run_duration_us - device_time_us - communication_time_us; -} - -// Compute the inference latency from inference stats proto. -OverviewInferenceLatency ComputeInferenceLatencyResult( - const tensorflow::profiler::InferenceStats& inference_stats) { - OverviewInferenceLatency result; - // If inference_stats is empty, return early with empty result. - // The following code is able to return empty result even - // without early return. - if (inference_stats.inference_stats_per_model_size() == 0) return result; - - // Target percentiles over all session runs. - // Default is [50.0, 75.0, 90.0, 99.0, 99.9]. - constexpr double kTargetPercentiles[] = {50.0, 75.0, 90.0, 99.0, 99.9}; - // Saves the latency corresponding to each percentile. - - std::vector sessions; - double total_sessioins_per_sec = 0; - double max_latency = 0.0; - double min_latency = std::numeric_limits::max(); - LatencyBreakdown avg; - // Iterate over all session runs from all models, calculate the device, - // communication, and host time for each session run, and push in the - // vector sessions. Also update the max, min, count, avg. - for (const auto& model_inference_stats : - inference_stats.inference_stats_per_model()) { - total_sessioins_per_sec += - model_inference_stats.second.request_throughput(); - for (const auto& request_detail : - model_inference_stats.second.request_details()) { - LatencyBreakdown session_breakdown; - ComputeBreakdownFromSessionRun(request_detail, &session_breakdown, &avg); - sessions.push_back(session_breakdown); - double session_run_duration_us = tsl::profiler::PicoToMicro( - request_detail.end_time_ps() - request_detail.start_time_ps()); - max_latency = std::max(max_latency, session_run_duration_us); - min_latency = std::min(min_latency, session_run_duration_us); - } - } - // Return empty result if there is no session found. - if (sessions.empty()) return result; - result.set_sessions_per_second(total_sessioins_per_sec); - result.set_max_latency_us(max_latency); - result.set_min_latency_us(min_latency); - ComputeAverage(sessions.size(), &avg); - - // Sort the sessions based on session run duration. For a specified - // percentile, get the corresponding session with the (lower-bound) index. - std::sort(sessions.begin(), sessions.end(), - [](const LatencyBreakdown& a, const LatencyBreakdown& b) { - return a.total_latency_us < b.total_latency_us; - }); - for (const auto& percent : kTargetPercentiles) { - result.add_percentile_numbers(percent); - int64_t index = percent / 100.0 * sessions.size(); - SetLatencyBreakdown(sessions[index], result.add_latency_breakdowns()); - } - // Set the average latency stats. - SetLatencyBreakdown(avg, result.add_latency_breakdowns()); - - return result; -} - -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/compute_inference_latency.h b/tensorflow/core/profiler/convert/compute_inference_latency.h deleted file mode 100644 index 91632c907cbdf2..00000000000000 --- a/tensorflow/core/profiler/convert/compute_inference_latency.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_COMPUTE_INFERENCE_LATENCY_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_COMPUTE_INFERENCE_LATENCY_H_ - -#include -#include - -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/overview_page.pb.h" - -namespace tensorflow::profiler { - -// Compute the inference latency from inference stats proto. -OverviewInferenceLatency ComputeInferenceLatencyResult( - const InferenceStats& inference_stats); - -} // namespace tensorflow::profiler - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_COMPUTE_INFERENCE_LATENCY_H_ diff --git a/tensorflow/core/profiler/convert/compute_inference_latency_test.cc b/tensorflow/core/profiler/convert/compute_inference_latency_test.cc deleted file mode 100644 index efd931c1384739..00000000000000 --- a/tensorflow/core/profiler/convert/compute_inference_latency_test.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/compute_inference_latency.h" - -#include -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { -namespace { - -constexpr double kMaxError = 0.0001; - -TEST(ComputeInferenceLatencyResult, InferenceLatencyTest) { - InferenceStats inference_stats; - auto& model = (*inference_stats.mutable_inference_stats_per_model())[0]; - - // Generates requests for testing. - for (int i = 0; i < 100; i++) { - RequestDetail request_detail; - request_detail.set_start_time_ps(0); - request_detail.set_end_time_ps(i * 10000); - request_detail.set_device_time_ps(i * 1000); - request_detail.set_write_to_device_time_ps(i * 1000); - model.add_request_details()->Swap(&request_detail); - } - - auto result = ComputeInferenceLatencyResult(inference_stats); - - // 5 percentiles and 1 average, so 6 results in total. - ASSERT_EQ(result.latency_breakdowns_size(), 6); - - // Verify 50 percentile result. - EXPECT_NEAR(result.latency_breakdowns(0).total_latency_us(), 0.5, kMaxError); - EXPECT_NEAR(result.latency_breakdowns(0).host_latency_us(), 0.4, kMaxError); - EXPECT_NEAR(result.latency_breakdowns(0).device_latency_us(), 0.05, - kMaxError); - EXPECT_NEAR(result.latency_breakdowns(0).communication_latency_us(), 0.05, - kMaxError); - - // Verify 99.9 percentile result. - EXPECT_NEAR(result.latency_breakdowns(4).total_latency_us(), 0.99, kMaxError); - EXPECT_NEAR(result.latency_breakdowns(4).host_latency_us(), 0.792, kMaxError); - EXPECT_NEAR(result.latency_breakdowns(4).device_latency_us(), 0.099, - kMaxError); - EXPECT_NEAR(result.latency_breakdowns(4).communication_latency_us(), 0.099, - kMaxError); - - // Verify average result. - EXPECT_NEAR(result.latency_breakdowns(5).total_latency_us(), 0.495, - kMaxError); - EXPECT_NEAR(result.latency_breakdowns(5).host_latency_us(), 0.396, kMaxError); - EXPECT_NEAR(result.latency_breakdowns(5).device_latency_us(), 0.0495, - kMaxError); - EXPECT_NEAR(result.latency_breakdowns(5).communication_latency_us(), 0.0495, - kMaxError); -} - -} // namespace -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/data_table_utils.h b/tensorflow/core/profiler/convert/data_table_utils.h deleted file mode 100644 index 34bc248b356db5..00000000000000 --- a/tensorflow/core/profiler/convert/data_table_utils.h +++ /dev/null @@ -1,155 +0,0 @@ - -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DATA_TABLE_UTILS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_DATA_TABLE_UTILS_H_ -#include -#include -#include - -#include "absl/container/btree_map.h" -#include "absl/strings/str_replace.h" -#include "nlohmann/json_fwd.hpp" -#include "nlohmann/json.hpp" -namespace tensorflow { -namespace profiler { -// We Don't deal with formatted values on backend now. -struct TableCell { - TableCell() = default; - explicit TableCell(nlohmann::json value) : value(value) {}; - explicit TableCell( - nlohmann::json value, - absl::btree_map custom_properties) - : value(value), custom_properties(custom_properties) {}; - std::string value_str() const { - return absl::StrReplaceAll(value.dump(), {{"\"", ""}}); - } - nlohmann::json value; - absl::btree_map custom_properties; -}; -struct TableColumn { - TableColumn() = default; - explicit TableColumn(std::string id, std::string type, std::string label) - : id(id), type(type), label(label) {}; - explicit TableColumn( - std::string id, std::string type, std::string label, - absl::btree_map custom_properties) - : id(id), type(type), label(label), custom_properties(custom_properties) { - }; - std::string id; - std::string type; - std::string label; - absl::btree_map custom_properties; -}; -class TableRow { - public: - TableRow() = default; - virtual ~TableRow() = default; - // Adds a value of a single cell to the end of the row. - // Memory will be freed by the TableRow. - TableCell* AddCell(nlohmann::json value) { - cells_.push_back(std::make_unique(value)); - return cells_.back().get(); - } - std::vector GetCells() const { - std::vector cells; - cells.reserve(cells_.size()); - for (const std::unique_ptr& cell : cells_) { - cells.push_back(cell.get()); - } - return cells; - } - void SetCustomProperties( - const absl::btree_map& custom_properties) { - custom_properties_ = custom_properties; - } - void AddCustomProperty(std::string name, std::string value) { - custom_properties_[name] = value; - } - const absl::btree_map& GetCustomProperties() const { - return custom_properties_; - } - int RowSize() const { return cells_.size(); } - - private: - std::vector> cells_; - absl::btree_map custom_properties_; -}; -// A DataTable class that can be used to create a DataTable JSON/CSV -// serialization. We need this class instead raw JSON manipulation because we -// need to support custom properties. -class DataTable { - public: - DataTable() = default; - void AddColumn(TableColumn column) { table_descriptions_.push_back(column); } - const std::vector& GetColumns() { return table_descriptions_; } - // Create an empty row and return a pointer to it. - // DataTable takes the ownership of the returned TableRow. - TableRow* AddRow() { - table_rows_.push_back(std::make_unique()); - return table_rows_.back().get(); - } - std::vector GetRows() { - std::vector rows; - rows.reserve(table_rows_.size()); - for (const std::unique_ptr& row : table_rows_) { - rows.push_back(row.get()); - } - return rows; - } - void AddCustomProperty(std::string name, std::string value) { - custom_properties_[name] = value; - } - std::string ToJson() { - nlohmann::json table; - table["cols"] = nlohmann::json::array(); - table["rows"] = nlohmann::json::array(); - if (!custom_properties_.empty()) { - table["p"] = custom_properties_; - } - for (const TableColumn& col : table_descriptions_) { - nlohmann::json column_json; - column_json["id"] = col.id; - column_json["type"] = col.type; - column_json["label"] = col.label; - if (!col.custom_properties.empty()) { - column_json["p"] = col.custom_properties; - } - table["cols"].push_back(column_json); - } - for (const std::unique_ptr& row : table_rows_) { - nlohmann::json row_json; - row_json["c"] = nlohmann::json::array(); - for (const TableCell* cell : row->GetCells()) { - nlohmann::json cell_json; - cell_json["v"] = cell->value; - if (!cell->custom_properties.empty()) { - cell_json["p"] = cell->custom_properties; - } - row_json["c"].push_back(cell_json); - } - if (!row->GetCustomProperties().empty()) { - row_json["p"] = row->GetCustomProperties(); - } - table["rows"].push_back(row_json); - } - return table.dump(); - } - - private: - std::vector table_descriptions_; - std::vector> table_rows_; - absl::btree_map custom_properties_; -}; -} // namespace profiler -} // namespace tensorflow -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DATA_TABLE_UTILS_H_ diff --git a/tensorflow/core/profiler/convert/data_table_utils_test.cc b/tensorflow/core/profiler/convert/data_table_utils_test.cc deleted file mode 100644 index 58c89f0e25f306..00000000000000 --- a/tensorflow/core/profiler/convert/data_table_utils_test.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/data_table_utils.h" - -#include -#include -#include - -#include -#include "nlohmann/json_fwd.hpp" -#include "nlohmann/json.hpp" - -namespace tensorflow::profiler { -namespace { - -std::vector> GetTestColumns() { - return {{"rank", "number", "Rank"}, - {"program_id", "string", "Program Id"}, - {"op_category", "string", "Op Category"}, - {"op_name", "string", "Op Name"}, - {"bytes_accessed", "number", "Bytes Accessed"}, - {"model_flops", "number", "Model Flops"}, - {"occurrences", "number", "#Occurrences"}}; -} - -std::vector GetTestRows() { - return {{1, "11111", "category1", "op1", 200000000, 123123123, 10}, - {2, "22222", "category2", "op2", 1000000, 0, 20}, - {3, "33333", "category3", "op3", 3000000, 565656, 30}}; -} - -std::unique_ptr CreateTestDataTable() { - auto data_table = std::make_unique(); - for (const std::vector& col : GetTestColumns()) { - data_table->AddColumn(TableColumn(col[0], col[1], col[2])); - } - for (const nlohmann::json& row_json : GetTestRows()) { - TableRow* row = data_table->AddRow(); - for (int i = 0; i < row_json.size(); ++i) { - row->AddCell(row_json[i]); - } - } - return data_table; -} - -std::unique_ptr -CreateTestDataTableWithCustomProperties() { - auto data_table = std::make_unique(); - data_table->AddCustomProperty("key1", "value1"); - data_table->AddCustomProperty("key2", "value2"); - return data_table; -} - -TEST(DataTableUtilsTest, ToJson) { - std::unique_ptr data_table = - CreateTestDataTable(); - std::string json_string = data_table->ToJson(); - const nlohmann::basic_json<> parsed_json = nlohmann::json::parse(json_string); - auto test_columns = GetTestColumns(); - auto test_rows = GetTestRows(); - EXPECT_EQ(parsed_json["cols"].size(), test_columns.size()); - EXPECT_EQ(parsed_json["rows"].size(), test_rows.size()); - for (int i = 0; i < test_columns.size(); ++i) { - EXPECT_EQ(parsed_json["cols"][i]["id"], test_columns[i][0]); - EXPECT_EQ(parsed_json["cols"][i]["label"], test_columns[i][2]); - EXPECT_EQ(parsed_json["cols"][i]["type"], test_columns[i][1]); - } - for (int i = 0; i < test_rows.size(); ++i) { - for (int j = 0; j < test_columns.size(); ++j) { - EXPECT_EQ(parsed_json["rows"][i]["c"][j]["v"], GetTestRows()[i][j]); - } - } -} - -TEST(DataTableUtilsTest, ToJsonWithCustomProperties) { - std::unique_ptr data_table = - CreateTestDataTableWithCustomProperties(); - std::string table_json_string = data_table->ToJson(); - const nlohmann::basic_json<> parsed_json = - nlohmann::json::parse(table_json_string); - EXPECT_EQ(parsed_json.find("p")->size(), 2); - EXPECT_EQ(parsed_json.find("p")->at("key1"), "value1"); - EXPECT_EQ(parsed_json.find("p")->at("key2"), "value2"); -} - -} // namespace -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/dcn_analysis.cc b/tensorflow/core/profiler/convert/dcn_analysis.cc deleted file mode 100644 index 15de6d44400def..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_analysis.cc +++ /dev/null @@ -1,471 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/dcn_analysis.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/convert/dcn_utils.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -using tsl::profiler::kMaxCollectivesToDisplay; -using tsl::profiler::kMegaScaleDcnReceive; -using tsl::profiler::LineIdType; -using tsl::profiler::MicroToNano; - -void DcnBurstManager::ResetBurstState() { - active_burst_messages_ = 0; - straggler_idx_ = 0; - active_burst_.num_messages = 0; - active_burst_.max_overlapping_messages = 0; - active_burst_.start_timestamp_ns = 0; - active_burst_.end_timestamp_ns = 0; - active_burst_.burst_size_bytes = 0; -} - -void DcnBurstManager::CreateBursts(const TimestampMap& tm_events) { - ResetBurstState(); - for (const auto& tm_event : tm_events) { - if (active_burst_messages_ < 0) { - LOG_FIRST_N(WARNING, 10) - << "Negative messages in burst, bursts will be incorrect."; - } - if (active_burst_messages_ == 0) { - // When no messages are active, next event starts a new burst - active_burst_.start_timestamp_ns = tm_event.first; - } - active_burst_messages_ += tm_event.second->message_diff; - if (tm_event.second->message_diff > 0) { - // On beginning of message increase messages and bytes - active_burst_.num_messages += tm_event.second->message_diff; - active_burst_.burst_size_bytes += tm_event.second->size_diff; - } else { - // On end of message, register straggler - Straggler straggler = {tm_event.second->duration_ns, // duration_ns - tm_event.second->timestamp_ns, // end_timestamp_ns - tm_event.second->size_diff * (-1), // size_bytes - tm_event.second->src_slice_id}; // src_slice_id - active_burst_.stragglers[straggler_idx_] = straggler; - straggler_idx_ = (straggler_idx_ + 1) % kMaxStragglersPerBurst; - } - active_burst_.max_overlapping_messages = - std::max(active_burst_.max_overlapping_messages, - static_cast(active_burst_messages_)); - // If we are back at 0 messages, the burst has finished and can be added - // to the bursts_ vector. - if (active_burst_messages_ == 0) { - active_burst_.end_timestamp_ns = tm_event.first; - total_latency_ += - (active_burst_.end_timestamp_ns - active_burst_.start_timestamp_ns); - bursts_.emplace_back(std::move(active_burst_)); - ResetBurstState(); - } - } -} - -DcnEventsProcessor::DcnEventsProcessor(uint32_t num_tpu_tensor_cores, - bool is_megacore) - : num_tpu_tensor_cores_(num_tpu_tensor_cores), is_megacore_(is_megacore) { - // Register all MSXLA messages we may need to analyze. Currently only - // receive messages are processed. - registered_dcn_messages_.push_back(kMegaScaleDcnReceive); - tpu_collective_ts_map_.resize(num_tpu_tensor_cores_); - tpu_collective_bursts_.resize(num_tpu_tensor_cores_); -} - -// Sets up map between registered Megascale messages and their event metadata -// so they can be captured from host events. -void DcnEventsProcessor::SetupMessageInfo(const XPlaneVisitor& plane) { - plane.ForEachEventMetadata([&](const XEventMetadataVisitor& event_metadata) { - if (std::find(registered_dcn_messages_.begin(), - registered_dcn_messages_.end(), - event_metadata.Name()) != registered_dcn_messages_.end()) { - megascale_msg_[event_metadata.Name()] = event_metadata.Id(); - } - }); -} - -// If we use megacore, collective traffic goes to even TPU tensor cores. -// Odd ones are woken up from their even pair (e.g. 0 wakes up 1). -uint32_t DcnEventsProcessor::FindTpuIdx(int tpu) { - uint32_t num_tpus = num_tpu_tensor_cores_; - if (is_megacore_) { - num_tpus /= 2; - } - uint32_t tpu_idx = tpu % num_tpus; - if (is_megacore_) { - tpu_idx = tpu_idx * 2; - } - return tpu_idx; -} - -void DcnEventsProcessor::GenerateTimestampEvents( - const DcnMessage& dcn_message) { - // Create one event for the beginning and one for the end of the message - std::shared_ptr start_event( - new TimestampEvent{dcn_message.start_timestamp_ns, 0, 1, - dcn_message.size_bytes, dcn_message.slice_src}); - std::shared_ptr end_event(new TimestampEvent{ - dcn_message.end_timestamp_ns, - static_cast(MicroToNano(dcn_message.duration_us)), -1, - -1 * dcn_message.size_bytes, dcn_message.slice_src}); - - // Add messages to host timestamp event map - std::pair> start_event_entry = - std::make_pair(dcn_message.start_timestamp_ns, start_event); - std::pair> end_event_entry = - std::make_pair(dcn_message.end_timestamp_ns, end_event); - host_ts_map_.insert(start_event_entry); - host_ts_map_.insert(end_event_entry); - - // Add messages to the proper TPU collective timestamp event map. - const std::string& collective_name = dcn_message.collective_name; - uint32_t tpu_idx = FindTpuIdx(dcn_message.tpu_dst); - auto& m = tpu_collective_ts_map_[tpu_idx][collective_name]; - m.insert(start_event_entry); - m.insert(end_event_entry); -} - -void DcnEventsProcessor::PrintTimestampEvents() { - for (const auto& host_ts : host_ts_map_) { - LOG(INFO) << host_ts.first << ": " << host_ts.second->timestamp_ns << " " - << host_ts.second->duration_ns << " " - << host_ts.second->message_diff << " " - << host_ts.second->size_diff << " " - << host_ts.second->src_slice_id; - } - for (uint32_t tpu_idx = 0; tpu_idx < num_tpu_tensor_cores_; tpu_idx++) { - LOG(INFO) << "TPU: " << tpu_idx; - for (const auto& col_id : tpu_collective_ts_map_[tpu_idx]) { - LOG(INFO) << col_id.first; - for (const auto& tpu_col_ts : - tpu_collective_ts_map_[tpu_idx][col_id.first]) { - LOG(INFO) << tpu_col_ts.first << ": " << tpu_col_ts.second->timestamp_ns - << " " << tpu_col_ts.second->duration_ns << " " - << tpu_col_ts.second->message_diff << " " - << tpu_col_ts.second->size_diff << " " - << tpu_col_ts.second->src_slice_id; - } - } - } -} - -// Uses heuristics to qualify a good enough amount of collectives. -// kMaxCollectivesToDisplay - 1 are displayed. -// Collectives with < 5% of total host BW time are never qualified -// Collectives with < 20% of total host BW time are qualified if less than 4 -// collectives have already been qualified. -// Top 8 collectives with > 20% of total host BW time are qualified -uint32_t DcnEventsProcessor::NumCollectivesQualified( - const std::vector& latencies) { - uint32_t num_collectives_qualified = 0; - // Allow for 1 line to display stragglers of non-qualified collectives. - uint32_t max_collectives = kMaxCollectivesToDisplay - 1; - for (const auto& lat : latencies) { - if (lat < host_dcn_bursts_.TotalLatency() * 0.05) { - return num_collectives_qualified; - } else if (lat < host_dcn_bursts_.TotalLatency() * 0.2 && - num_collectives_qualified >= (max_collectives / 2)) { - return num_collectives_qualified; - } else if (num_collectives_qualified >= max_collectives) { - return num_collectives_qualified; - } else { - num_collectives_qualified++; - } - } - return latencies.size(); -} - -// Find which collectives you are going to display in details (dedicated line) -// and which not (shared line for stragglers). -// Order collectives based on burst latency -- then qualify the top ones based -// on NumCollectivesQualified function. -void DcnEventsProcessor::QualifyCollectives() { - for (auto tpu_idx = 0; tpu_idx < num_tpu_tensor_cores_; tpu_idx++) { - std::vector latency_to_order; - latency_to_order.reserve(tpu_collective_bursts_[tpu_idx].size()); - for (const auto& col_info : tpu_collective_bursts_[tpu_idx]) { - latency_to_order.emplace_back(col_info.second.TotalLatency()); - } - std::sort(latency_to_order.begin(), latency_to_order.end(), - std::greater()); - uint32_t num_collectives_qualified = - NumCollectivesQualified(latency_to_order); - if (num_collectives_qualified > 0) { - uint32_t min_latency_to_qualify = - latency_to_order[num_collectives_qualified - 1]; - uint32_t col_num = 0; - for (auto& col_info : tpu_collective_bursts_[tpu_idx]) { - if (col_info.second.TotalLatency() >= min_latency_to_qualify) { - col_info.second.SetToDisplay(true); - if (++col_num == kMaxCollectivesToDisplay - 1) break; - } - } - } - } -} - -void DcnEventsProcessor::GenerateBursts() { - host_dcn_bursts_.CreateBursts(host_ts_map_); - host_dcn_bursts_.SetToDisplay(true); - - for (auto tpu_idx = 0; tpu_idx < num_tpu_tensor_cores_; tpu_idx++) { - for (const auto& col_info : tpu_collective_ts_map_[tpu_idx]) { - tpu_collective_bursts_[tpu_idx][col_info.first].CreateBursts( - tpu_collective_ts_map_[tpu_idx][col_info.first]); - } - } - QualifyCollectives(); -} - -void DcnEventsProcessor::ProcessReceiveMessages(const XPlaneVisitor& plane) { - plane.ForEachLine([&](const XLineVisitor& line) { - uint32_t recv_msg_id = megascale_msg_[kMegaScaleDcnReceive]; - line.ForEachEvent([&](const XEventVisitor& event) { - if (event.Id() == recv_msg_id) { - DcnMessage dcn_message = GetDcnMessageFromXEvent(event); - // TODO(emizan): Report invalid and clock skew messages somehow. - // TODO(emizan): Bring back loopback messages when MSXLA fixes them. - if (dcn_message.validity_info == DCN_MESSAGE_VALID) { - GenerateTimestampEvents(dcn_message); - } - received_messages_.emplace_back(std::move(dcn_message)); - } - }); - }); - GenerateBursts(); -} - -absl::string_view DcnEventsProcessor::GetBwInfo(bool is_per_tpu, - const DcnBurst& burst, - float& burst_mean_bw, - float& burst_bw_utilization) { - absl::string_view bw_level; - uint32_t bw_divider = 1; - burst_mean_bw = static_cast(burst.burst_size_bytes) / - (burst.end_timestamp_ns - burst.start_timestamp_ns); - if (is_per_tpu) { - bw_divider = num_tpu_tensor_cores_; - if (is_megacore_) { - bw_divider /= 2; - } - } - // Have 3 BW categories (low/med/high) to limit the amount of colors in the - // trace viewer - if (burst_mean_bw < kLimitLowHostDcnBw / bw_divider) { - bw_level = "Low BW"; - } else if (burst_mean_bw < kLimitMedHostDcnBw / bw_divider) { - bw_level = "Med BW"; - } else { - bw_level = "High BW"; - } - burst_bw_utilization = burst_mean_bw / (kMaxHostDcnBw / bw_divider); - return bw_level; -} - -void DcnEventsProcessor::AddHostDcnTrafficToXPlane(XPlane* host_xplane) { - if (!host_dcn_bursts_.ToDisplay()) return; - XPlaneBuilder plane_builder(host_xplane); - XLineBuilder line = - plane_builder.GetOrCreateLine(LineIdType::kDcnHostTraffic); - line.SetNameIfEmpty("DCN Host Bandwidth"); - line.SetTimestampNs(0); - XStatMetadata* bw_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Bandwidth (GBytes/sec)"); - XStatMetadata* bw_util_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Bandwidth Utilization"); - XStatMetadata* num_msg_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Total Messages"); - XStatMetadata* max_overlap_msg_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Max Overlapping Messages"); - XStatMetadata* avg_msg_size_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Average Message Size (Bytes)"); - for (const auto& host_burst : host_dcn_bursts_.GetBursts()) { - float burst_mean_bw, bw_utilization; - absl::string_view bw_level = - GetBwInfo(false, host_burst, burst_mean_bw, bw_utilization); - XEventMetadata* event_metadata = - plane_builder.GetOrCreateEventMetadata(bw_level); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetNs(host_burst.start_timestamp_ns); - event.SetDurationNs(host_burst.end_timestamp_ns - - host_burst.start_timestamp_ns); - - // Using std::string to limit number of decimals. - event.ParseAndAddStatValue(*bw_stat_metadata, - std::to_string(burst_mean_bw)); - event.ParseAndAddStatValue(*bw_util_stat_metadata, - std::to_string(bw_utilization)); - event.AddStatValue(*num_msg_stat_metadata, host_burst.num_messages); - event.AddStatValue(*max_overlap_msg_stat_metadata, - host_burst.max_overlapping_messages); - uint32_t avg_message_size = - host_burst.burst_size_bytes / host_burst.num_messages; - event.AddStatValue(*avg_msg_size_stat_metadata, avg_message_size); - } -} - -void DcnEventsProcessor::AddUnqualifiedCollectivesToXPlane( - XPlaneBuilder& plane_builder, uint32_t tpu_idx) { - XLineBuilder line = - plane_builder.GetOrCreateLine(LineIdType::kDcnCollectiveTrafficMax); - line.SetNameIfEmpty("Remaining collectives"); - line.SetTimestampNs(0); - for (const auto& col_item : tpu_collective_bursts_[tpu_idx]) { - if (col_item.second.ToDisplay()) continue; - for (const auto& col_burst : col_item.second.GetBursts()) { - XEventMetadata* straggler_event_metadata = - plane_builder.GetOrCreateEventMetadata(col_item.first); - uint32_t stragglers_processed = 0; - XStatMetadata* straggler_src_slice_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Source slice"); - XStatMetadata* straggler_duration_ns_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Duration ns"); - XStatMetadata* straggler_send_time_ns_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Send timestamp ns"); - XStatMetadata* straggler_recv_time_ns_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Recv timestamp ns"); - for (const auto& straggler : col_burst.stragglers) { - XEventBuilder straggler_event = - line.AddEvent(*straggler_event_metadata); - straggler_event.SetOffsetNs(straggler.end_timestamp_ns - 10000); - straggler_event.SetDurationNs(10000); - straggler_event.AddStatValue(*straggler_src_slice_stat_metadata, - straggler.src_slice_id); - straggler_event.AddStatValue(*straggler_duration_ns_stat_metadata, - straggler.duration_ns); - straggler_event.AddStatValue( - *straggler_send_time_ns_stat_metadata, - straggler.end_timestamp_ns - straggler.duration_ns); - straggler_event.AddStatValue(*straggler_recv_time_ns_stat_metadata, - straggler.end_timestamp_ns); - if (++stragglers_processed >= col_burst.num_messages) break; - } - } - } -} - -void DcnEventsProcessor::AddQualifiedCollectivesToXPlane( - XPlaneBuilder& plane_builder, uint32_t tpu_idx) { - uint32_t total_collectives = 0; - for (const auto& col_item : tpu_collective_bursts_[tpu_idx]) { - // Skip collectives not enabled for display. - if (!col_item.second.ToDisplay()) continue; - const std::string& col_name = col_item.first; - XLineBuilder line = plane_builder.GetOrCreateLine( - LineIdType::kDcnCollectiveTraffic + total_collectives++); - line.SetNameIfEmpty(col_name); - line.SetTimestampNs(0); - XStatMetadata* bw_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Bandwidth (GBytes/sec)"); - XStatMetadata* bw_util_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Bandwidth Utilization"); - XStatMetadata* num_msg_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Total Messages"); - XStatMetadata* max_overlap_msg_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Max Overlapping Messages"); - XStatMetadata* avg_msg_size_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Average Message Size (Bytes)"); - XStatMetadata* straggler_details_metadata = - plane_builder.GetOrCreateStatMetadata("Straggler info:"); - XStatMetadata* straggler_src_slice_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Source slice"); - XStatMetadata* straggler_duration_ns_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Duration ns"); - XStatMetadata* straggler_send_time_ns_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Send timestamp ns"); - XStatMetadata* straggler_recv_time_ns_stat_metadata = - plane_builder.GetOrCreateStatMetadata("Recv timestamp ns"); - for (const auto& col_burst : col_item.second.GetBursts()) { - float burst_mean_bw, bw_utilization; - absl::string_view bw_level = - GetBwInfo(true, col_burst, burst_mean_bw, bw_utilization); - XEventMetadata* event_metadata = - plane_builder.GetOrCreateEventMetadata(bw_level); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetNs(col_burst.start_timestamp_ns); - event.SetDurationNs(col_burst.end_timestamp_ns - - col_burst.start_timestamp_ns); - event.ParseAndAddStatValue(*bw_stat_metadata, - std::to_string(burst_mean_bw)); - event.ParseAndAddStatValue(*bw_util_stat_metadata, - std::to_string(bw_utilization)); - event.AddStatValue(*num_msg_stat_metadata, col_burst.num_messages); - event.AddStatValue(*max_overlap_msg_stat_metadata, - col_burst.max_overlapping_messages); - event.AddStatValue(*avg_msg_size_stat_metadata, - col_burst.burst_size_bytes / col_burst.num_messages); - // Add straggler info. - XEventMetadata* straggler_event_metadata = - plane_builder.GetOrCreateEventMetadata("Straggler"); - uint32_t stragglers_processed = 0; - std::string straggler_details = "Stragglers:\n"; - for (const auto& straggler : col_burst.stragglers) { - // Add an event for the last straggler - if (straggler.end_timestamp_ns == col_burst.end_timestamp_ns) { - XEventBuilder straggler_event = - line.AddEvent(*straggler_event_metadata); - straggler_event.SetOffsetNs(straggler.end_timestamp_ns - - straggler.duration_ns); - straggler_event.SetDurationNs(straggler.duration_ns); - straggler_event.AddStatValue(*straggler_src_slice_stat_metadata, - straggler.src_slice_id); - straggler_event.AddStatValue(*straggler_duration_ns_stat_metadata, - straggler.duration_ns); - straggler_event.AddStatValue( - *straggler_send_time_ns_stat_metadata, - straggler.end_timestamp_ns - straggler.duration_ns); - straggler_event.AddStatValue(*straggler_recv_time_ns_stat_metadata, - straggler.end_timestamp_ns); - } - // Add text metadata for all stragglers. - straggler_details += - " Src slice: " + std::to_string(straggler.src_slice_id) + - " -- Duration (ns): " + std::to_string(straggler.duration_ns) + - " -- [Send Timestamp, Recv Timestamp]: [" + - std::to_string(straggler.end_timestamp_ns - straggler.duration_ns) + - ", " + std::to_string(straggler.end_timestamp_ns) + "]\n"; - if (++stragglers_processed >= col_burst.num_messages) break; - } - event.AddStatValue(*straggler_details_metadata, straggler_details); - } - } -} - -void DcnEventsProcessor::AddTpuCollectiveDcnTrafficToXPlane( - XPlane* device_xplane) { - XPlaneBuilder plane_builder(device_xplane); - auto tpu = tsl::profiler::GetTensorCoreId(plane_builder.Name()); - if (!tpu.has_value()) return; - uint32_t tpu_idx = FindTpuIdx(tpu.value()); - AddQualifiedCollectivesToXPlane(plane_builder, tpu_idx); - AddUnqualifiedCollectivesToXPlane(plane_builder, tpu_idx); -} -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/dcn_analysis.h b/tensorflow/core/profiler/convert/dcn_analysis.h deleted file mode 100644 index d17cfc9f31764a..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_analysis.h +++ /dev/null @@ -1,227 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DCN_ANALYSIS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_DCN_ANALYSIS_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/xplane_builder.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/profiler/convert/dcn_utils.h" - -namespace tensorflow { -namespace profiler { - -// Structure representing a DcnMessage using two entries: -// One for the start of the message and one for the end. -struct TimestampEvent { - uint64_t timestamp_ns; // TraceMe logging timestamp - uint64_t duration_ns; // 0 for start of message, duration for end of message - int32_t message_diff; // +1/-1 for start/end of message. - // Makes handling 0-sized messages easier and is - // convenient for the burst generation algorithm. - size_t size_diff; // +size/-size for start/end of message. - int32_t src_slice_id; // Source slice for message, used for stragglers -}; - -// We use an multi map since TimestampEvents will be ordered and we -// need separate entries for possible events happening at exactly the -// same time. -typedef std::multimap> TimestampMap; -typedef absl::flat_hash_map CollectiveTimestampMap; - -// Straggler messages. These are shown at the end of the bursts they belong to. -struct Straggler { - uint64_t duration_ns; // Message duration in ns - uint64_t end_timestamp_ns; // End of the message. For the last straggler - // this will be the end of the burst - size_t size_bytes; // Size of the message in bytes - int32_t src_slice_id; // Source slice of the message - // TODO(emizan) Add host info. -}; - -static constexpr uint32_t kMaxStragglersPerBurst = 4; - -// DCN Burst description. -// A burst is defined as a period of time during which there is at least one -// message in the network. Since DCN traffic is bursty this structure is -// convenient to summarize 100K+ messages in a few 10s of bursts. -// Burst scope is flexible. In this analysis we have per-host bursts, which -// include messages arriving on a single host independent of sender/target TPU/ -// and collective. We also have per collective/TPU bursts which include messages -// for a single collective+TPU combination. -struct DcnBurst { - uint64_t start_timestamp_ns; // Beginning of burst in ns - uint64_t end_timestamp_ns; // End of burst in ns - uint64_t burst_size_bytes; // Total number of bytes in burst - uint64_t num_messages; // Messages in burst - uint64_t max_overlapping_messages; // Max overlapping messages in burst - // Buffer of stragglers in a bursts. Contains the last few messages in a burst - std::array stragglers; -}; - -// Class with functionality to generate DcnBursts out of TimestampEvents. -// Burst creation is a non-trivial state machine -class DcnBurstManager { - public: - DcnBurstManager() = default; - uint64_t TotalLatency() const { return total_latency_; } - void SetToDisplay(bool to_display) { to_display_ = to_display; } - bool ToDisplay() const { return to_display_; } - const std::vector &GetBursts() const { return bursts_; } - - // Run burst state machine creation out of timestamp map. - void CreateBursts(const TimestampMap &tm_events); - // For debugging purposes. - void PrintBursts() { - for (const auto &burst : bursts_) { - LOG(INFO) << burst.start_timestamp_ns << " " << burst.end_timestamp_ns - << " " << burst.num_messages << " " << burst.burst_size_bytes - << " " << burst.max_overlapping_messages; - } - } - - private: - std::vector bursts_; // Bursts created by this manager - uint64_t total_latency_ = 0; // Total latency of all bursts created - // Used to see if bursts will be displayed - bool to_display_ = false; // Set to true to enable burst display - - int32_t active_burst_messages_; // Used by burst creation state machine. - DcnBurst active_burst_; // Active burst in creation - uint32_t straggler_idx_; - - // Initializes state machine when new burst is detected. - void ResetBurstState(); -}; - -typedef absl::flat_hash_map - CollectiveBurstManager; - -class DcnEventsProcessor { - public: - DcnEventsProcessor() = delete; - DcnEventsProcessor(uint32_t num_tpu_tensor_cores, bool is_megacore); - - uint32_t NumTpuTensorCores() const { return num_tpu_tensor_cores_; } - bool IsMegacore() const { return is_megacore_; } - - // Populates available megascale messages from event metadata. - void SetupMessageInfo(const tsl::profiler::XPlaneVisitor &plane); - - std::optional MegaScaleMessageId(absl::string_view msg_name) const { - auto iter = megascale_msg_.find(msg_name); - if (iter != megascale_msg_.end()) { - return iter->second; - } - return std::nullopt; - } - - uint32_t NumReceivedMessages() const { return received_messages_.size(); } - const tensorflow::profiler::DcnMessage &GetMessage(uint32_t i) const { - return received_messages_[i]; - } - - // Checks if messages with msg event name have been found in event metadata. - bool HasDcnMessages(absl::string_view msg_name) const { - return (megascale_msg_.find(msg_name) != megascale_msg_.end()); - } - - const TimestampMap &HostTsMap() const { return host_ts_map_; } - const std::vector &GetHostBursts() const { - return host_dcn_bursts_.GetBursts(); - } - - // Main function to process receive messages, and call other functions - // to generate timestamp events and bursts. - void ProcessReceiveMessages(const tsl::profiler::XPlaneVisitor &plane); - - // Update XPlanes using DCN traffic info - void AddHostDcnTrafficToXPlane(tsl::profiler::XPlane *host_xplane); - void AddTpuCollectiveDcnTrafficToXPlane(tsl::profiler::XPlane *device_xplane); - - private: - // Tensor cores and megacore flag for this host. DCN messages are sent to a - // TPU chip, so we need to know the number of tensor cores and whether - // megacore is used to map DCN traffic to the proper tensor core. - const uint32_t num_tpu_tensor_cores_; - const bool is_megacore_; - - // Used for visualization of BW and computation of BW utilization. - static constexpr float kLimitLowHostDcnBw = 4.17; - static constexpr float kLimitMedHostDcnBw = 8.34; - static constexpr float kMaxHostDcnBw = 12.5; - - std::vector registered_dcn_messages_; - - // Available megascale messages for this trace. - absl::flat_hash_map megascale_msg_; - - std::vector received_messages_; - - // TimestampMaps for messages that arrive to this host - // and for messages of distinct collectives going to different TPUs. - TimestampMap host_ts_map_; - std::vector tpu_collective_ts_map_; - - // DcnBurstManagers for bursts that arrive to this host - // and for burst from distinct collectives going to different TPUs. - DcnBurstManager host_dcn_bursts_; - std::vector tpu_collective_bursts_; - - // Find the TPU index a DCN message goes to. - uint32_t FindTpuIdx(int tpu); - - // Generates BW info to display in the trace viewer. - // This included trace event BW level string, mean BW per burst and - // utilization. - absl::string_view GetBwInfo(bool is_per_tpu, const DcnBurst &burst, - float &burst_mean_bw, - float &burst_bw_utilization); - - // Qualify collectives to display on trace viewer. - // Qualified collectives are given a dedicated line, while for the rest - // we share a single line for their stragglers. - uint32_t NumCollectivesQualified(const std::vector &latencies); - void QualifyCollectives(); - // Export collective DCN activity to trace viewer. - void AddQualifiedCollectivesToXPlane( - tsl::profiler::XPlaneBuilder &plane_builder, uint32_t tpu_idx); - void AddUnqualifiedCollectivesToXPlane( - tsl::profiler::XPlaneBuilder &plane_builder, uint32_t tpu_idx); - - // Create timestamp events for every message - void GenerateTimestampEvents( - const tensorflow::profiler::DcnMessage &dcn_message); - // For debugging purposes - void PrintTimestampEvents(); - // Generate bursts (host and TPU/collective) from timestamp events. - void GenerateBursts(); -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DCN_ANALYSIS_H_ diff --git a/tensorflow/core/profiler/convert/dcn_analysis_test.cc b/tensorflow/core/profiler/convert/dcn_analysis_test.cc deleted file mode 100644 index b71a583bf26d65..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_analysis_test.cc +++ /dev/null @@ -1,363 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/dcn_analysis.h" - -#include -#include -#include - -#include -#include -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/xplane_builder.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/profiler/convert/dcn_utils.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -using tensorflow::profiler::DCN_MESSAGE_INVALID_BAD_KEY; -using tensorflow::profiler::DCN_MESSAGE_INVALID_CLOCK_SKEW; -using tensorflow::profiler::DCN_MESSAGE_VALID; -using tensorflow::profiler::DCN_MESSAGE_VALID_LOOPBACK; -using ::testing::FieldsAre; -using tsl::profiler::kMegaScaleDcnReceive; -using tsl::profiler::kMegaScaleDcnSend; -using tsl::profiler::XEventBuilder; -using tsl::profiler::XEventMetadata; -using tsl::profiler::XLineBuilder; -using tsl::profiler::XPlane; -using tsl::profiler::XPlaneBuilder; -using tsl::profiler::XPlaneVisitor; -using tsl::profiler::XSpace; - -TEST(DcnAnalysis, SetupMessageInfoTest) { - XSpace space; - XPlane *host_trace = space.add_planes(); - XPlaneBuilder host_trace_builder(host_trace); - - XEventMetadata *event_metadata_1 = - host_trace_builder.GetOrCreateEventMetadata(1); - event_metadata_1->set_name(std::string(kMegaScaleDcnReceive)); - XEventMetadata *event_metadata_2 = - host_trace_builder.GetOrCreateEventMetadata(2); - event_metadata_2->set_name(std::string(kMegaScaleDcnSend)); - - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - DcnEventsProcessor dcn_events_processor(/*num_tpu_tensor_cores*/ 4, - /*is_megacore*/ false); - dcn_events_processor.SetupMessageInfo(plane); - ASSERT_FALSE(dcn_events_processor.HasDcnMessages(kMegaScaleDcnSend)); - ASSERT_TRUE(dcn_events_processor.HasDcnMessages(kMegaScaleDcnReceive)); - ASSERT_FALSE(dcn_events_processor.HasDcnMessages("Another Message")); - ASSERT_EQ(dcn_events_processor.MegaScaleMessageId(kMegaScaleDcnReceive), 1); - ASSERT_EQ(dcn_events_processor.MegaScaleMessageId(kMegaScaleDcnSend), - std::nullopt); -} - -// Test processing of valid messages and that all of them are received. -TEST(DcnAnalysis, CreateMessageTestValidMessages) { - XSpace space; - XPlane *host_trace = space.add_planes(); - XPlaneBuilder xplane_builder(host_trace); - - XEventMetadata *event_metadata_1 = xplane_builder.GetOrCreateEventMetadata(1); - event_metadata_1->set_name(std::string(kMegaScaleDcnReceive)); - - XLineBuilder xline_builder_0 = xplane_builder.GetOrCreateLine(0); - XLineBuilder xline_builder_1 = xplane_builder.GetOrCreateLine(1); - - // 1st event - XEventBuilder event_builder = xline_builder_0.AddEvent(*event_metadata_1); - event_builder.SetOffsetNs(100000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_label"), - "all-reduce.273_312"); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_slice_id"), 2); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_per_slice_device_id"), - 3); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_destination_slice_id"), 1); - event_builder.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - "dcn_destination_per_slice_device_id"), - 3); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_chunk"), 0); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_loop_index"), 24); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), 50); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), 32768); - - // 2nd event, same line - event_builder = xline_builder_0.AddEvent(*event_metadata_1); - event_builder.SetOffsetNs(175000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_label"), - "super-collective.1234"); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_slice_id"), 112); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_per_slice_device_id"), - 1); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_destination_slice_id"), 34); - event_builder.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - "dcn_destination_per_slice_device_id"), - 2); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_chunk"), 4); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_loop_index"), 0); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), 50); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), 1); - - // 3rd event event, new line, no chunk/loop index - event_builder = xline_builder_1.AddEvent(*event_metadata_1); - event_builder.SetOffsetNs(150000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_label"), "super-collective"); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_slice_id"), 9); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_per_slice_device_id"), - 3); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_destination_slice_id"), 0); - event_builder.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - "dcn_destination_per_slice_device_id"), - 0); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), 75); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), 10); - - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - DcnEventsProcessor dcn_events_processor(4, false); - dcn_events_processor.SetupMessageInfo(plane); - dcn_events_processor.ProcessReceiveMessages(plane); - - ASSERT_EQ(dcn_events_processor.NumReceivedMessages(), 3); - EXPECT_THAT(dcn_events_processor.GetMessage(0), - FieldsAre("all-reduce.273_312", /* collective name */ - 2, 3, 1, 3, /* slice_src, tpu_src, slice_dst, tpu_dst */ - /* start_timestamp_ns, end_timestamp_ns, duration_us */ - 50000, 100000, 50, - /* size_bytes, chunk_id, loop_index_id */ - 32768, 0, 24, - /* validity_info */ - DCN_MESSAGE_VALID)); - EXPECT_THAT(dcn_events_processor.GetMessage(1), - FieldsAre("super-collective.1234", /* collective name */ - /* slice_src, tpu_src, slice_dst, tpu_dst */ - 112, 1, 34, 2, - /* start_timestamp_ns. end_timestamp_ns, duration_us */ - 125000, 175000, 50, - /* size_bytes, chunk_id, loop_index_id */ - 1, 4, 0, - /* validity_info */ - DCN_MESSAGE_VALID)); - EXPECT_THAT( - dcn_events_processor.GetMessage(2), - FieldsAre("super-collective", /* collective name */ - 9, 3, 0, 0, /* slice_src, tpu_src, slice_dst, tpu_dst */ - 75000, 150000, /* start_timestamp_ns. end_timestamp_ns */ - 75, /* duration_us */ - 10, -1, -1, /* size_bytes, chunk_id, loop_index_id */ - /* validity_info */ - DCN_MESSAGE_VALID)); - TimestampMap host_ts_map = dcn_events_processor.HostTsMap(); - ASSERT_EQ(host_ts_map.size(), 6); - for (const auto &ts_map_item : host_ts_map) { - ASSERT_EQ(ts_map_item.first, ts_map_item.second->timestamp_ns); - if (ts_map_item.first == 50000) { - ASSERT_EQ(ts_map_item.second->duration_ns, 0); - ASSERT_EQ(ts_map_item.second->message_diff, 1); - ASSERT_EQ(ts_map_item.second->size_diff, 32768); - } else if (ts_map_item.first == 125000) { - ASSERT_EQ(ts_map_item.second->duration_ns, 0); - ASSERT_EQ(ts_map_item.second->message_diff, 1); - ASSERT_EQ(ts_map_item.second->size_diff, 1); - } else if (ts_map_item.first == 75000) { - ASSERT_EQ(ts_map_item.second->duration_ns, 0); - ASSERT_EQ(ts_map_item.second->message_diff, 1); - ASSERT_EQ(ts_map_item.second->size_diff, 10); - } else if (ts_map_item.first == 100000) { - ASSERT_EQ(ts_map_item.second->duration_ns, 50000); - ASSERT_EQ(ts_map_item.second->message_diff, -1); - ASSERT_EQ(ts_map_item.second->size_diff, -32768); - } else if (ts_map_item.first == 175000) { - ASSERT_EQ(ts_map_item.second->duration_ns, 50000); - ASSERT_EQ(ts_map_item.second->message_diff, -1); - ASSERT_EQ(ts_map_item.second->size_diff, -1); - } else if (ts_map_item.first == 150000) { - ASSERT_EQ(ts_map_item.second->duration_ns, 75000); - ASSERT_EQ(ts_map_item.second->message_diff, -1); - ASSERT_EQ(ts_map_item.second->size_diff, -10); - } else { - FAIL() << "Unexpected timestamp entry."; - } - } - const std::vector &host_bursts = - dcn_events_processor.GetHostBursts(); - ASSERT_EQ(host_bursts.size(), 1); - ASSERT_EQ(host_bursts[0].num_messages, 3); - ASSERT_EQ(host_bursts[0].start_timestamp_ns, 50000); - ASSERT_EQ(host_bursts[0].end_timestamp_ns, 175000); - ASSERT_EQ(host_bursts[0].burst_size_bytes, 32779); - ASSERT_EQ(host_bursts[0].max_overlapping_messages, 2); -} - -// Loopback message test, currently interpreted as valid. -TEST(DcnAnalysis, CreateLoopBackMessageTest) { - XSpace space; - XPlane *host_trace = space.add_planes(); - XPlaneBuilder xplane_builder(host_trace); - - XEventMetadata *event_metadata_1 = xplane_builder.GetOrCreateEventMetadata(1); - event_metadata_1->set_name(std::string(kMegaScaleDcnReceive)); - - XLineBuilder xline_builder = xplane_builder.GetOrCreateLine(0); - XEventBuilder event_builder = xline_builder.AddEvent(*event_metadata_1); - event_builder.SetOffsetNs(5000000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_label"), "all-gather.1234"); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_slice_id"), 2); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_per_slice_device_id"), - 3); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_destination_slice_id"), 2); - event_builder.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - "dcn_destination_per_slice_device_id"), - 1); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_chunk"), 4); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_loop_index"), 40); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), 1000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), 1000); - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - DcnEventsProcessor dcn_events_processor(4, false); - dcn_events_processor.SetupMessageInfo(plane); - dcn_events_processor.ProcessReceiveMessages(plane); - ASSERT_EQ(dcn_events_processor.NumReceivedMessages(), 1); - EXPECT_THAT(dcn_events_processor.GetMessage(0), - FieldsAre("all-gather.1234", /* collective name */ - 2, 3, 2, 1, /* slice_src, tpu_src, slice_dst, tpu_dst */ - /* start_timestamp_ns. end_timestamp_ns, duration_us */ - 4000000, 5000000, 1000, - /* size_bytes, chunk_id, loop_index_id */ - 1000, 4, 40, - /* validity_info */ - DCN_MESSAGE_VALID_LOOPBACK)); -} - -// Zero duration message, this is due to a bug or clock skew between source -// and destination. Any analysis will just cause confusion, mark it as invalid. -TEST(DcnAnalysis, CreateZeroDurationMessageTest) { - XSpace space; - XPlane *host_trace = space.add_planes(); - XPlaneBuilder xplane_builder(host_trace); - - XEventMetadata *event_metadata_1 = xplane_builder.GetOrCreateEventMetadata(1); - event_metadata_1->set_name(std::string(kMegaScaleDcnReceive)); - - XLineBuilder xline_builder = xplane_builder.GetOrCreateLine(0); - XEventBuilder event_builder = xline_builder.AddEvent(*event_metadata_1); - event_builder.SetOffsetNs(20000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_label"), - "all-reduce.273_312"); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_slice_id"), 2); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_per_slice_device_id"), - 3); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_destination_slice_id"), 1); - event_builder.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - "dcn_destination_per_slice_device_id"), - 1); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_chunk"), 0); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_loop_index"), 25); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), 0); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), 512); - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - DcnEventsProcessor dcn_events_processor(4, false); - dcn_events_processor.SetupMessageInfo(plane); - dcn_events_processor.ProcessReceiveMessages(plane); - EXPECT_THAT( - dcn_events_processor.GetMessage(0), - FieldsAre("all-reduce.273_312", /* collective name */ - 2, 3, 1, 1, /* slice_src, tpu_src, slice_dst, tpu_dst */ - 20000, 20000, - 0, /* start_timestamp_ns. end_timestamp_ns, duration_us */ - 512, 0, 25, /* size_bytes, chunk_id, loop_index_id */ - /* validity_info */ - DCN_MESSAGE_INVALID_CLOCK_SKEW)); -} - -// Missing key test, make sure it is invalid and correctly initialized. -TEST(DcnAnalysis, CreateMissingKeyTest) { - XSpace space; - XPlane *host_trace = space.add_planes(); - XPlaneBuilder xplane_builder(host_trace); - - XEventMetadata *event_metadata_1 = xplane_builder.GetOrCreateEventMetadata(1); - event_metadata_1->set_name(std::string(kMegaScaleDcnReceive)); - - XLineBuilder xline_builder = xplane_builder.GetOrCreateLine(0); - XEventBuilder event_builder = xline_builder.AddEvent(*event_metadata_1); - event_builder.SetOffsetNs(50000); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), 10); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), 100); - - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - DcnEventsProcessor dcn_events_processor(4, false); - dcn_events_processor.SetupMessageInfo(plane); - dcn_events_processor.ProcessReceiveMessages(plane); - EXPECT_THAT( - dcn_events_processor.GetMessage(0), - FieldsAre("", /* collective name */ - -1, -1, -1, -1, /* slice_src, tpu_src, slice_dst, tpu_dst */ - 40000, 50000, /* start_timestamp_ns. end_timestamp_ns, */ - 10, /* duration_us */ - 100, -1, -1, /* size_bytes, chunk_id, loop_index_id */ - /* validity_info */ - DCN_MESSAGE_INVALID_BAD_KEY)); -} - -} // namespace - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc b/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc deleted file mode 100644 index d2b1e7abd59a3b..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc +++ /dev/null @@ -1,92 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h" - -#include - -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" - -namespace tensorflow { -namespace profiler { - -using tensorflow::profiler::DcnSlackAnalysis; -using tensorflow::profiler::DcnSlackSummary; -using tsl::profiler::SafeDivide; - -void DcnSlackAnalysisCombiner::Combine(const DcnSlackAnalysis& slack_analysis) { - for (const auto& slack : slack_analysis.dcn_slack_summary()) { - uint64_t occurrences = slack.occurrences(); - DcnSlackSummary& summary = slack_summary_[slack.rendezvous()]; - summary.set_slack_us(summary.slack_us() + slack.slack_us() * occurrences); - summary.set_observed_duration_us(summary.observed_duration_us() + - slack.observed_duration_us() * - occurrences); - summary.set_stall_duration_us(summary.stall_duration_us() + - slack.stall_duration_us() * occurrences); - summary.set_send_done_duration_us(summary.send_done_duration_us() + - slack.send_done_duration_us() * - occurrences); - summary.set_recv_done_duration_us(summary.recv_done_duration_us() + - slack.recv_done_duration_us() * - occurrences); - summary.set_send_duration_us(summary.send_duration_us() + - slack.send_duration_us() * occurrences); - summary.set_recv_duration_us(summary.recv_duration_us() + - slack.recv_duration_us() * occurrences); - summary.set_host_stall_us(summary.host_stall_us() + - slack.host_stall_us() * occurrences); - summary.set_occurrences(summary.occurrences() + slack.occurrences()); - summary.set_bytes_transmitted_over_network( - slack.bytes_transmitted_over_network()); - summary.set_recv_op_name(slack.recv_op_name()); - summary.set_send_op_name(slack.send_op_name()); - summary.set_transfer_type(slack.transfer_type()); - } -} - -DcnSlackAnalysis DcnSlackAnalysisCombiner::Finalize() { - DcnSlackAnalysis analysis; - for (const auto& [rendezvous, summary] : slack_summary_) { - auto* slack = analysis.add_dcn_slack_summary(); - slack->set_rendezvous(rendezvous); - slack->set_recv_op_name(summary.recv_op_name()); - slack->set_send_op_name(summary.send_op_name()); - slack->set_transfer_type(summary.transfer_type()); - slack->set_slack_us(SafeDivide(summary.slack_us(), summary.occurrences())); - slack->set_observed_duration_us( - SafeDivide(summary.observed_duration_us(), summary.occurrences())); - slack->set_stall_duration_us( - SafeDivide(summary.stall_duration_us(), summary.occurrences())); - slack->set_send_done_duration_us( - SafeDivide(summary.send_done_duration_us(), summary.occurrences())); - slack->set_recv_done_duration_us( - SafeDivide(summary.recv_done_duration_us(), summary.occurrences())); - slack->set_send_duration_us( - SafeDivide(summary.send_duration_us(), summary.occurrences())); - slack->set_recv_duration_us( - SafeDivide(summary.recv_duration_us(), summary.occurrences())); - slack->set_host_stall_us( - SafeDivide(summary.host_stall_us(), summary.occurrences())); - slack->set_occurrences(summary.occurrences()); - slack->set_bytes_transmitted_over_network( - summary.bytes_transmitted_over_network()); - } - - return analysis; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h b/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h deleted file mode 100644 index f0fc727a62dcc1..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DCN_SLACK_ANALYSIS_COMBINER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_DCN_SLACK_ANALYSIS_COMBINER_H_ - -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" - -namespace tensorflow { -namespace profiler { - -using tensorflow::profiler::DcnSlackAnalysis; -using tensorflow::profiler::DcnSlackSummary; - -class DcnSlackAnalysisCombiner { - private: - absl::flat_hash_map slack_summary_; - - public: - // Combine the DCN Slack Summary in the DcnSlackAnalysis. - // The DcnSlackAnalysis consists of average durations, The combine phase, the - // summary consists of the total duration for all the occurrences. Finazile - // must be called to get the accurate value. - void Combine(const DcnSlackAnalysis& slack_analysis); - - // Finalize the DcnSlackSummary by converting total durations to averages. - DcnSlackAnalysis Finalize(); -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DCN_SLACK_ANALYSIS_COMBINER_H_ diff --git a/tensorflow/core/profiler/convert/dcn_utils.cc b/tensorflow/core/profiler/convert/dcn_utils.cc deleted file mode 100644 index 6a457053c30b85..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_utils.cc +++ /dev/null @@ -1,121 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/dcn_utils.h" - -#include "absl/strings/match.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using tsl::profiler::MicroToNano; -using tsl::profiler::StatType; -using tsl::profiler::XEventVisitor; -using tsl::profiler::XStatVisitor; - -DcnMessage CreateDcnMessageFromStats(const XEventVisitor& event_visitor) { - DcnMessage dcn_message; - event_visitor.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type()) return; - switch (static_cast(*stat.Type())) { - case StatType::kDcnLabel: { - dcn_message.collective_name = stat.ToString(); - break; - } - case StatType::kDcnSourceSliceId: { - dcn_message.slice_src = stat.IntValue(); - break; - } - case StatType::kDcnSourcePerSliceDeviceId: { - dcn_message.tpu_src = stat.IntValue(); - break; - } - case StatType::kDcnDestinationSliceId: { - dcn_message.slice_dst = stat.IntValue(); - break; - } - case StatType::kDcnDestinationPerSliceDeviceId: { - dcn_message.tpu_dst = stat.IntValue(); - break; - } - case StatType::kDcnChunk: { - dcn_message.chunk_id = stat.IntValue(); - break; - } - case StatType::kDcnLoopIndex: { - dcn_message.loop_index_id = stat.IntValue(); - - break; - } - case StatType::kPayloadSizeBytes: { - dcn_message.size_bytes = stat.IntValue(); - break; - } - case StatType::kDuration: { - dcn_message.duration_us = stat.IntOrUintValue(); - dcn_message.start_timestamp_ns = - event_visitor.TimestampNs() - MicroToNano(dcn_message.duration_us); - dcn_message.end_timestamp_ns = event_visitor.TimestampNs(); - break; - } - default: - break; - } - }); - return dcn_message; -} - -// Analyze message to see if it can be directly processed or it falls under -// corner-case categories, or if there is something wrong with it. -void SetMessageValidity(DcnMessage& dcn_message) { - // Message should not be valid if fields have not been set properly - // The main use of that is to detect unexpected key format changes that do - // not cause crashes. - if (dcn_message.collective_name.empty() || dcn_message.slice_src == -1 || - dcn_message.tpu_src == -1 || dcn_message.slice_dst == -1 || - dcn_message.tpu_dst == -1 || dcn_message.size_bytes == -1) { - dcn_message.validity_info = DCN_MESSAGE_INVALID_BAD_KEY; - } else if (dcn_message.duration_us == 0) { - // Destination timestamp smaller than the source timestamp likely due to - // clock skew - dcn_message.validity_info = DCN_MESSAGE_INVALID_CLOCK_SKEW; - } else if (dcn_message.slice_src == dcn_message.slice_dst) { - // Loopback messages remain on the same host, so they are valid - // even though they should not go through DCN. - // TODO(emizan): Get host/TPU info and check host, not slice. - dcn_message.validity_info = DCN_MESSAGE_VALID_LOOPBACK; - } else { - dcn_message.validity_info = DCN_MESSAGE_VALID; - } -} -} // namespace - -DcnMessage GetDcnMessageFromXEvent(const XEventVisitor& event_visitor) { - DcnMessage dcn_message = CreateDcnMessageFromStats(event_visitor); - SetMessageValidity(dcn_message); - return dcn_message; -} - -bool IsDcnEvent(const tsl::profiler::XEventVisitor& event) { - return absl::StartsWith(event.Name(), "MegaScale:"); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/dcn_utils.h b/tensorflow/core/profiler/convert/dcn_utils.h deleted file mode 100644 index e0dd3a174df919..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_utils.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DCN_UTILS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_DCN_UTILS_H_ - -#include - -#include "xla/tsl/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -// DCN Message Validity -enum DcnMessageValidity { - // Valid message - DCN_MESSAGE_VALID = 1, - // Valid message, but should not go through DCN, so it should not use BW. - DCN_MESSAGE_VALID_LOOPBACK = 2, - // Invalid message with 0 duration due to clock skew. Should be ignored. - DCN_MESSAGE_INVALID_CLOCK_SKEW = 3, - // Message that cannot be decoded. Should be ignored. - DCN_MESSAGE_INVALID_BAD_KEY = 4 -}; - -// Structure representing a DCN event -struct DcnMessage { - // Unique collective that generated this message, format should be - // _, e.g. all_gather_34 - std::string collective_name = ""; - // Src info - // TODO(emizan) Add host info when you figure out how to get it from - // slice+tpu. - int32_t slice_src = -1; - int32_t tpu_src = -1; - // Dst info - int32_t slice_dst = -1; - int32_t tpu_dst = -1; - // Timing info in ns. Since MSXLA TraceMe's have us timestamps, we need to - // multiply by 1000 to get these timestamps. - uint64_t start_timestamp_ns = 0; - uint64_t end_timestamp_ns = 0; - uint64_t duration_us = 0; - // Size info - size_t size_bytes = 0; - // Chunk and Loop index - int32_t chunk_id = -1; - int32_t loop_index_id = -1; - // Is message valid/invalid and why - DcnMessageValidity validity_info = DCN_MESSAGE_INVALID_BAD_KEY; - // TBD: Add flow events in case you need to connect to other events pointed to - // by MSXLA TraceMe's -}; - -DcnMessage GetDcnMessageFromXEvent( - const tsl::profiler::XEventVisitor& event_visitor); - -// Check if the XEventVisitor is a DCN Message -bool IsDcnEvent(const tsl::profiler::XEventVisitor& event); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DCN_UTILS_H_ diff --git a/tensorflow/core/profiler/convert/dcn_utils_test.cc b/tensorflow/core/profiler/convert/dcn_utils_test.cc deleted file mode 100644 index 8789da9d07b8f8..00000000000000 --- a/tensorflow/core/profiler/convert/dcn_utils_test.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/dcn_utils.h" - -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/xplane_builder.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using tsl::profiler::kMegaScaleDcnReceive; -using tsl::profiler::XEventBuilder; -using tsl::profiler::XEventVisitor; -using tsl::profiler::XLineBuilder; -using tsl::profiler::XPlaneBuilder; -using tsl::profiler::XPlaneVisitor; - -void PopulateXPlane(XPlane &xplane, absl::string_view event_name, int offset, - absl::string_view label, int64_t source_slice_id, - int64_t source_per_slice_device_id, - int64_t destination_slice_id, - int64_t destination_per_slice_device_id, int64_t chunk, - int64_t loop_index, int64_t payload_size, - int64_t duration) { - XPlaneBuilder xplane_builder(&xplane); - - XEventMetadata *event_metadata = xplane_builder.GetOrCreateEventMetadata(1); - event_metadata->set_name(std::string(event_name)); - - XLineBuilder xline_builder = xplane_builder.GetOrCreateLine(0); - XEventBuilder event_builder = xline_builder.AddEvent(*event_metadata); - event_builder.SetOffsetNs(offset); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_label"), label); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_slice_id"), - source_slice_id); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_source_per_slice_device_id"), - source_per_slice_device_id); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_destination_slice_id"), - destination_slice_id); - event_builder.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - "dcn_destination_per_slice_device_id"), - destination_per_slice_device_id); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_chunk"), chunk); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("dcn_loop_index"), loop_index); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("duration_us"), duration); - event_builder.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata("payload_size_bytes"), - payload_size); -} - -TEST(DcnUtilsTest, IsDcnEvent) { - XPlane xplane; - PopulateXPlane(xplane, kMegaScaleDcnReceive, 0, "test", 0, 0, 0, 0, 0, 0, 0, - 0); - XLine line = xplane.lines()[0]; - XPlaneVisitor xplane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&xplane); - - XEventVisitor visitor(&xplane_visitor, &line, &line.events()[0]); - EXPECT_TRUE(IsDcnEvent(visitor)); -} - -TEST(DcnUtilsTest, IsNotDcnEvent) { - XPlane xplane; - PopulateXPlane(xplane, "test", 0, "test", 0, 0, 0, 0, 0, 0, 0, 0); - XLine line = xplane.lines()[0]; - XPlaneVisitor xplane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&xplane); - - XEventVisitor visitor(&xplane_visitor, &line, &line.events()[0]); - EXPECT_FALSE(IsDcnEvent(visitor)); -} - -TEST(DcnUtilsTest, GetDcnMessageFromXEvent) { - XPlane xplane; - PopulateXPlane(xplane, kMegaScaleDcnReceive, 100000, "all-reduce.273_312", 2, - 3, 1, 3, 0, 24, 32768, 50); - XPlaneVisitor xplane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&xplane); - XEventVisitor visitor(&xplane_visitor, &xplane.lines()[0], - &xplane.lines()[0].events()[0]); - EXPECT_THAT(GetDcnMessageFromXEvent(visitor), - testing::FieldsAre( - "all-reduce.273_312", /* collective name */ - 2, 3, 1, 3, /* slice_src, tpu_src, slice_dst, tpu_dst */ - /* start_timestamp_ns, end_timestamp_ns, duration_us */ - 50000, 100000, 50, - /* size_bytes, chunk_id, loop_index_id */ - 32768, 0, 24, - /* validity_info */ - DCN_MESSAGE_VALID)); -} - -TEST(DcnUtilsTest, GetDcnMessageFromXEventLoopBack) { - XPlane xplane; - PopulateXPlane(xplane, kMegaScaleDcnReceive, 5000000, "all-gather.1234", 2, 3, - 2, 1, 4, 40, 1000, 1000); - XPlaneVisitor xplane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&xplane); - XEventVisitor visitor(&xplane_visitor, &xplane.lines()[0], - &xplane.lines()[0].events()[0]); - EXPECT_THAT(GetDcnMessageFromXEvent(visitor), - testing::FieldsAre( - "all-gather.1234", /* collective name */ - 2, 3, 2, 1, /* slice_src, tpu_src, slice_dst, tpu_dst */ - /* start_timestamp_ns. end_timestamp_ns, duration_us */ - 4000000, 5000000, 1000, - /* size_bytes, chunk_id, loop_index_id */ - 1000, 4, 40, - /* validity_info */ - DCN_MESSAGE_VALID_LOOPBACK)); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/duty_cycle_combiner.h b/tensorflow/core/profiler/convert/duty_cycle_combiner.h deleted file mode 100644 index 74b2e0ebdc9aba..00000000000000 --- a/tensorflow/core/profiler/convert/duty_cycle_combiner.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_COMBINER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_COMBINER_H_ - -#include - -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Responsible for combining the duty cycle trackers for all cores and chips. -class DutyCycleCombiner { - public: - // Combines the given core tracker with the tracker for the given chip. - // NOTE: The given chip_id should be unique across all chips being combined. - void CombineCore(const DutyCycleTracker& core_tracker, uint32_t chip_id) { - chip_duty_cycle_trackers_[chip_id].Union(core_tracker); - } - - // Combines the given chip tracker with the tracker for other chips. - void CombineChip(const DutyCycleTracker& chip_tracker) { - chip_active_time_ps_ += chip_tracker.GetActiveTimePs(); - chip_idle_time_ps_ += chip_tracker.GetIdleTimePs(); - } - - // Returns the total active time across all chips and cores. - uint64_t GetTotalActiveTimePs() const { - uint64_t total_busy_time_ps = chip_active_time_ps_; - for (const auto& [chip_id, tracker] : chip_duty_cycle_trackers_) { - total_busy_time_ps += tracker.GetActiveTimePs(); - } - return total_busy_time_ps; - } - - // Returns the total idle time across all chips and cores. - uint64_t GetTotalIdleTimePs() const { - uint64_t total_idle_time_ps = chip_idle_time_ps_; - for (const auto& [chip_id, tracker] : chip_duty_cycle_trackers_) { - total_idle_time_ps += tracker.GetIdleTimePs(); - } - return total_idle_time_ps; - } - - private: - absl::flat_hash_map chip_duty_cycle_trackers_; - uint64_t chip_active_time_ps_ = 0; - uint64_t chip_idle_time_ps_ = 0; -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_COMBINER_H_ diff --git a/tensorflow/core/profiler/convert/duty_cycle_combiner_test.cc b/tensorflow/core/profiler/convert/duty_cycle_combiner_test.cc deleted file mode 100644 index 6a9e158b43da5b..00000000000000 --- a/tensorflow/core/profiler/convert/duty_cycle_combiner_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/duty_cycle_combiner.h" - -#include -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tsl::profiler::Timespan; - -TEST(DutyCycleAnalysisTest, CombineMultiCoreChipTest) { - DutyCycleTracker core0_tracker; - core0_tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - core0_tracker.AddInterval(Timespan::FromEndPoints(20, 30), false); - DutyCycleTracker core1_tracker; - core1_tracker.AddInterval(Timespan::FromEndPoints(10, 20), false); - core1_tracker.AddInterval(Timespan::FromEndPoints(20, 30), true); - - DutyCycleCombiner combiner; - combiner.CombineCore(core0_tracker, 0); - combiner.CombineCore(core1_tracker, 0); - - EXPECT_EQ(combiner.GetTotalActiveTimePs(), 20); - EXPECT_EQ(combiner.GetTotalIdleTimePs(), 0); -} - -TEST(DutyCycleAnalysisTest, CombineMultiChipTest) { - DutyCycleTracker chip0_tracker; - chip0_tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - chip0_tracker.AddInterval(Timespan::FromEndPoints(20, 30), false); - DutyCycleTracker chip1_tracker; - chip1_tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - chip1_tracker.AddInterval(Timespan::FromEndPoints(20, 30), false); - - DutyCycleCombiner combiner; - combiner.CombineChip(chip0_tracker); - combiner.CombineChip(chip1_tracker); - - EXPECT_EQ(combiner.GetTotalActiveTimePs(), 20); - EXPECT_EQ(combiner.GetTotalIdleTimePs(), 20); -} - -TEST(DutyCycleAnalysisTest, CombineMultiChipAndCoreTest) { - DutyCycleTracker chip0_core0_tracker; - chip0_core0_tracker.AddInterval(Timespan::FromEndPoints(10, 20), false); - chip0_core0_tracker.AddInterval(Timespan::FromEndPoints(20, 30), true); - DutyCycleTracker chip0_core1_tracker; - chip0_core1_tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - chip0_core1_tracker.AddInterval(Timespan::FromEndPoints(20, 30), false); - DutyCycleTracker chip1_tracker; - chip1_tracker.AddInterval(Timespan::FromEndPoints(15, 25), true); - chip1_tracker.AddInterval(Timespan::FromEndPoints(10, 30), false); - - DutyCycleCombiner combiner; - combiner.CombineCore(chip0_core0_tracker, 0); - combiner.CombineCore(chip0_core1_tracker, 0); - combiner.CombineChip(chip1_tracker); - - EXPECT_EQ(combiner.GetTotalActiveTimePs(), 30); - EXPECT_EQ(combiner.GetTotalIdleTimePs(), 10); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/duty_cycle_tracker.cc b/tensorflow/core/profiler/convert/duty_cycle_tracker.cc deleted file mode 100644 index 96d793e86b31ea..00000000000000 --- a/tensorflow/core/profiler/convert/duty_cycle_tracker.cc +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" - -#include - -#include -#include -#include - -#include "absl/container/btree_set.h" -#include "absl/log/check.h" -#include "xla/tsl/profiler/utils/timespan.h" - -namespace tensorflow { -namespace profiler { - -using tsl::profiler::Timespan; - -DutyCycleTracker::ActiveTimeSpans::const_iterator -DutyCycleTracker::MergeOrInsert(const Timespan& timespan, - ActiveTimeSpans::const_iterator hint) { - DCHECK(hint == active_time_spans_.end() || - hint == active_time_spans_.begin() || - hint->begin_ps() <= timespan.begin_ps()); - ActiveTimeSpans::const_iterator merge_begin = hint; - while (merge_begin != active_time_spans_.end() && - merge_begin->end_ps() < timespan.begin_ps()) { - ++merge_begin; - } - - // timespan is fully contained in an existing timespan. - if (merge_begin != active_time_spans_.end() && - merge_begin->Includes(timespan)) { - return merge_begin; - } - - ActiveTimeSpans::const_iterator merge_end = merge_begin; - while (merge_end != active_time_spans_.end() && - merge_end->begin_ps() <= timespan.end_ps()) { - ++merge_end; - } - if (merge_begin != merge_end) { - Timespan merged = Timespan::FromEndPoints( - std::min(timespan.begin_ps(), merge_begin->begin_ps()), - std::max(timespan.end_ps(), std::prev(merge_end)->end_ps())); - merge_end = active_time_spans_.erase(merge_begin, merge_end); - return active_time_spans_.insert(merge_end, merged); - } else { - // There is no overlap with the existing timespans. - return active_time_spans_.insert(merge_begin, timespan); - } -} - -void DutyCycleTracker::AddInterval(tsl::profiler::Timespan time_span, - bool is_active) { - total_time_span_.ExpandToInclude(time_span); - if (!is_active) { - return; - } - - auto hint = active_time_spans_.lower_bound(time_span); - if (hint != active_time_spans_.begin()) --hint; - MergeOrInsert(time_span, hint); -} - -void DutyCycleTracker::Union(const DutyCycleTracker& other) { - total_time_span_.ExpandToInclude(other.total_time_span_); - if (other.active_time_spans_.empty()) return; - ActiveTimeSpans::const_iterator hint_it = - active_time_spans_.lower_bound(*other.active_time_spans_.begin()); - if (hint_it != active_time_spans_.begin()) --hint_it; - for (const auto& interval : other.active_time_spans_) { - hint_it = MergeOrInsert(interval, hint_it); - } -} - -uint64_t DutyCycleTracker::GetActiveTimePs() const { - uint64_t active_time_ps = 0; - for (const auto& interval : active_time_spans_) { - DCHECK(!interval.Empty()); - active_time_ps += interval.duration_ps(); - } - return active_time_ps; -} - -uint64_t DutyCycleTracker::GetIdleTimePs() const { - return total_time_span_.duration_ps() - GetActiveTimePs(); -} -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/duty_cycle_tracker.h b/tensorflow/core/profiler/convert/duty_cycle_tracker.h deleted file mode 100644 index bf5160d97d3037..00000000000000 --- a/tensorflow/core/profiler/convert/duty_cycle_tracker.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_TRACKER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_TRACKER_H_ - -#include - -#include "absl/container/btree_set.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/timespan.h" - -namespace tensorflow { -namespace profiler { - -// Tracks the active time intervals for a given TPU core. -// Disjoint intervals of time in ps for which this core was active. -class DutyCycleTracker { - public: - DutyCycleTracker() : active_time_spans_() {} - void AddInterval(tsl::profiler::Timespan time_span, bool is_active); - void Union(const DutyCycleTracker& other); - uint64_t GetActiveTimePs() const; - uint64_t GetIdleTimePs() const; - uint64_t GetDurationPs() const { return total_time_span_.duration_ps(); } - double DutyCycle() const { - return tsl::profiler::SafeDivide(GetActiveTimePs(), GetDurationPs()); - } - - private: - struct TimespanComparator { - // Order by increasing begin_ps, then decreasing duration_ps. - bool operator()(const tsl::profiler::Timespan& a, - const tsl::profiler::Timespan& b) const { - return a.begin_ps() < b.begin_ps() || (a.begin_ps() == b.begin_ps() && - a.duration_ps() > b.duration_ps()); - } - }; - using ActiveTimeSpans = - absl::btree_set; - - /** - * Merge or insert the given timespan into the set of active time spans. - * - * @param timespan The timespan to merge or insert. - * @param hint The iterator indicating where to begin the merge search. - * @return The iterator where the timespan was merged or inserted. - */ - ActiveTimeSpans::const_iterator MergeOrInsert( - const tsl::profiler::Timespan& timespan, - ActiveTimeSpans::const_iterator hint); - - ActiveTimeSpans active_time_spans_; - tsl::profiler::Timespan total_time_span_; -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_TRACKER_H_ diff --git a/tensorflow/core/profiler/convert/duty_cycle_tracker_test.cc b/tensorflow/core/profiler/convert/duty_cycle_tracker_test.cc deleted file mode 100644 index 2ee0218d986f54..00000000000000 --- a/tensorflow/core/profiler/convert/duty_cycle_tracker_test.cc +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" - -#include - -#include -#include - -#include -#include "absl/log/check.h" -#include "xla/tsl/platform/test_benchmark.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tsl::profiler::Timespan; - -TEST(DutyCycleTrackerTest, NonOverlappingIntervalsTest) { - DutyCycleTracker tracker; - tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - tracker.AddInterval(Timespan::FromEndPoints(30, 40), true); - EXPECT_EQ(tracker.GetActiveTimePs(), 20); - EXPECT_EQ(tracker.GetIdleTimePs(), 10); - EXPECT_EQ(tracker.GetDurationPs(), 30); - EXPECT_NEAR(tracker.DutyCycle(), 0.6666, 0.0001); -} - -TEST(DutyCycleTrackerTest, OverlappingIntervalsTest) { - DutyCycleTracker tracker; - tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - tracker.AddInterval(Timespan::FromEndPoints(30, 40), true); - tracker.AddInterval(Timespan::FromEndPoints(20, 35), true); - EXPECT_EQ(tracker.GetActiveTimePs(), 30); - EXPECT_EQ(tracker.GetIdleTimePs(), 0); - EXPECT_EQ(tracker.GetDurationPs(), 30); - EXPECT_EQ(tracker.DutyCycle(), 1.0); -} - -TEST(DutyCycleTrackerTest, DutyCycleTestWithIncludedIntervals) { - DutyCycleTracker tracker; - tracker.AddInterval(Timespan::FromEndPoints(10, 40), true); - tracker.AddInterval(Timespan::FromEndPoints(20, 30), true); - EXPECT_EQ(tracker.GetActiveTimePs(), 30); - EXPECT_EQ(tracker.GetIdleTimePs(), 0); - EXPECT_EQ(tracker.GetDurationPs(), 30); - EXPECT_EQ(tracker.DutyCycle(), 1.0); -} - -TEST(DutyCycleTrackerTest, UnionTest) { - DutyCycleTracker tracker; - tracker.AddInterval(Timespan::FromEndPoints(0, 10), true); - tracker.AddInterval(Timespan::FromEndPoints(20, 30), true); - - DutyCycleTracker other_tracker; - other_tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - other_tracker.AddInterval(Timespan::FromEndPoints(30, 40), true); - - tracker.Union(other_tracker); - EXPECT_EQ(tracker.GetActiveTimePs(), 40); - EXPECT_EQ(tracker.GetIdleTimePs(), 0); - EXPECT_EQ(tracker.GetDurationPs(), 40); -} - -TEST(DutyCycleTrackerTest, OverlappingMixedIntervalsTest) { - DutyCycleTracker tracker; - EXPECT_EQ(tracker.GetActiveTimePs(), 0); - tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); - tracker.AddInterval(Timespan::FromEndPoints(20, 30), false); - EXPECT_EQ(tracker.GetActiveTimePs(), 10); - EXPECT_EQ(tracker.GetIdleTimePs(), 10); -} - -void BM_DutyCycleTracker_AddInterval(::testing::benchmark::State& state) { - std::vector timespans; - timespans.reserve(state.range(0)); - for (uint64_t i = 0; i < state.range(0); ++i) { - timespans.push_back(Timespan::FromEndPoints(i * 2, i * 2 + 1)); - } - for (auto s : state) { - DutyCycleTracker tracker; - for (const auto& timespan : timespans) { - tracker.AddInterval(timespan, true); - } - } - state.SetItemsProcessed(state.iterations() * timespans.size()); -} - -BENCHMARK(BM_DutyCycleTracker_AddInterval)->Range(1 << 15, 1 << 21); - -void BM_DutyCycleTracker_AddInterval_Merge(::testing::benchmark::State& state) { - std::vector timespans; - timespans.reserve(state.range(0)); - for (uint64_t i = 0; i < state.range(0); ++i) { - timespans.push_back(Timespan::FromEndPoints(i, i + 1)); - } - for (auto s : state) { - DutyCycleTracker tracker; - for (const auto& timespan : timespans) { - tracker.AddInterval(timespan, true); - } - } - state.SetItemsProcessed(state.iterations() * timespans.size()); -} - -BENCHMARK(BM_DutyCycleTracker_AddInterval_Merge)->Range(1 << 15, 1 << 21); - -void BM_DutyCycleTracker_Union(::testing::benchmark::State& state) { - DCHECK_GT(state.range(1), 1); - DCHECK_LT(state.range(1), state.range(0)); - DutyCycleTracker tracker_a; - DutyCycleTracker tracker_b; - uint64_t merge_rate = state.range(1); - for (uint64_t i = 0; i < state.range(0); ++i) { - tracker_a.AddInterval(Timespan(i * 2, 1), true); - if (i % merge_rate == 0) { - tracker_b.AddInterval(Timespan(i * 2 + 1, merge_rate * 2 - 1), true); - } - } - for (auto s : state) { - DutyCycleTracker unioned_tracker; - unioned_tracker.Union(tracker_a); - unioned_tracker.Union(tracker_b); - } - state.SetItemsProcessed(state.iterations() * - (state.range(0) + state.range(0) / merge_rate)); -} - -BENCHMARK(BM_DutyCycleTracker_Union)->RangePair(1 << 10, 1 << 16, 2, 10); - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc deleted file mode 100644 index 2f153dd850b2e9..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.cc +++ /dev/null @@ -1,555 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_replace.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_print_options.h" -#include "xla/tsl/platform/statusor.h" -#ifdef PLATFORM_GOOGLE -#include "nlohmann/json.hpp" -#include "tensorflow/compiler/mlir/lite/experimental/google/tooling/hlo_adapter/direct_hlo_to_json_graph_convert.h" -#endif // PLATFORM_GOOGLE -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo.pb.h" -#include "xla/service/hlo_graph_dumper.h" -#include "xla/tsl/platform/errors.h" -#include "tensorflow/core/profiler/convert/tool_options.h" -#include "tensorflow/core/profiler/utils/hlo_module_utils.h" -#include "tensorflow/core/profiler/utils/hlo_proto_to_module.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tsl::StatusOr; -using ::tsl::errors::InvalidArgument; -using ::xla::HloComputation; -using ::xla::HloInstruction; -using ::xla::HloModule; -using ::xla::HloPrintOptions; -using ::xla::HloProto; -using ::xla::HloRenderOptions; -using ::xla::RenderedGraphFormat; - -constexpr char kCenterNodeKey[] = "centerNode"; - -void CleanUpHloModuleForGraphviz(HloModule* hlo_module) { - // Infeed config is escaped serialized proto, and graphviz server complains. - for (HloComputation* computation : hlo_module->computations()) { - for (HloInstruction* inst : computation->instructions()) { - if (inst->opcode() == xla::HloOpcode::kInfeed) { - inst->set_infeed_config(""); - } else if (inst->opcode() == xla::HloOpcode::kOutfeed) { - inst->set_outfeed_config(""); - } - } - } -} - -#ifdef PLATFORM_GOOGLE -// Add a custom group node on the graph level, for the center node chosen by the -// user set its attributes like `id`, `name` or `opcode` in `graph_json`. -void AddCenterNodeMetadata(nlohmann::json& graph_json, std::string id, - absl::string_view name, absl::string_view opcode) { - nlohmann::json centerGroupNodeAttributes; - centerGroupNodeAttributes["name"] = name; - centerGroupNodeAttributes["id"] = id; - if (!opcode.empty()) { - centerGroupNodeAttributes["opcode"] = opcode; - } - // Follow ModelExplorer's Graph typing: GraphCollectionFromBuiltinAdapters - graph_json[0]["subgraphs"][0]["groupNodeAttributes"][kCenterNodeKey] = - centerGroupNodeAttributes; -} -#endif // PLATFORM_GOOGLE - -void AddGraphMetadata(std::string& graph_json_str, - const HloInstruction& instr) { -#ifdef PLATFORM_GOOGLE - nlohmann::json graph_json = nlohmann::json::parse(graph_json_str); - // 1. Fusion instruction is represented as a layer on client, use its - // pinned node as the center node, id of the pinned node is the fusion name. - // 2. Other instructions are represented as nodes on client, use iteself as - // the center node, where node id is the instruction name. - std::string id = absl::StrCat(instr.name()); - AddCenterNodeMetadata(graph_json, id, instr.name(), - HloOpcodeString(instr.opcode())); - graph_json_str = graph_json.dump(); -#endif // PLATFORM_GOOGLE -} - -void AddGraphMetadata(std::string& graph_json_str, const HloComputation& comp) { -#ifdef PLATFORM_GOOGLE - nlohmann::json graph_json = nlohmann::json::parse(graph_json_str); - // Computation is represented as a layer on client, use its pinned node as the - // center node,id of the pinned node is the computation name. - AddCenterNodeMetadata(graph_json, absl::StrCat(comp.name()), comp.name(), ""); - graph_json_str = graph_json.dump(); -#endif // PLATFORM_GOOGLE -} - -// This function does the same thing as Plot() but uses the ModelExplorer -// instead of graphviz. -absl::StatusOr PlotMe(std::unique_ptr module, - const std::string& node_name, - int graph_width) { - if (node_name.empty()) { - // This should not happen. - return InvalidArgument("node_name should not be empty"); - } - // Find the node with the given name. - const HloInstruction* instr = FindInstruction(*module, node_name); - const HloComputation* comp = FindComputation(*module, node_name); - - if (!instr && !comp) { - return InvalidArgument( - absl::StrCat("Couldn't find HloInstruction or HloComputation named ", - node_name, ".")); - } - // Generate the graph and print the resulting string. - absl::StatusOr graph_handle; - std::string graph_json_str; -// b/360874576: Enable when the adapter is open sourced. -#ifdef PLATFORM_GOOGLE - if (comp) { - graph_handle = tooling::visualization_client::HloGraphAdapter(*comp); - } else { - graph_handle = - tooling::visualization_client::HloGraphAdapter(*instr, graph_width); - } -#endif // PLATFORM_GOOGLE - if (graph_handle.ok()) { - VLOG(1) << graph_handle.value(); - graph_json_str = graph_handle.value(); - if (comp) { - AddGraphMetadata(graph_json_str, *comp); - } else { - AddGraphMetadata(graph_json_str, *instr); - } - return graph_json_str; - } else { - LOG(ERROR) << "Unable to render graph: " << graph_handle.status(); - } - - return graph_handle; -} - -absl::StatusOr Plot(std::unique_ptr module, - const std::string& node_name, int graph_width, - const HloRenderOptions& render_options, - const RenderedGraphFormat& format) { - if (node_name.empty()) { - // This should not happen. - return InvalidArgument("node_name should not be empty"); - } - // Find the node with the given name. - const HloInstruction* instr = FindInstruction(*module, node_name); - const HloComputation* comp = FindComputation(*module, node_name); - if (!instr && !comp) { - return InvalidArgument( - absl::StrCat("Couldn't find HloInstruction or HloComputation named ", - node_name, ".")); - } - // Generate the graph and print the resulting string. - absl::StatusOr graph_handle; - - CleanUpHloModuleForGraphviz(module.get()); - if (comp) { - graph_handle = - RenderGraphView(*comp, "", comp->parent()->config().debug_options(), - format, render_options); - } else { - graph_handle = RenderGraphNeighborhoodAround(*instr, graph_width, format, - render_options); - } - if (graph_handle.ok()) { - VLOG(1) << graph_handle.value(); - } else { - LOG(ERROR) << "Unable to render graph: " << graph_handle.status(); - } - - return graph_handle; -} - -// Default parameter constants for graph viewer. -static constexpr char kGraphTypeName[] = "graph"; -static constexpr char kShortTxtTypeName[] = "short_txt"; -static constexpr char kLongTxtTypeName[] = "long_txt"; -static constexpr char kDefaultFormatString[] = "url"; -static constexpr int kDefaultWidth = 3; -static constexpr int kDefaultShowMetadata = 0; -static constexpr int kDefaultMergeFusion = 0; - -} // namespace - -absl::StatusOr GetNodeStyles() { - std::vector async_op_codes = {xla::HloOpcode::kAsyncStart, - xla::HloOpcode::kAsyncUpdate, - xla::HloOpcode::kAsyncDone}; - std::vector brown_op_codes = { - xla::HloOpcode::kAllGather, - xla::HloOpcode::kAllGatherStart, - xla::HloOpcode::kAllGatherDone, - xla::HloOpcode::kAllReduce, - xla::HloOpcode::kReduceScatter, - xla::HloOpcode::kAllReduceStart, - xla::HloOpcode::kAllReduceDone, - xla::HloOpcode::kAllToAll, - xla::HloOpcode::kCollectiveBroadcast, - xla::HloOpcode::kCollectivePermute, - xla::HloOpcode::kCollectivePermuteStart, - xla::HloOpcode::kCollectivePermuteDone, - xla::HloOpcode::kInfeed, - xla::HloOpcode::kOutfeed, - xla::HloOpcode::kPartitionId, - xla::HloOpcode::kRecv, - xla::HloOpcode::kRecvDone, - xla::HloOpcode::kSend, - xla::HloOpcode::kSendDone, - xla::HloOpcode::kReplicaId}; - std::vector dark_blue_op_codes = { - xla::HloOpcode::kConvolution, xla::HloOpcode::kDot, xla::HloOpcode::kFft, - xla::HloOpcode::kTriangularSolve, xla::HloOpcode::kCholesky}; - std::vector dark_green_op_codes = { - xla::HloOpcode::kCall, xla::HloOpcode::kConditional, - xla::HloOpcode::kCustomCall, xla::HloOpcode::kWhile}; - std::vector gray_op_codes = { - xla::HloOpcode::kDomain, xla::HloOpcode::kFusion, xla::HloOpcode::kMap, - xla::HloOpcode::kGetDimensionSize, xla::HloOpcode::kSetDimensionSize}; - std::vector green_op_codes = { - xla::HloOpcode::kConcatenate, xla::HloOpcode::kDynamicSlice, - xla::HloOpcode::kReshape, xla::HloOpcode::kDynamicReshape, - xla::HloOpcode::kReverse, xla::HloOpcode::kTranspose, - xla::HloOpcode::kCopy, xla::HloOpcode::kCopyStart, - xla::HloOpcode::kCopyDone}; - std::vector orange_op_codes = {xla::HloOpcode::kParameter}; - std::vector purple_op_codes = { - xla::HloOpcode::kBatchNormGrad, xla::HloOpcode::kBatchNormInference, - xla::HloOpcode::kBatchNormTraining, xla::HloOpcode::kReduce, - xla::HloOpcode::kReduceWindow, xla::HloOpcode::kScatter, - xla::HloOpcode::kSelectAndScatter, xla::HloOpcode::kGather}; - std::vector yellow_op_codes = { - xla::HloOpcode::kBroadcast, xla::HloOpcode::kDynamicUpdateSlice}; - - auto OpCodesToNames = - [&](std::vector op_codes) -> std::string { - std::string op_names = ""; - for (const auto& op_code : op_codes) { - if (!op_names.empty()) { - op_names += ","; - } - op_names += std::string(xla::HloOpcodeString(op_code)); - } - return op_names; - }; - - return absl::StrReplaceAll( - R"json({ - "kBlue": "$asyncOpNames", - "kBrown": "$brownOpNames", - "kDarkBlue": "$darkBlueOpNames", - "kDarkGreen": "$darkGreenOpNames", - "kGray": "$grayOpNames", - "kGreen": "$greenOpNames", - "kOrange": "$orangeOpNames", - "kPurple": "$purpleOpNames", - "kYellow": "$yellowOpNames" - })json", - { - {"$asyncOpNames", OpCodesToNames(async_op_codes)}, - {"$brownOpNames", OpCodesToNames(brown_op_codes)}, - {"$darkBlueOpNames", OpCodesToNames(dark_blue_op_codes)}, - {"$darkGreenOpNames", OpCodesToNames(dark_green_op_codes)}, - {"$grayOpNames", OpCodesToNames(gray_op_codes)}, - {"$greenOpNames", OpCodesToNames(green_op_codes)}, - {"$orangeOpNames", OpCodesToNames(orange_op_codes)}, - {"$purpleOpNames", OpCodesToNames(purple_op_codes)}, - {"$yellowOpNames", OpCodesToNames(yellow_op_codes)}, - }); -} - -absl::StatusOr ParseGraphViewerParams( - const ToolOptions& options) { - GraphViewerParams params; - std::optional type = GetParam(options, "type"); - if (!type.has_value()) { - return InvalidArgument("Graph viewer must provide a type option."); - } - - // For graph type. - if (type == kGraphTypeName) { - params.type = type.value(); - if (std::optional node_name = - GetParam(options, "node_name")) { - params.node_name = node_name.value(); - } - - params.graph_width = - GetParamWithDefault(options, "graph_width", kDefaultWidth); - params.render_options.show_backend_config = GetParamWithDefault( - options, "show_metadata", kDefaultShowMetadata); - params.render_options.show_fusion_subcomputations = - !GetParamWithDefault(options, "merge_fusion", kDefaultMergeFusion); - params.format = GetRenderFormat(GetParamWithDefault( - options, "format", kDefaultFormatString)); - - return params; - } - - // For txt type. - if (type == kShortTxtTypeName || type == kLongTxtTypeName) { - params.type = type.value(); - params.verbose = (type == kLongTxtTypeName); - params.show_metadata = - GetParamWithDefault(options, "show_metadata", kDefaultShowMetadata); - return params; - } - - // Unknown type. - return InvalidArgument("Unknown graph viewer type option: ", type.value()); -} - -xla::RenderedGraphFormat GetRenderFormat(const std::string& format_string) { - if (format_string == "html") { - return xla::RenderedGraphFormat::kHtml; - } else if (format_string == "dot") { - return xla::RenderedGraphFormat::kDot; - } else if (format_string == "url") { - return xla::RenderedGraphFormat::kUrl; - } else { - LOG(ERROR) << "Invalid graph format argument: " << format_string - << ", fallback to default url"; - return xla::RenderedGraphFormat::kUrl; - } -} - -absl::StatusOr ConvertHloProtoToGraph( - const HloProto& hlo_proto, const std::string& node_name, int graph_width, - const HloRenderOptions& render_options, const RenderedGraphFormat& format) { - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - ConvertHloProtoToModule(hlo_proto)); - return Plot(std::move(hlo_module), node_name, graph_width, render_options, - format); -} - -absl::StatusOr ConvertHloProtoToMeGraph( - const HloProto& hlo_proto, const std::string& node_name, int graph_width) { - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - ConvertHloProtoToModule(hlo_proto)); - return PlotMe(std::move(hlo_module), node_name, graph_width); -} - -absl::StatusOr ConvertHloProtoToStringView( - const HloProto& hlo_proto, bool verbose, bool metadata) { - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - ConvertHloProtoToModule(hlo_proto)); - HloPrintOptions options; - if (!verbose) { - options = HloPrintOptions::ShortParsable(); - } - options.set_print_large_constants(verbose); - options.set_print_metadata(metadata); - return hlo_module->ToString(options); -} - -std::function(absl::string_view)>* url_renderer = - nullptr; - -// Precondition: (url_renderer != nullptr || format != kUrl). -// -// (We specify this as a precondition rather than checking it in here and -// returning an error because we want to fail quickly when there's no URL -// renderer available, and this function runs only after we've done all the work -// of producing dot for the graph.) -absl::Status CheckPrecondition(xla::RenderedGraphFormat format) { - if (format == xla::RenderedGraphFormat::kUrl && url_renderer == nullptr) { - return absl::FailedPreconditionError( - "Can't render as URL; no URL renderer was registered."); - } - return absl::OkStatus(); -} - -absl::StatusOr RenderGraphView( - const xla::HloComputation& computation, absl::string_view label, - const xla::DebugOptions& debug_options, xla::RenderedGraphFormat format, - xla::HloRenderOptions hlo_render_options) { - auto precheck_status = CheckPrecondition(format); - if (!precheck_status.ok()) { - return precheck_status; - } - auto rendered_dot = - xla::RenderGraph(computation, label, debug_options, - RenderedGraphFormat::kDot, hlo_render_options); - if (!rendered_dot.ok()) { - return rendered_dot.status(); - } - return WrapDotInFormat(rendered_dot.value(), format); -} - -absl::StatusOr RenderGraphNeighborhoodAround( - const xla::HloInstruction& node, int radius, - xla::RenderedGraphFormat format, xla::HloRenderOptions hlo_render_options, - const absl::flat_hash_set& boundary) { - auto precheck_status = CheckPrecondition(format); - if (!precheck_status.ok()) { - return precheck_status; - } - auto rendered_dot = xla::RenderNeighborhoodAround( - node, radius, RenderedGraphFormat::kDot, hlo_render_options, boundary); - if (!rendered_dot.ok()) { - return rendered_dot.status(); - } - return WrapDotInFormat(rendered_dot.value(), format); -} - -absl::StatusOr WrapDotInFormat(std::string dot, - xla::RenderedGraphFormat format) { - switch (format) { - case xla::RenderedGraphFormat::kUrl: - if (url_renderer == nullptr) { - return absl::InternalError("url_renderer is null"); - } - return (*url_renderer)(dot); - case xla::RenderedGraphFormat::kHtml: - return WrapDotInHtml(dot); - case xla::RenderedGraphFormat::kDot: - return std::string(dot); - } -} - -std::string WrapDotInHtml(std::string dot) { - return absl::StrReplaceAll(R"html( - - - - - - - - - -
- - - -)html", - { - {"$DOT", dot}, - }); -} - -void RegisterGraphvizURLRenderer( - std::function(absl::string_view)> renderer) { - if (url_renderer != nullptr) { - LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer. Last call " - "wins, but because order of initialization in C++ is " - "nondeterministic, this may not be what you want."; - } - delete url_renderer; - url_renderer = - new std::function(absl::string_view)>( - std::move(renderer)); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h deleted file mode 100644 index b3a3a7c45e1175..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_GRAPH_VIEW_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_GRAPH_VIEW_H_ - -#include -#include -#include - -#include "xla/service/hlo.pb.h" -#include "xla/service/hlo_graph_dumper.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/profiler/convert/tool_options.h" - -namespace tensorflow { -namespace profiler { - -// All the parameters for graph viewer. -struct GraphViewerParams { - // Whether to use GraphView or TxtView. - std::string type; - // Parameters for GraphView. - std::string node_name; - int graph_width; - xla::HloRenderOptions render_options; - xla::RenderedGraphFormat format; - // Parameters for TxtView. - bool verbose; - bool show_metadata; -}; - -// Return mapping from style key word to op names separated by comma. -// following hlo_graph_dumper styling -absl::StatusOr GetNodeStyles(); - -// Parse tool options to get the parameters for graph viewer. -absl::StatusOr ParseGraphViewerParams( - const ToolOptions& options); - -// Get graph render format. -xla::RenderedGraphFormat GetRenderFormat(const std::string& format_string); - -// Convert `hlo_proto` to GraphView with the provided render options. -absl::StatusOr ConvertHloProtoToGraph( - const xla::HloProto& hlo_proto, const std::string& node_name, - int graph_width, const xla::HloRenderOptions& render_options, - const xla::RenderedGraphFormat& format); - -// Convert `hlo_proto` to ModelExplorer Graph JSON data. -absl::StatusOr ConvertHloProtoToMeGraph( - const xla::HloProto& hlo_proto, const std::string& node_name, - int graph_width); - -// Render graph with the provided render options. -absl::StatusOr RenderGraphView( - const xla::HloComputation& computation, absl::string_view label, - const xla::DebugOptions& debug_options, xla::RenderedGraphFormat format, - xla::HloRenderOptions hlo_render_options = {}); - -// Render graph with centered node and depth -absl::StatusOr RenderGraphNeighborhoodAround( - const xla::HloInstruction& node, int radius, - xla::RenderedGraphFormat format, - xla::HloRenderOptions hlo_render_options = {}, - const absl::flat_hash_set& boundary = {}); - -// Convert `hlo_proto` to StringView. -absl::StatusOr ConvertHloProtoToStringView( - const xla::HloProto& hlo_proto, bool verbose, bool metadata); - -// Convert dot into certain format -absl::StatusOr WrapDotInFormat(std::string dot, - xla::RenderedGraphFormat format); - -// Convert dot into visual graph in html -std::string WrapDotInHtml(std::string dot); - -// Registers a function which implements RenderedGraphFormat::kUrl. -// The input to the function is dot, and the output should be a URL or an error. -// There can only be one active renderer, and the last call to this function -// wins. -void RegisterGraphvizURLRenderer( - std::function(absl::string_view dot)> renderer); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_GRAPH_VIEW_H_ diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view_test.cc b/tensorflow/core/profiler/convert/hlo_proto_to_graph_view_test.cc deleted file mode 100644 index b53ec03de2822e..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_proto_to_graph_view_test.cc +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h" - -#include - -#include -#include "xla/service/hlo_graph_dumper.h" -#include "xla/tsl/platform/status_matchers.h" -#include "xla/tsl/platform/statusor.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/convert/tool_options.h" -#include "tensorflow/core/protobuf/error_codes.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::testing::HasSubstr; -using ::tsl::testing::StatusIs; - -TEST(GraphViewerParamsTest, GraphType) { - // Default for graph type. - ToolOptions options1; - options1["type"] = "graph"; - TF_ASSERT_OK_AND_ASSIGN(GraphViewerParams params1, - ParseGraphViewerParams(options1)); - EXPECT_EQ(params1.type, "graph"); - EXPECT_EQ(params1.node_name, ""); - EXPECT_EQ(params1.graph_width, 3); - EXPECT_EQ(params1.render_options.show_backend_config, false); - EXPECT_EQ(params1.render_options.show_fusion_subcomputations, true); - EXPECT_EQ(params1.format, xla::RenderedGraphFormat::kUrl); - - // User defined options for graph type. - ToolOptions options2; - options2["type"] = "graph"; - options2["node_name"] = "fusion.111"; - options2["graph_width"] = 10; - options2["show_metadata"] = 1; - options2["merge_fusion"] = 1; - options2["format"] = "html"; - TF_ASSERT_OK_AND_ASSIGN(GraphViewerParams params2, - ParseGraphViewerParams(options2)); - EXPECT_EQ(params2.type, "graph"); - EXPECT_EQ(params2.node_name, "fusion.111"); - EXPECT_EQ(params2.graph_width, 10); - EXPECT_EQ(params2.render_options.show_backend_config, true); - EXPECT_EQ(params2.render_options.show_fusion_subcomputations, false); - EXPECT_EQ(params2.format, xla::RenderedGraphFormat::kHtml); -} - -TEST(GraphViewerParamsTest, ShortTxtType) { - // Default for short txt type. - ToolOptions options1; - options1["type"] = "short_txt"; - TF_ASSERT_OK_AND_ASSIGN(GraphViewerParams params1, - ParseGraphViewerParams(options1)); - EXPECT_EQ(params1.type, "short_txt"); - EXPECT_EQ(params1.verbose, false); - EXPECT_EQ(params1.show_metadata, false); - - // User defined options for short txt type. - ToolOptions options2; - options2["type"] = "short_txt"; - options2["show_metadata"] = 1; - TF_ASSERT_OK_AND_ASSIGN(GraphViewerParams params2, - ParseGraphViewerParams(options2)); - EXPECT_EQ(params2.type, "short_txt"); - EXPECT_EQ(params2.verbose, false); - EXPECT_EQ(params2.show_metadata, true); -} - -TEST(GraphViewerParamsTest, LongTxtType) { - // Default for long txt type. - ToolOptions options1; - options1["type"] = "long_txt"; - TF_ASSERT_OK_AND_ASSIGN(GraphViewerParams params1, - ParseGraphViewerParams(options1)); - EXPECT_EQ(params1.type, "long_txt"); - EXPECT_EQ(params1.verbose, true); - EXPECT_EQ(params1.show_metadata, false); - - // User defined options for long txt type. - ToolOptions options2; - options2["type"] = "long_txt"; - options2["show_metadata"] = 1; - TF_ASSERT_OK_AND_ASSIGN(GraphViewerParams params2, - ParseGraphViewerParams(options2)); - EXPECT_EQ(params2.type, "long_txt"); - EXPECT_EQ(params2.verbose, true); - EXPECT_EQ(params2.show_metadata, true); -} - -TEST(GraphViewerParamsTest, OtherTypes) { - ToolOptions options1; - EXPECT_THAT(ParseGraphViewerParams(options1), - StatusIs(error::INVALID_ARGUMENT, - HasSubstr("Graph viewer must provide a type option"))); - - ToolOptions options2; - options2["type"] = "abcd"; - EXPECT_THAT(ParseGraphViewerParams(options2), - StatusIs(error::INVALID_ARGUMENT, - HasSubstr("Unknown graph viewer type option: abcd"))); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc deleted file mode 100644 index cf4fce7aecde69..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc +++ /dev/null @@ -1,1108 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/layout_util.h" -#include "xla/service/hlo.pb.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/tsl/platform/errors.h" -#include "xla/xla_data.pb.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/memory_viewer_preprocess.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::xla::BufferAllocationProto; -using ::xla::HeapSimulatorTrace; -using ::xla::HloInstructionProto; -using ::xla::HloProto; -using ::xla::LayoutUtil; -using ::xla::LogicalBufferProto; -using ::xla::Shape; -using ::xla::ShapeUtil; - -Shape ResolveShapeIndex(const xla::ShapeProto& shape_proto, - absl::Span shape_index) { - if (shape_index.empty()) return Shape(shape_proto); - // Choosing the last subshape to maintain historical behavior. - int64_t i = shape_index.back(); - if (i >= shape_proto.tuple_shapes_size()) { - return Shape(shape_proto); - } - return Shape(shape_proto.tuple_shapes(i)); -} - -std::string ShapeDescription(const Shape& shape) { - return ShapeUtil::HumanStringWithLayout(shape); -} - -// A wrapper around ShapeUtil::ByteSizeOf that clears out the layout/padding, -// since that is considered in the ByteSizeOf calculation. -int64_t ShapeUnpaddedSize(Shape shape) { - // Ensure the layout has no padding by making it the default layout. - LayoutUtil::SetToDefaultLayout(&shape); - // Note: we make a simplifying assumption here that a "minimal" size for a - // tuple member would be the size of a `void*` -- there may be even fancier - // ways of doing things, but this should give a good enough approximation of - // what a minimal tuple size is. - return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); -} - -class BufferAllocationStruct { - public: - explicit BufferAllocationStruct(const BufferAllocationProto& proto) - : buffer_allocation_((proto)) {} - bool IsIndefinite() const { - return buffer_allocation_.is_thread_local() || - buffer_allocation_.is_entry_computation_parameter() || - buffer_allocation_.is_constant() || - buffer_allocation_.maybe_live_out(); - } - const BufferAllocationProto& proto() const { return buffer_allocation_; } - size_t size() const { return buffer_allocation_.size(); } - int64_t color() const { return buffer_allocation_.color(); } - int64_t index() const { return buffer_allocation_.index(); } - std::optional heap_simulator_trace_id() const { - return heap_simulator_trace_id_; - } - void set_heap_simulator_trace_id(int64_t id) { - heap_simulator_trace_id_ = id; - } - - // Get buffer allocation category. - std::string category() const { - if (buffer_allocation_.is_entry_computation_parameter()) { - return "Parameter"; - } else if (buffer_allocation_.maybe_live_out()) { - return "Output"; - } else if (buffer_allocation_.is_thread_local()) { - return "Thread-local"; - } else if (buffer_allocation_.is_constant()) { - return "Constant"; - } else { - return "Temporary"; - } - } - - std::string description() const { - return absl::StrFormat( - "buffer_allocation_id:%d\nsize:%d\nbuffer_counts:%d\n", - buffer_allocation_.index(), size(), buffer_allocation_.assigned_size()); - } - - private: - const BufferAllocationProto& buffer_allocation_; - std::optional heap_simulator_trace_id_; -}; - -struct LogicalBufferStruct { - LogicalBufferStruct(const LogicalBufferProto& p, - const BufferAllocationStruct& b, - const ::xla::HloInstructionProto& i, uint64_t offset) - : proto(p), - buffer_allocation(b), - hlo_instruction(i), - offset(offset), - shape(ResolveShapeIndex(hlo_instruction.shape(), - proto.defined_at().shape_index())) {} - - absl::string_view instruction_name() const { return hlo_instruction.name(); } - - int64_t color() const { return proto.color(); } - size_t size() const { return proto.size(); } - size_t unpadded_size() const { return ShapeUnpaddedSize(shape); } - - // reference counting related - int64_t inc() { - if (canonical_buffer) return canonical_buffer->inc(); - return ++ref_count; - } - int64_t dec() { - if (canonical_buffer) return canonical_buffer->dec(); - return --ref_count; - } - int64_t share_with(LogicalBufferStruct* buffer) { - canonical_buffer = buffer; - return canonical_buffer->inc(); - } - LogicalBufferStruct* get_canonical_buffer() { - return canonical_buffer ? canonical_buffer->get_canonical_buffer() : this; - } - - // Get the instruction name with shape index for a logical buffer. - std::string GetInstructionNameWithShapeIndex() const { - if (proto.defined_at().shape_index().empty()) { - return std::string(instruction_name()); - } else { - return absl::StrCat(instruction_name(), "{", - absl::StrJoin(proto.defined_at().shape_index(), ","), - "}"); - } - } - - std::string description() const { - return absl::StrFormat( - "buffer_id:%d\nhlo_op:%s\nshape:%s\nsize:%d\nunpadded_size:%d\n" - "offset:%d\nspan:(%lld,%lld)", - proto.id(), instruction_name(), ShapeDescription(shape), size(), - unpadded_size(), offset, span ? span->first : -1, - span ? span->second : -1); - } - - const LogicalBufferProto& proto; - const BufferAllocationStruct& buffer_allocation; - const ::xla::HloInstructionProto& hlo_instruction; - uint64_t offset; // within the buffer allocation; - // Span within the specific simulator trace. - std::optional> span; - xla::Shape shape; - int64_t ref_count = 0; - LogicalBufferStruct* canonical_buffer = nullptr; -}; - -// A wrapper of HLO BufferAssignment, with lookup maps for logical buffers and -// buffer allocations. -class HloProtoBufferWrapper { - public: - explicit HloProtoBufferWrapper(const ::xla::HloProto& hlo_proto) - : hlo_proto_(hlo_proto) { - Init(); - } - - // Get the heap simulator trace ID using memory color. - // If unable to find the heap simulator trace, return -1. - int64_t GetHeapSimulatorTraceId(const int64_t memory_color) const { - int64_t id = GetHeapSimulatorTraceIdFromBufferAllocationIndex(memory_color); - if (id != -1) { - return id; - } - return GetHeapSimulatorTraceIdFromEvents(memory_color); - } - - // Get the raw HLO proto. - const ::xla::HloProto& GetHloProto() const { return hlo_proto_; } - - std::vector GetBufferAllocations( - int64_t memory_color) const { - std::vector buffer_allocations; - for (const auto& iter : id_to_buffer_allocation_) { - if (iter.second->proto().color() != memory_color) continue; - buffer_allocations.push_back(iter.second.get()); - } - return buffer_allocations; - } - - LogicalBufferStruct* GetLogicalBuffer(int64_t logical_buffer_id) const { - if (!id_to_logical_buffer_.contains(logical_buffer_id)) { - LOG(DFATAL) << "logical_buffer_id " << logical_buffer_id << "not found."; - return nullptr; - } - return id_to_logical_buffer_.at(logical_buffer_id).get(); - } - - // Get the logical buffers with indefinite lifetime (excluding thread_local). - std::vector LogicalBuffersWithIndefiniteLifetime( - int64_t memory_color) const { - std::vector indefinite_logical_buffers; - - for (const auto& buffer_assignment : GetBufferAllocations(memory_color)) { - if (!buffer_assignment->IsIndefinite()) continue; - if (buffer_assignment->proto().is_thread_local()) continue; - // A indefinite buffer allocation will contain multiple logical buffers. - // None of them have a offset, and may have different size than the buffer - // allocation's size. In most cases, if not all cases, one of the logical - // buffer will have the size equal to buffer allocation's size. We will - // pick the biggest logical buffer. - const LogicalBufferStruct* best_logical_buffer = nullptr; - size_t best_size = 0; - for (const auto& assigned : buffer_assignment->proto().assigned()) { - const LogicalBufferStruct* logical_buffer_struct = - GetLogicalBuffer(assigned.logical_buffer_id()); - if (logical_buffer_struct == nullptr) continue; - if (logical_buffer_struct->size() > best_size) { - best_size = logical_buffer_struct->size(); - best_logical_buffer = logical_buffer_struct; - } - } - if (best_logical_buffer) { - indefinite_logical_buffers.push_back(best_logical_buffer); - } - } - return indefinite_logical_buffers; - } - - private: - // Initialize the mappings of logical buffers and buffer allocations. - void Init() { - // A mapping from name to HLO instruction. - absl::flat_hash_map - name_to_hlo; - absl::flat_hash_map - unique_id_to_hlo; - - for (const auto& computation : hlo_proto_.hlo_module().computations()) { - for (const auto& instruction : computation.instructions()) { - name_to_hlo[instruction.name()] = &instruction; - unique_id_to_hlo[instruction.id()] = &instruction; - } - } - - absl::flat_hash_map - id_to_logical_buffer_proto; - for (const auto& logical_buffer : - hlo_proto_.buffer_assignment().logical_buffers()) { - id_to_logical_buffer_proto[logical_buffer.id()] = &logical_buffer; - } - - for (const auto& buffer_allocation : - hlo_proto_.buffer_assignment().buffer_allocations()) { - auto& buffer_allocation_s = - id_to_buffer_allocation_[buffer_allocation.index()]; - buffer_allocation_s = - std::make_unique(buffer_allocation); - for (const auto& assigned : buffer_allocation.assigned()) { - const auto id = assigned.logical_buffer_id(); - if (!id_to_logical_buffer_proto.contains(id)) { - LOG(DFATAL) << "logical_buffer_id " << id << " not found."; - continue; - } - const auto* logical_buffer = id_to_logical_buffer_proto.at(id); - int64_t inst_id = logical_buffer->defined_at().instruction_id(); - if (!unique_id_to_hlo.contains(inst_id)) { - LOG(DFATAL) << "instruction_id " << inst_id << " not found."; - continue; - } - const auto* instruction = unique_id_to_hlo.at(inst_id); - id_to_logical_buffer_[id] = std::make_unique( - *logical_buffer, *buffer_allocation_s, *instruction, - assigned.offset()); - } - } - - const auto& heap_simulator_traces = - hlo_proto_.buffer_assignment().heap_simulator_traces(); - for (int64_t i = 0; i < heap_simulator_traces.size(); i++) { - // The trace's buffer_allocation_index is not trustful, so we are trying - // to obtain the buffer allocation index ourselves. - if (heap_simulator_traces[i].events().empty()) continue; - int logical_buffer_id = heap_simulator_traces[i].events(0).buffer_id(); - if (!id_to_logical_buffer_.contains(logical_buffer_id)) continue; - auto* logical_buffer = id_to_logical_buffer_[logical_buffer_id].get(); - auto buffer_allocation_index = logical_buffer->buffer_allocation.index(); - id_to_buffer_allocation_[buffer_allocation_index] - ->set_heap_simulator_trace_id(i); - } - } - - // From a list of heap simulator traces, identify the one that has the largest - // number of memory events with color . - int64_t GetHeapSimulatorTraceIdFromEvents(const int64_t memory_color) const { - int64_t best_index = -1; - int64_t best_event_count = 0; - for (int64_t i = 0; - i < hlo_proto_.buffer_assignment().heap_simulator_traces_size(); i++) { - const auto& heap_simulator_trace = - hlo_proto_.buffer_assignment().heap_simulator_traces(i); - int64_t event_count = 0; - for (const auto& event : heap_simulator_trace.events()) { - if (!id_to_logical_buffer_.contains(event.buffer_id())) { - LOG(DFATAL) << "buffer_id " << event.buffer_id() << "not found."; - continue; - } - const auto& logical_buffer = - id_to_logical_buffer_.at(event.buffer_id()); - if (logical_buffer->color() == memory_color) { - event_count++; - } - } - if (event_count > best_event_count) { - best_index = i; - best_event_count = event_count; - } - } - return best_index; - } - - // Tries to get heap simulator trace based on buffer_allocation_index. - int64_t GetHeapSimulatorTraceIdFromBufferAllocationIndex( - const int64_t memory_color) const { - auto buffer_allocations = GetBufferAllocations(memory_color); - for (const auto* buffer_allocation : buffer_allocations) { - if (buffer_allocation->IsIndefinite()) continue; - // TODO(xprof): handle multiple temporary buffer allocations for the same - // color. - if (buffer_allocation->heap_simulator_trace_id()) { - return *buffer_allocation->heap_simulator_trace_id(); - } - } - return -1; - } - - // Reference to the original HLO proto. - const ::xla::HloProto& hlo_proto_; - - // A mapping from logical buffer ID to logical buffer. - absl::flat_hash_map> - id_to_logical_buffer_; - - // A mapping from buffer allocation ID to BufferAllocationProto. - absl::flat_hash_map> - id_to_buffer_allocation_; -}; - -double BytesToMiB(int64_t bytes) { - return static_cast(bytes) / (1ULL << 20); -} - -HeapObject MakeHeapObjectCommon(std::string label, int32_t color, - int64_t logical_buffer_id, - int64_t logical_buffer_size_bytes, - int64_t unpadded_shape_bytes) { - HeapObject result; - result.set_numbered(color); - result.set_label(std::move(label)); - result.set_logical_buffer_id(logical_buffer_id); - result.set_logical_buffer_size_mib(BytesToMiB(logical_buffer_size_bytes)); - result.set_unpadded_shape_mib(BytesToMiB(unpadded_shape_bytes)); - return result; -} - -HeapObject MakeHeapObject(const LogicalBufferStruct& logical_buffer, - int32_t color) { - const HloInstructionProto& hlo_instruction = logical_buffer.hlo_instruction; - std::string shape_string = ShapeDescription(logical_buffer.shape); - std::string label = - absl::StrFormat("%s: %s # %s", logical_buffer.instruction_name(), - shape_string, hlo_instruction.metadata().op_name()); - HeapObject result = MakeHeapObjectCommon( - std::move(label), color, logical_buffer.proto.id(), logical_buffer.size(), - logical_buffer.unpadded_size()); - result.set_instruction_name( - logical_buffer.GetInstructionNameWithShapeIndex()); - result.set_group_name(logical_buffer.buffer_allocation.category()); - result.set_tf_op_name(hlo_instruction.metadata().op_name()); - result.set_shape_string(shape_string); - result.set_op_code(hlo_instruction.opcode()); - return result; -} - -BufferSpan MakeBufferSpan(int32 start, int32 limit) { - BufferSpan result; - result.set_start(start); - result.set_limit(limit); - return result; -} - -void Convert(const xla::BufferAllocationProto_Assigned& assigned, - const HloProtoBufferWrapper& wrapper, LogicalBuffer* result) { - result->set_id(assigned.logical_buffer_id()), - result->set_size_mib(BytesToMiB(assigned.size())); - const LogicalBufferStruct* logical_buffer = - wrapper.GetLogicalBuffer(assigned.logical_buffer_id()); - if (logical_buffer == nullptr) return; - result->set_hlo_name(std::string(logical_buffer->instruction_name())); - result->mutable_shape_index()->CopyFrom( - logical_buffer->proto.defined_at().shape_index()); - result->set_shape(ShapeDescription(logical_buffer->shape)); -} - -bool IsReusable(const BufferAllocationProto& buffer_allocation) { - return !buffer_allocation.is_thread_local() && !buffer_allocation.is_tuple(); -} - -void Convert(const BufferAllocationProto& proto, - const HloProtoBufferWrapper& wrapper, BufferAllocation* result) { - result->set_id(proto.index()); - result->set_size_mib(BytesToMiB(proto.size())); - if (proto.is_entry_computation_parameter()) { - result->add_attributes("entry computation parameter"); - } - if (proto.maybe_live_out()) { - result->add_attributes("may-be live out"); - } - if (IsReusable(proto)) { - result->add_attributes("reusable"); - } - for (const auto& assigned : proto.assigned()) { - Convert(assigned, wrapper, result->add_logical_buffers()); - } - // Check whether all logical buffers for this buffer allocation have a common - // shape. - if (!result->logical_buffers().empty()) { - std::string common_shape = result->logical_buffers(0).shape(); - for (int64_t i = 1; i < result->logical_buffers_size(); ++i) { - if (result->logical_buffers(i).shape() != common_shape) { - common_shape = ""; - break; - } - } - if (!common_shape.empty()) { - result->set_common_shape(common_shape); - } - } -} - -void NoteSpecialAllocations(const HloProtoBufferWrapper& wrapper, - int64_t memory_color, int64_t small_buffer_size, - PreprocessResult* result) { - int64_t entry_parameters_bytes = 0; - int64_t non_reusable_bytes = 0; - int64_t maybe_live_out_bytes = 0; - int64_t total_buffer_allocation_bytes = 0; - int64_t indefinite_buffer_allocation_bytes = 0; - for (const auto* buffer_allocation_struct : - wrapper.GetBufferAllocations(memory_color)) { - const auto& buffer_allocation = buffer_allocation_struct->proto(); - if (buffer_allocation.is_entry_computation_parameter()) { - entry_parameters_bytes += buffer_allocation.size(); - } - if (!IsReusable(buffer_allocation)) { - non_reusable_bytes += buffer_allocation.size(); - } - if (buffer_allocation.maybe_live_out()) { - if (buffer_allocation.size() > small_buffer_size) { - VLOG(1) << "Maybe live out buffer allocation: " - << buffer_allocation.size() - << " bytes :: " << buffer_allocation.ShortDebugString(); - } - maybe_live_out_bytes += buffer_allocation.size(); - } - if (buffer_allocation_struct->IsIndefinite()) { - indefinite_buffer_allocation_bytes += buffer_allocation.size(); - Convert(buffer_allocation, wrapper, result->add_indefinite_lifetimes()); - } - total_buffer_allocation_bytes += buffer_allocation.size(); - } - - result->set_entry_computation_parameters_mib( - BytesToMiB(entry_parameters_bytes)); - result->set_non_reusable_mib(BytesToMiB(non_reusable_bytes)); - result->set_maybe_live_out_mib(BytesToMiB(maybe_live_out_bytes)); - result->set_total_buffer_allocation_mib( - BytesToMiB(total_buffer_allocation_bytes)); - result->set_indefinite_buffer_allocation_mib( - BytesToMiB(indefinite_buffer_allocation_bytes)); -} - -// Memory usage statistics collected from heap simulator trace. -struct HeapSimulatorStats { - explicit HeapSimulatorStats(const HloProtoBufferWrapper& wrapper) - : wrapper(wrapper) {} - - void SetSimulatorTraceEventSize(int64_t size) { - simulator_trace_event_size = size; - } - - // Update stats for general simulator event. - void UpdateOnSimulatorEvent(const HeapSimulatorTrace::Event& event) { - // Update memory timelines and seen buffers. - heap_size_bytes_timeline.push_back(heap_size_bytes); - unpadded_heap_size_bytes_timeline.push_back(unpadded_heap_size_bytes); - hlo_instruction_name_timeline.push_back(event.instruction_name()); - const LogicalBufferStruct* logical_buffer = - wrapper.GetLogicalBuffer(event.buffer_id()); - if (logical_buffer == nullptr) return; - seen_logical_buffers.insert(logical_buffer); - seen_buffer_allocations.insert(&logical_buffer->buffer_allocation.proto()); - } - - // Update stats when memory usage increase. - void IncreaseMemoryUsage(LogicalBufferStruct* canonical_logical_buffer, - bool init_buffer_span) { - logical_buffers.push_back(canonical_logical_buffer->proto.id()); - heap_size_bytes += canonical_logical_buffer->size(); - unpadded_heap_size_bytes += canonical_logical_buffer->unpadded_size(); - - // Increase peak memory usage if needed. - int64_t prior_peak_heap_size_bytes = peak_heap_size_bytes; - peak_heap_size_bytes = std::max(peak_heap_size_bytes, heap_size_bytes); - if (prior_peak_heap_size_bytes != peak_heap_size_bytes) { - peak_heap_size_position = heap_size_bytes_timeline.size() - 1; - peak_unpadded_heap_size_bytes = unpadded_heap_size_bytes; - VLOG(1) << absl::StrFormat("New peak heap size on %d :: %d bytes", - peak_heap_size_position, peak_heap_size_bytes); - peak_logical_buffers = logical_buffers; - } - // Initialize the buffer lifespan if needed. - if (init_buffer_span) { - // Initialize the buffer span from the current event to the last event in - // heap simulator trace. - canonical_logical_buffer->span.emplace( - heap_size_bytes_timeline.size() - 1, simulator_trace_event_size - 1); - } - } - - // Update stats when memory usage decrease. - absl::Status DecreaseMemoryUsage( - LogicalBufferStruct* canonical_logical_buffer) { - int64_t canonical_buffer_id = canonical_logical_buffer->proto.id(); - logical_buffers.remove(canonical_buffer_id); - heap_size_bytes -= canonical_logical_buffer->size(); - if (heap_size_bytes < 0) { - return errors::InvalidArgument(absl::StrCat( - "Heap size should be non-negative, but get: ", heap_size_bytes)); - } - unpadded_heap_size_bytes -= canonical_logical_buffer->unpadded_size(); - // Mark the end of this buffer. - if (canonical_logical_buffer->span) { - canonical_logical_buffer->span->second = - heap_size_bytes_timeline.size() - 1; - } - return absl::OkStatus(); - } - - // Finalize the memory usage stats from heap simulator trace. - absl::Status FinalizeMemoryUsage() { - // Add the final heap size after simulating the entire heap trace. - heap_size_bytes_timeline.push_back(heap_size_bytes); - unpadded_heap_size_bytes_timeline.push_back(unpadded_heap_size_bytes); - // Add an empty instruction name just so that this array is the same size as - // the other two. - hlo_instruction_name_timeline.push_back(""); - - if (seen_buffer_allocations.size() != 1) { - return errors::InvalidArgument( - absl::StrCat("All heap simulation should work out of a single buffer " - "allocation, actual seen_buffer_allocations.size():", - seen_buffer_allocations.size())); - } - - // Log stats. - VLOG(1) << "Found " << peak_logical_buffers.size() - << " logical buffers alive at point of peak heap usage."; - - VLOG(1) << "Peak logical buffers: [" - << absl::StrJoin(peak_logical_buffers, ", ") << "]"; - - return absl::OkStatus(); - } - - // Keep track of memory usage when iterating through heap simulator trace - // events. - int64_t heap_size_bytes = 0; - int64_t unpadded_heap_size_bytes = 0; - // Memory usage at peak. - int64_t peak_heap_size_bytes = 0; - int64_t peak_unpadded_heap_size_bytes = 0; - - // Keep track of logical buffer IDs when iterating through heap simulator - // trace events. It is important this is in "program order", i.e. heap - // simulator's order. - std::list logical_buffers; - // Logical buffer IDs at peak. - std::list peak_logical_buffers; - - // Heap size timeline. - std::vector heap_size_bytes_timeline; - std::vector unpadded_heap_size_bytes_timeline; - std::vector hlo_instruction_name_timeline; - - // Position of peak memory usage in the timeline. - int64_t peak_heap_size_position = 0; - - // Logical buffers and buffer allocations that exists in heap simulator trace. - absl::flat_hash_set seen_logical_buffers; - absl::flat_hash_set seen_buffer_allocations; - - // Constants while iterating through heap simulator trace. - const HloProtoBufferWrapper& wrapper; - int64_t simulator_trace_event_size; -}; - -absl::Status ProcessHeapSimulatorTrace(const HloProtoBufferWrapper& wrapper, - const int64_t memory_color, - HeapSimulatorStats* stats) { - int64_t heap_simulator_trace_id = - wrapper.GetHeapSimulatorTraceId(memory_color); - - // If unable to get a valid heap simulator trace id, skip heap simulator - // trace and process the rest of the buffers. - if (heap_simulator_trace_id < 0 || - heap_simulator_trace_id >= wrapper.GetHloProto() - .buffer_assignment() - .heap_simulator_traces_size()) { - return absl::OkStatus(); - } - - // Run through all the simulator events in the given trace, and simulate the - // heap in order to find the point of peak memory usage and record its - // associated metadata. - const auto& trace = - wrapper.GetHloProto().buffer_assignment().heap_simulator_traces( - heap_simulator_trace_id); - - stats->SetSimulatorTraceEventSize(trace.events_size()); - for (const auto& event : trace.events()) { - stats->UpdateOnSimulatorEvent(event); - LogicalBufferStruct* logical_buffer = - wrapper.GetLogicalBuffer(event.buffer_id()); - if (logical_buffer == nullptr) { - continue; - } - if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { - // ALLOC event increases memory usage and initializes the buffer lifetime - // span. - logical_buffer->inc(); - stats->IncreaseMemoryUsage(logical_buffer, - /*init_buffer_span=*/true); - } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { - auto ref_count = logical_buffer->dec(); - if (ref_count < 0) { - return errors::InvalidArgument(absl::StrCat( - "Buffer ", logical_buffer->proto.id(), "is freed multiple times.")); - } - if (ref_count == 0) { - // There is no more reference to the canonical buffer, the canonical - // buffer is finally freed. Update memory usage and memory timespan - // using the metadata of canonical buffer. - auto& canonical_buffer = *logical_buffer->get_canonical_buffer(); - TF_RETURN_IF_ERROR(stats->DecreaseMemoryUsage(&canonical_buffer)); - } - } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { - int64_t canonical_buffer_id = event.share_with_canonical_id(); - LogicalBufferStruct* canonical_buffer = - wrapper.GetLogicalBuffer(canonical_buffer_id); - if (canonical_buffer == nullptr) { - continue; - } - auto ref_count = logical_buffer->share_with(canonical_buffer); - - if (ref_count == 1) { - // SHARE_WITH happens after the FREE of a canonical buffer. - // SHARE_WITH event does not initialize buffer lifetime span, it was - // initialized by ALLOC event using the canonical logical buffer. - stats->IncreaseMemoryUsage(canonical_buffer, - /*init_buffer_span=*/false); - } - } else { - return errors::InvalidArgument( - absl::StrCat("Unhandled event kind: ", event.kind())); - } - } - TF_RETURN_IF_ERROR(stats->FinalizeMemoryUsage()); - return absl::OkStatus(); -} - -// The stats when processing buffer allocations and logical buffers. -struct PeakUsageSnapshot { - PeakUsageSnapshot(const HloProtoBufferWrapper& wrapper, - const HeapSimulatorStats& simulator_stats, - int64_t small_buffer_size) - : wrapper(wrapper), - simulator_stats(simulator_stats), - small_buffer_size(small_buffer_size) {} - - // Add a HeapObject derived from logical buffer and buffer allocation. - void AddHeapObject(const LogicalBufferStruct& logical_buffer) { - if (logical_buffer.size() < small_buffer_size) { - // Accumulate small buffers, don't make a HeapObject. - total_small_buffer_size_bytes += logical_buffer.size(); - } else { - // Make a new HeapObject, assign a new color to visualize it. - max_heap_objects.push_back(MakeHeapObject(logical_buffer, colorno++)); - } - } - - void FinalizeBufferUsage() { - // Buffers from HeapSimulatorTrace. - for (const int64_t logical_buffer_id : - simulator_stats.peak_logical_buffers) { - const LogicalBufferStruct* logical_buffer = - wrapper.GetLogicalBuffer(logical_buffer_id); - if (logical_buffer == nullptr) return; - AddHeapObject(*logical_buffer); - } - - // Make a single HeapObject out of all the small buffers. - if (total_small_buffer_size_bytes != 0) { - max_heap_objects.push_back(MakeHeapObjectCommon( - absl::StrFormat("small (<%d bytes)", small_buffer_size), colorno++, - /*logical_buffer_id=*/-1, total_small_buffer_size_bytes, - /*unpadded_shape_bytes=*/0)); - } - } - - // All the HeapObjects at peak memory time. - std::vector max_heap_objects; - // The total size of all memory buffers with indefinite lifetime. - int64_t indefinite_memory_usage_bytes = 0; - // The accumulated size of all small buffers. - int64_t total_small_buffer_size_bytes = 0; - // Tracker of memory viewer color. - int32_t colorno = 0; - - const HloProtoBufferWrapper& wrapper; - const HeapSimulatorStats& simulator_stats; - const int64_t small_buffer_size; -}; - -void CreatePeakUsageSnapshot(const HloProtoBufferWrapper& wrapper, - int64_t memory_color, - PeakUsageSnapshot* peak_snapshot) { - // Add indefinite (global) buffers to peak usage snapshot. - for (const auto* logical_buffer : - wrapper.LogicalBuffersWithIndefiniteLifetime(memory_color)) { - const auto& buffer_allocation = logical_buffer->buffer_allocation; - peak_snapshot->indefinite_memory_usage_bytes += buffer_allocation.size(); - peak_snapshot->AddHeapObject(*logical_buffer); - } - - // Add temporary buffers (traced by heap simulator) to peak usage snapshot. - peak_snapshot->FinalizeBufferUsage(); -} - -void ConvertAllocationTimeline(const HloProtoBufferWrapper& wrapper, - const HeapSimulatorStats& simulator_stats, - const int64_t memory_color, - PreprocessResult* result) { - // The color constants from https://graphviz.org/doc/info/colors.html. - const char* lb_colors[] = { - "antiquewhite3", - "aqua", - "aquamarine", - "bisque", - "blanchedalmond", - "blue", - "blueviolet", - "brown", - "burlywood", - "cadetblue", - "chartreuse", - "chocolate", - "coral", - "cornflowerblue", - "crimson", - "cyan", - "darkblue", - "darkcyan", - "darkgoldenrod", - "darkgray", - "darkgreen", - "darkkhaki", - "darkmagenta", - "darkolivegreen", - "darkorange", - "darkorchid", - "darkred", - "darksalmon", - "darkseagreen", - "darkslateblue", - "darkslategray", - "darkturquoise", - "darkviolet", - "deeppink", - "deepskyblue", - "dimgray", - "dodgerblue", - "firebrick", - "floralwhite", - "forestgreen", - "fuchsia", - "gainsboro", - "gold", - "goldenrod", - "green", - "greenyellow", - "goldenrod", - "greenyellow", - "honeydew", - "hotpink", - "indianred", - "indigo", - "ivory3", - "khaki", - "lavender", - "lavenderblush", - "lawngreen", - "lemonchiffon", - "lightblue", - "lightcoral", - "lightcyan", - "lightpink", - "limegreen", - "lightsalmon", - "lightseagreen", - "lightskyblue", - "lime", - "magenta", - "maroon", - "mediumaquamarine", - "mediumblue", - "mediumorchid", - "mediumpurple", - "midnightblue", - "mediumvioletred", - "mistyrose", - "moccasin", - "olive", - "orange", - "orangered", - "orchid", - "palegoldenrod", - "palegreen", - "paleturquoise", - "palevioletred", - "papayawhip", - "peachpuff", - "peachpuff", - "pink", - "plum", - "powderblue", - "purple", - "rebeccapurple", - "red", - "rosybrown", - "royalblue", - "salmon", - "sandybrown", - "seagreen", - "seashell", - "sienna", - "skyblue", - "tan", - "teal", - "turquoise", - "tomato", - "violet", - "violetred", - "yellow", - }; - - struct RenderOptions { - size_t graph_width = 2048; - size_t graph_height = 2048; - } render_options; - - const char* ba_colors[] = { - "azure", - "beige", - "cornsilk", - }; - - int num_lb_colors = sizeof(lb_colors) / sizeof(lb_colors[0]); - int num_ba_colors = sizeof(ba_colors) / sizeof(ba_colors[0]); - std::vector buffer_allocation_offsets; - size_t total_y_size = 0; // Range of y dimension. - size_t total_x_size = 0; // Range of x dimension. - std::vector rects; - auto buffer_allocations = wrapper.GetBufferAllocations(memory_color); - const auto& heap_simulator_traces = - wrapper.GetHloProto().buffer_assignment().heap_simulator_traces(); - for (const auto& buffer_allocation : buffer_allocations) { - // Exclude BAs for "global variables". The timeline provides little value. - if (buffer_allocation->IsIndefinite()) continue; - auto heap_simulator_trace_id = buffer_allocation->heap_simulator_trace_id(); - if (!heap_simulator_trace_id) continue; - buffer_allocation_offsets.push_back(total_y_size); - total_y_size += buffer_allocation->size(); - if (*heap_simulator_trace_id >= heap_simulator_traces.size()) { - LOG(DFATAL) << "heap_simulator_trace_id " << *heap_simulator_trace_id - << " out of bounds."; - continue; - } - total_x_size = std::max( - total_x_size, - heap_simulator_traces.at(*heap_simulator_trace_id).events_size()); - } - if (!total_y_size || !total_x_size) return; - double scale_x = - static_cast(render_options.graph_width) / total_x_size; - double scale_y = - static_cast(render_options.graph_height) / total_y_size; - - int node_id = 0; - auto add_rect = [&](size_t x, size_t y, size_t width, size_t height, - const string& description, const char* color) { - size_t center_x = x + (width >> 1); - size_t center_y = y + (height >> 1); - int pos_x = center_x * scale_x; - int pos_y = center_y * scale_y; - int rect_w = width * scale_x; - int rect_h = height * scale_y; - // Skip when block size is smaller than half a pixel in output size. - if (height * scale_y < 0.5) return; - rect_h = std::max(rect_h, 1); // Rounding up. - std::string rect = absl::StrFormat( - R"("%d" [tooltip="%s", pos="%d,%d!", width="%d!", height="%d!", color=%s];)", - node_id++, description, pos_x, pos_y, rect_w, rect_h, color); - rects.push_back(rect); - }; - int buffer_id = 0; - for (const auto& buffer_allocation : buffer_allocations) { - // Exclude BAs for "global variables". The timeline provides little value. - if (buffer_allocation->IsIndefinite()) continue; - auto buffer_allocation_offset = buffer_allocation_offsets[buffer_id++]; - add_rect(0, buffer_allocation_offset, total_x_size, - buffer_allocation->size(), buffer_allocation->description(), - ba_colors[buffer_id % num_ba_colors]); - - for (const auto& assigned : buffer_allocation->proto().assigned()) { - const LogicalBufferStruct* logical_buffer = - wrapper.GetLogicalBuffer(assigned.logical_buffer_id()); - if (logical_buffer == nullptr) continue; - // Exclude non-canonical logical buffers. - if (!logical_buffer->span || logical_buffer->canonical_buffer) continue; - size_t width = logical_buffer->span->second - logical_buffer->span->first; - size_t height = buffer_allocation_offset + logical_buffer->size(); - add_rect(logical_buffer->span->first, logical_buffer->offset, width, - height, logical_buffer->description(), - lb_colors[node_id % num_lb_colors]); - } - } - VLOG(1) << "rects:" << rects.size(); - result->set_allocation_timeline( - absl::StrFormat("graph G {\n node [shape=box,style=filled];\n %s\n}", - absl::StrJoin(rects, "\n"))); -} - -void GeneratePreprocessResult(const HloProtoBufferWrapper& wrapper, - const HeapSimulatorStats& simulator_stats, - const PeakUsageSnapshot& peak_snapshot, - const int64_t memory_color, - PreprocessResult* result) { - // Module info. - result->set_module_name(wrapper.GetHloProto().hlo_module().name()); - result->set_entry_computation_name( - wrapper.GetHloProto().hlo_module().entry_computation_name()); - - // Build HeapObjects and index. - std::vector max_heap_by_size; - max_heap_by_size.reserve(peak_snapshot.max_heap_objects.size()); - for (const auto& object : peak_snapshot.max_heap_objects) { - max_heap_by_size.push_back(&object); - } - std::sort(max_heap_by_size.begin(), max_heap_by_size.end(), - [](const HeapObject* a, const HeapObject* b) { - return a->logical_buffer_size_mib() > - b->logical_buffer_size_mib(); - }); - - std::vector max_heap_to_by_size; - max_heap_to_by_size.reserve(max_heap_by_size.size()); - for (const auto& object : peak_snapshot.max_heap_objects) { - auto it = - std::find(max_heap_by_size.begin(), max_heap_by_size.end(), &object); - int index = std::distance(max_heap_by_size.begin(), it); - max_heap_to_by_size.push_back(index); - } - - std::vector by_size_to_max_heap; - for (const auto* object : max_heap_by_size) { - int index = object - &peak_snapshot.max_heap_objects[0]; - by_size_to_max_heap.push_back(index); - } - - *result->mutable_max_heap() = {peak_snapshot.max_heap_objects.begin(), - peak_snapshot.max_heap_objects.end()}; - result->mutable_max_heap_by_size()->Reserve(max_heap_by_size.size()); - for (const HeapObject* o : max_heap_by_size) { - *result->add_max_heap_by_size() = *o; - } - *result->mutable_max_heap_to_by_size() = {max_heap_to_by_size.begin(), - max_heap_to_by_size.end()}; - *result->mutable_by_size_to_max_heap() = {by_size_to_max_heap.begin(), - by_size_to_max_heap.end()}; - - // For the buffers that have indefinite lifetime (that is, lifetime not - // reflected by the heap simulation) add it to the peak values and the vectors - // of heap sizes. - size_t timeline_size = simulator_stats.heap_size_bytes_timeline.size(); - double add_mib = BytesToMiB(peak_snapshot.indefinite_memory_usage_bytes); - result->mutable_heap_sizes()->Reserve(timeline_size); - result->mutable_unpadded_heap_sizes()->Reserve(timeline_size); - for (size_t i = 0; i < timeline_size; i++) { - result->add_heap_sizes( - BytesToMiB(simulator_stats.heap_size_bytes_timeline[i]) + add_mib); - result->add_unpadded_heap_sizes( - BytesToMiB(simulator_stats.unpadded_heap_size_bytes_timeline[i]) + - add_mib); - result->add_hlo_instruction_names( - simulator_stats.hlo_instruction_name_timeline[i]); - } - - result->set_peak_heap_mib(BytesToMiB(simulator_stats.peak_heap_size_bytes) + - add_mib); - result->set_peak_unpadded_heap_mib( - BytesToMiB(simulator_stats.peak_unpadded_heap_size_bytes) + add_mib); - result->set_peak_heap_size_position(simulator_stats.peak_heap_size_position); - - // Build buffer lifespan. - for (const auto* logical_buffer : simulator_stats.seen_logical_buffers) { - if (!logical_buffer->span) continue; - (*result->mutable_logical_buffer_spans())[logical_buffer->proto.id()] = - MakeBufferSpan(logical_buffer->span->first, - logical_buffer->span->second); - } - - NoteSpecialAllocations(wrapper, memory_color, peak_snapshot.small_buffer_size, - result); - - ConvertAllocationTimeline(wrapper, simulator_stats, memory_color, result); -} - -} // namespace - -absl::StatusOr ConvertHloProtoToPreprocessResult( - const HloProto& hlo_proto, int64_t small_buffer_size, - int64_t memory_color) { - HloProtoBufferWrapper wrapper(hlo_proto); - - // Process heap simulator trace. - HeapSimulatorStats simulator_stats(wrapper); - auto status = - ProcessHeapSimulatorTrace(wrapper, memory_color, &simulator_stats); - if (!status.ok()) { - return absl::InvalidArgumentError(absl::StrCat( - "Failed to process heap simulator trace: ", status.message())); - } - - // Process buffers with indefinite lifetime. - PeakUsageSnapshot peak_snapshot(wrapper, simulator_stats, small_buffer_size); - CreatePeakUsageSnapshot(wrapper, memory_color, &peak_snapshot); - - PreprocessResult result; - GeneratePreprocessResult(wrapper, simulator_stats, peak_snapshot, - memory_color, &result); - return result; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h index e7a681de51c393..d5a6061f90187a 100644 --- a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h +++ b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h @@ -16,29 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_MEMORY_VISUALIZATION_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_MEMORY_VISUALIZATION_UTILS_H_ -#include - -#include "absl/status/statusor.h" -#include "xla/service/hlo.pb.h" -#include "tensorflow/core/profiler/protobuf/memory_viewer_preprocess.pb.h" - -namespace tensorflow { -namespace profiler { - -constexpr int kSmallBufferSize = 16 * 1024; - -// Convert HloProto to PreprocessResult proto for memory visualization. -// small_buffer_size sets the byte size within which we collapse buffer entries -// for the max-heap display. -// is the index of heap simulator trace to be -// displayed. By default it is -1, which means the profiler will infer the heap -// simulator trace id from . -// By default the memory color is 0, which is HBM. -absl::StatusOr ConvertHloProtoToPreprocessResult( - const xla::HloProto& hlo_proto, - int64_t small_buffer_size = kSmallBufferSize, int64_t memory_color = 0); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/convert/hlo_proto_to_memory_visualization_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_MEMORY_VISUALIZATION_UTILS_H_ diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils_test.cc b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils_test.cc deleted file mode 100644 index d92dea32152a36..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils_test.cc +++ /dev/null @@ -1,114 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h" - -#include - -#include "absl/strings/str_format.h" -#include "xla/service/hlo.pb.h" -#include "xla/tsl/platform/statusor.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/memory_viewer_preprocess.pb.h" -#include "tensorflow/core/util/proto/proto_utils.h" - -namespace tensorflow { -namespace profiler { -namespace { - -// 1 buffer allocation of 1MB -// 2 logical buffers, each is 0.5MB -static constexpr char kHLOBase[] = R"pb( - hlo_module { - name: "test_module" - entry_computation_name: "test_computation" - computations { - name: "test_computation" - instructions { - name: "fusion.1" - id: 0 - shape { tuple_shapes { element_type: U64 } } - } - instructions { - name: "fusion.2" - id: 1 - shape { tuple_shapes { element_type: U64 } } - } - } - } - buffer_assignment { - buffer_allocations { - index: 0 - size: 1048576 - color: 0 - assigned { logical_buffer_id: 1 offset: 0 size: 524288 } - assigned { logical_buffer_id: 2 offset: 524288 size: 524288 } - } - logical_buffers { - id: 1 - size: 524288 - color: 0 - defined_at { instruction_id: 0 shape_index: 0 } - } - logical_buffers { - id: 2 - size: 524288 - color: 0 - defined_at { instruction_id: 1 shape_index: 0 } - } - heap_simulator_traces { %s } - } -)pb"; - -TEST(MemoryViewerTest, TestHeapSimulatorTraceShareWith_1) { - // Allocate and then share, the memory usage is not doubled. - static constexpr char kHeapSimulatorTrace[] = R"pb( - events { kind: ALLOC buffer_id: 1 } - events { kind: SHARE_WITH buffer_id: 2 share_with_canonical_id: 1 } - events { kind: FREE buffer_id: 1 } - events { kind: FREE buffer_id: 2 } - )pb"; - std::string hlo_string = absl::StrFormat(kHLOBase, kHeapSimulatorTrace); - xla::HloProto hlo_proto; - ASSERT_TRUE( - proto_utils::ParseTextFormatFromString(hlo_string, &hlo_proto).ok()); - TF_ASSERT_OK_AND_ASSIGN( - PreprocessResult preprocess_result, - ConvertHloProtoToPreprocessResult(hlo_proto, /*small_buffer_size=*/0)); - EXPECT_EQ(preprocess_result.peak_heap_mib(), 0.5); -} - -TEST(MemoryViewerTest, TestHeapSimulatorTraceShareWith_2) { - // Allocate, free and then share, the memory usage is not doubled. - static constexpr char kHeapSimulatorTrace[] = R"pb( - events { kind: ALLOC buffer_id: 1 } - events { kind: FREE buffer_id: 1 } - events { kind: SHARE_WITH buffer_id: 2 share_with_canonical_id: 1 } - events { kind: FREE buffer_id: 2 } - )pb"; - std::string hlo_string = absl::StrFormat(kHLOBase, kHeapSimulatorTrace); - xla::HloProto hlo_proto; - ASSERT_TRUE( - proto_utils::ParseTextFormatFromString(hlo_string, &hlo_proto).ok()); - TF_ASSERT_OK_AND_ASSIGN( - PreprocessResult preprocess_result, - ConvertHloProtoToPreprocessResult(hlo_proto, /*small_buffer_size=*/0)); - EXPECT_EQ(preprocess_result.peak_heap_mib(), 0.5); - EXPECT_FALSE(preprocess_result.allocation_timeline().empty()); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/hlo_to_tools_data.cc b/tensorflow/core/profiler/convert/hlo_to_tools_data.cc deleted file mode 100644 index 0978f0211d4d8f..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_to_tools_data.cc +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/hlo_to_tools_data.h" - -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/service/hlo.pb.h" -#include "xla/tsl/platform/statusor.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h" -#include "tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/tool_options.h" -#include "tensorflow/core/profiler/convert/xplane_to_hlo.h" -#include "tensorflow/core/profiler/protobuf/memory_viewer_preprocess.pb.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -absl::StatusOr GetMemoryViewerPreprocessResult( - const xla::HloProto& hlo_proto) { - static constexpr int kSmallBufferSize = 16 * 1024; // 16KB - static constexpr int kMemorySpaceColor = 0; // HBM - - auto result_or = ConvertHloProtoToPreprocessResult( - hlo_proto, kSmallBufferSize, kMemorySpaceColor); - if (!result_or.ok()) { - return errors::Internal( - "Failed to convert HLO proto to memory viewer result: ", - result_or.status().message()); - } - return result_or; -} - -absl::StatusOr ConvertHloProtoToMemoryViewer( - const xla::HloProto& hlo_proto) { - auto result_or = GetMemoryViewerPreprocessResult(hlo_proto); - if (!result_or.ok()) { - return result_or.status(); - } - - std::string json_output; - tsl::protobuf::util::JsonPrintOptions options; - options.always_print_primitive_fields = true; - auto encoded_status = tsl::protobuf::util::MessageToJsonString( - result_or.value(), &json_output, options); - if (!encoded_status.ok()) { - const auto& error_message = encoded_status.message(); - return errors::Internal( - "Failed to convert memory viewer result to JSON format: ", - absl::string_view(error_message.data(), error_message.length())); - } - - return json_output; -} - -absl::StatusOr ConvertHloProtoToAllocationTimeline( - const xla::HloProto& hlo_proto) { - auto result_or = GetMemoryViewerPreprocessResult(hlo_proto); - if (!result_or.ok()) { - return result_or.status(); - } - - return WrapDotInHtml(std::move(result_or.value().allocation_timeline())); -} - -absl::StatusOr ConvertHloProtoToGraphViewer( - const xla::HloProto& hlo_proto, const ToolOptions& options) { - TF_ASSIGN_OR_RETURN(GraphViewerParams params, - ParseGraphViewerParams(options)); - if (params.type == "graph") { - return ConvertHloProtoToGraph(hlo_proto, params.node_name, - params.graph_width, params.render_options, - params.format); - } else { - return ConvertHloProtoToStringView(hlo_proto, params.verbose, - params.show_metadata); - } -} - -} // namespace - -absl::StatusOr ConvertHloProtoToToolData( - const SessionSnapshot& session_snapshot, const absl::string_view tool_name, - const ToolOptions& options) { - // must provide a hlo module_name field to identify the HLO module. - std::optional hlo_module_name = - GetParam(options, "module_name"); - if (!hlo_module_name.has_value() || hlo_module_name->empty()) { - return errors::InvalidArgument( - "Can not find HLO module name from options."); - } - - // Load HLO module from file. - TF_ASSIGN_OR_RETURN( - xla::HloProto hlo_proto, - GetHloProtoByModuleName(session_snapshot, *hlo_module_name)); - - // Convert from HLO proto to tools data. - if (tool_name == "memory_viewer") { - if (GetParamWithDefault(options, "view_memory_allocation_timeline", 0)) { - return ConvertHloProtoToAllocationTimeline(hlo_proto); - } - return ConvertHloProtoToMemoryViewer(hlo_proto); - } else if (tool_name == "graph_viewer") { - return ConvertHloProtoToGraphViewer(hlo_proto, options); - } else { - return errors::InvalidArgument( - "Can not find tool: ", tool_name, - ". Please update to the latest version of Tensorflow."); - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/hlo_to_tools_data.h b/tensorflow/core/profiler/convert/hlo_to_tools_data.h deleted file mode 100644 index b567c973382997..00000000000000 --- a/tensorflow/core/profiler/convert/hlo_to_tools_data.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_HLO_TO_TOOLS_DATA_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_HLO_TO_TOOLS_DATA_H_ - -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/tool_options.h" - -namespace tensorflow { -namespace profiler { - -// Convert HLO proto to tool specific data. -// must provide a "module_name" field to identify which HLO proto -// is used for the conversion. -// Return the serialized string of tool specific data when the conversion is -// successful, else return an error status. -absl::StatusOr ConvertHloProtoToToolData( - const SessionSnapshot& session_snapshot, absl::string_view tool_name, - const ToolOptions& options); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_HLO_TO_TOOLS_DATA_H_ diff --git a/tensorflow/core/profiler/convert/inference_stats.cc b/tensorflow/core/profiler/convert/inference_stats.cc deleted file mode 100644 index 25e87b31dab352..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats.cc +++ /dev/null @@ -1,1510 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/inference_stats.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/base/macros.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/match.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/tsl/platform/logging.h" -#include "xla/tsl/profiler/utils/device_utils.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tsl/platform/protobuf.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tensorflow::profiler::EventType; -using ::tensorflow::profiler::EventTypeSpan; -using ::tensorflow::profiler::StepEvents; -using ::tensorflow::profiler::ToNonOverlappedEvents; -using ::tsl::profiler::CreateTfXPlaneVisitor; -using ::tsl::profiler::DeviceType; -using ::tsl::profiler::GroupMetadata; -using ::tsl::profiler::GroupMetadataMap; -using ::tsl::profiler::HostEventType; -using ::tsl::profiler::StatType; -using ::tsl::profiler::Timespan; -using ::tsl::profiler::XEventVisitor; -using ::tsl::profiler::XLineVisitor; -using ::tsl::profiler::XPlane; -using ::tsl::profiler::XPlaneVisitor; -using ::tsl::profiler::XSpace; -using ::tsl::profiler::XStatVisitor; - -using EventsByType = - absl::flat_hash_map>; - -// Holds all the events within a user facing request. -// A user facing request can be a Session.Run without batching, or a -// BatchingSession.Run with Batching, or a Session.Run with -// BatchingFunctionOp. -struct RequestEvents { - // Index to the model id. - int32_t model_id_index; - // The timespan of the entire request(including both host and device). - Timespan request_timespan; - // The latency between a request is scheduled and is processed in a batch. - int64_t batching_request_delay_ps; - // Size of a request in batching mode. - int32_t batching_request_size; - - // Timestamps of the events used for the detailed execution time breakdown. - struct EventTimestamps { - std::optional ts_batch_schedule; - std::optional ts_batch_concat_input; - std::optional ts_tpu_execute; - std::optional ts_tpu_program_launch; - std::optional ts_tpu_complete_callback; - }; - // Mapping from group ID to the timestamps, there can be multiple group IDs - // in a single request, because if request splitting is enabled, one request - // can be split to multiple batches for execution, and each batch has - // different group ID. - absl::flat_hash_map timestamps; - - // The events that record tensor details like shape, type and layout. - std::vector tensor_events; - // The final tensor details in proto format. - std::vector - tensor_event_detail_protos; - - // The batch ids related to this request. - std::vector related_batch_ids; - // All the events. - std::vector events; -}; - -// Helper functions to handle absl::optional -void MinOfOptional(std::optional& min, std::optional value) { - if (!min.has_value()) - min = value; - else - min = std::min(min, value); -} -void MaxOfOptional(std::optional& max, std::optional value) { - if (!max.has_value()) - max = value; - else - max = std::max(max, value); -} - -// Helper functions to set timestamps in RequestEvents. -void UpdateTsBatchSchedule(int64_t group_id, int64_t value, - RequestEvents* events) { - events->timestamps[group_id].ts_batch_schedule = value; -} -void UpdateTsBatchConcatInput(int64_t group_id, int64_t value, - RequestEvents* events) { - events->timestamps[group_id].ts_batch_concat_input = value; -} -void UpdateTsTPUExecute(int64_t group_id, int64_t value, - RequestEvents* events) { - events->timestamps[group_id].ts_tpu_execute = value; -} -void UpdateTsTPUProgramLaunch(int64_t group_id, int64_t value, - RequestEvents* events) { - // There might be multiple TPUProgramLaunch events in a single request. - // Set ts_tpu_program_launch to the earlist timestamp. - MinOfOptional(events->timestamps[group_id].ts_tpu_program_launch, value); -} -void UpdateTsTPUCompleteCallback(int64_t group_id, int64_t value, - RequestEvents* events) { - events->timestamps[group_id].ts_tpu_complete_callback = value; -} - -// Map from the ID of a request to its events. -using RequestEventsMap = - absl::flat_hash_map; - -// An internal data structure that holds all the events within a batch. -struct BatchEvents { - // The events that record tensor details like shape, type and layout. - std::vector tensor_events; - - // The BatchDetail proto. - tensorflow::profiler::BatchDetail batch_detail_proto; - - // All the events. - std::vector events; -}; - -// Map from the ID of a batch to its events. -using BatchEventsMap = absl::flat_hash_map; - -// Map from the ID of a request to its model ID. -using ModelIdMap = absl::flat_hash_map; - -int32_t AssignIndexToModelId( - const std::string& model_id, - tensorflow::profiler::ModelIdDatabase* model_id_db) { - if (model_id.empty()) return -1; - auto [iter, inserted] = model_id_db->mutable_id_to_index()->insert( - {model_id, model_id_db->ids_size()}); - if (inserted) { - model_id_db->add_ids(model_id); - } - return iter->second; -} - -// Updates timestamps in RequestEvents. -// is the timestamp to update, is the updated value. -void UpdateEventTimestamps( - const GroupMetadataMap& group_metadata_map, int64_t group_id, int64_t value, - std::function function, - RequestEventsMap* request_events_map, - BatchEventsMap* batch_events_map = nullptr) { - // Update RequestEvents that are directly associated with . - if (request_events_map != nullptr) { - if (auto request_events = gtl::FindOrNull(*request_events_map, group_id)) { - function(group_id, value, request_events); - } - - // Update all the parent RequestEvents of . - const GroupMetadata* group_metadata = - gtl::FindOrNull(group_metadata_map, group_id); - if (!group_metadata) return; - for (const int64_t parent_group_id : group_metadata->parents) { - if (auto parent_request_events = - gtl::FindOrNull(*request_events_map, parent_group_id)) { - // Update parent events, but still use instead of - // , because xprof needs to track where these event - // timestamps originally come from. - function(group_id, value, parent_request_events); - } - } - } - // Note: Timestamp updates for batch analysis is not supported yet. -} - -void UpdateBatchEvents(const GroupMetadataMap& group_metadata_map, - absl::Span events, int64_t group_id, - BatchEventsMap* batch_events_map) { - // Update BatchEvents that are directly associated with . - if (auto batch_events = gtl::FindOrNull(*batch_events_map, group_id)) { - batch_events->events.insert(batch_events->events.end(), events.begin(), - events.end()); - } -} - -// Updates RequestEvents using ReadFromDevice, WriteToDevice and DeviceRun. -void UpdateRequestEvents(const GroupMetadataMap& group_metadata_map, - absl::Span events, - int64_t group_id, - RequestEventsMap* request_events_map) { - // Update RequestEvents that are directly associated with . - if (auto request_events = gtl::FindOrNull(*request_events_map, group_id)) { - request_events->events.insert(request_events->events.end(), events.begin(), - events.end()); - } - - // Update all the parent RequestEvents of with the same - // and . Parent RequestEvents are all the requests - // in a batch. - const GroupMetadata* group_metadata = - gtl::FindOrNull(group_metadata_map, group_id); - if (!group_metadata) return; - for (const int64_t parent_group_id : group_metadata->parents) { - if (auto parent_request_events = - gtl::FindOrNull(*request_events_map, parent_group_id)) { - parent_request_events->events.insert(parent_request_events->events.end(), - events.begin(), events.end()); - } - } -} - -// Initializes RequestEvents. -// determines whether this event is a -// BatchingSession.Run -void InitializeRequestEvents( - const XEventVisitor& event, const GroupMetadataMap& group_metadata_map, - const absl::flat_hash_set& process_batch_group_ids, - const ModelIdMap& model_id_map, bool is_batching_request, - bool is_user_defined_request, - tensorflow::profiler::ModelIdDatabase* model_id_db, - RequestEventsMap* request_events_map) { - std::optional optional_group_id = - event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) return; - int64_t group_id = optional_group_id->IntValue(); - - // If the event has ProcessBatch event as a parent, then do not consider - // it as a request. - if (process_batch_group_ids.contains(group_id)) return; - - RequestEvents& request_events = (*request_events_map)[group_id]; - const GroupMetadata* group_metadata = - gtl::FindOrNull(group_metadata_map, group_id); - if (!group_metadata) return; - // The children group_ids of a request are the batches related to this - // request. - for (const int64_t child_group_id : group_metadata->children) { - request_events.related_batch_ids.push_back(child_group_id); - } - // Sort related_batch_ids to get deterministic result. - absl::c_sort(request_events.related_batch_ids); - if (is_batching_request) { - // The children events of BatchingSession.Run are multiple Session.Run, - // use the first child event to initialize ModelId information, because - // all the children events should have the same ModelId. - if (group_metadata->children.empty()) return; - int64_t children_group_id = *group_metadata->children.begin(); - const std::string* children_model_id = - gtl::FindOrNull(model_id_map, children_group_id); - request_events.model_id_index = AssignIndexToModelId( - children_model_id ? *children_model_id : "", model_id_db); - } else if (is_user_defined_request) { - const std::string* model_id = gtl::FindOrNull(model_id_map, group_id); - if (model_id) { - request_events.model_id_index = - AssignIndexToModelId(*model_id, model_id_db); - } else { - // In some cases (e.g., BrainServer::Estimate), a single request might - // dispatch batches for multiple models. If all children events - // have the same ModelId, we assign that ModelId to the request. - if (group_metadata->children.empty()) return; - int32_t model_id_index_for_all_children = -1; - bool all_children_have_same_model_id = true; - for (int64_t children_group_id : group_metadata->children) { - const std::string* children_model_id = - gtl::FindOrNull(model_id_map, children_group_id); - int32_t child_model_id_index = AssignIndexToModelId( - children_model_id ? *children_model_id : "", model_id_db); - if (model_id_index_for_all_children == -1) { - model_id_index_for_all_children = child_model_id_index; - } else if (child_model_id_index != model_id_index_for_all_children) { - all_children_have_same_model_id = false; - } - } - request_events.model_id_index = - all_children_have_same_model_id - ? model_id_index_for_all_children - : AssignIndexToModelId("", model_id_db); - } - } else { - const std::string* model_id = gtl::FindOrNull(model_id_map, group_id); - request_events.model_id_index = - AssignIndexToModelId(model_id ? *model_id : "", model_id_db); - } -} - -// Set the begin and end timestamp of the request. -// The timespan of the request is marked by the earliest timestamp and latest -// timestamp of the events with the same group_id. -void UpdateRequestTimespan(const EventsByType& host_events_by_type, - RequestEventsMap* request_events_map) { - for (const auto& [_, events] : host_events_by_type) { - for (const auto& event : events) { - auto optional_group_id = event.GetStat(StatType::kGroupId); - if (optional_group_id.has_value()) { - if (RequestEvents* request = gtl::FindOrNull( - *request_events_map, optional_group_id->IntValue())) { - auto begin_ps = request->request_timespan.begin_ps() == 0 - ? event.GetTimespan().begin_ps() - : std::min(request->request_timespan.begin_ps(), - event.GetTimespan().begin_ps()); - auto end_ps = std::max(request->request_timespan.end_ps(), - event.GetTimespan().end_ps()); - request->request_timespan = Timespan::FromEndPoints(begin_ps, end_ps); - } - } - } - } -} - -// Update RequestEventsMap using data transfer events in tpu::system. -// Each data transfer is associated with a start event, an end event, and a -// transfer type (H2D or D2H). -void UpdateTpuDataTransferEventsInTpuSystem( - const EventsByType& host_events_by_type, - const GroupMetadataMap& group_metadata_map, - const HostEventType data_transfer_start_event, - const HostEventType data_transfer_end_event, - const EventType data_transfer_type, RequestEventsMap* request_events_map, - BatchEventsMap* batch_events_map) { - absl::flat_hash_map> - events_per_transfer; - - auto build_events = - [&](const HostEventType event_type, - std::function func) { - if (const auto* events = - gtl::FindOrNull(host_events_by_type, event_type)) { - for (const XEventVisitor& event : *events) { - std::optional optional_group_id = - event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - std::optional context_id = - event.GetStat(StatType::kConsumerId); - if (!context_id.has_value()) continue; - func(context_id->IntValue(), &event); - } - } - }; - - // Build start event. - build_events(data_transfer_start_event, - [&](uint64_t id, const XEventVisitor* start_event) { - events_per_transfer[id] = {start_event, nullptr}; - }); - - // Build end event. - // This only happens when the start event exists, the end event has the same - // group ID as the start event, and the end event timestamp is larger than - // start event timestamp. - build_events(data_transfer_end_event, - [&](uint64_t id, const XEventVisitor* end_event) { - if (auto* value = gtl::FindOrNull(events_per_transfer, id)) { - const XEventVisitor* start_event = value->at(0); - if (start_event->TimestampPs() < end_event->TimestampPs()) { - value->at(1) = end_event; - } - } - }); - - std::vector event_to_update = { - {data_transfer_type, Timespan(0, 0)}}; - for (const auto& [id, events] : events_per_transfer) { - if (events[0] != nullptr && events[1] != nullptr) { - // Duration of the data transfer is measured as the timespan between - // start and end events. - event_to_update[0].span = - Timespan(events[0]->TimestampPs(), - events[1]->EndTimestampPs() - events[0]->TimestampPs()); - if (request_events_map != nullptr) { - UpdateRequestEvents(group_metadata_map, event_to_update, - events[0]->GetStat(StatType::kGroupId)->IntValue(), - request_events_map); - } - if (batch_events_map != nullptr) { - UpdateBatchEvents(group_metadata_map, event_to_update, - events[0]->GetStat(StatType::kGroupId)->IntValue(), - batch_events_map); - } - } - } -} - -// Initializes device side events for TPU. -void BuildTPUDeviceEvents(const std::vector& device_traces, - const EventsByType& host_events_by_type, - const GroupMetadataMap& group_metadata_map, - RequestEventsMap* request_events_map, - BatchEventsMap* batch_events_map) { - static constexpr int64_t kDataTransferTypes[] = { - HostEventType::kReadHbm, HostEventType::kTransferD2HRequest, - HostEventType::kWriteHbm, HostEventType::kTransferH2DRequest, - HostEventType::kTransferPreprocessedH2DRequest}; - auto data_transfer_type_to_enum = [](const int64_t type) { - switch (type) { - case HostEventType::kReadHbm: - case HostEventType::kTransferD2HRequest: - return EventType::DEVICE_TO_HOST; - case HostEventType::kWriteHbm: - case HostEventType::kTransferH2DRequest: - case HostEventType::kTransferPreprocessedH2DRequest: - return EventType::HOST_TO_DEVICE; - default: - return EventType::UNKNOWN_TIME; - } - }; - - // Initialize a TPU device event for future updates. - // In order to reuse the same UpdateRequestEvents function with GPU device - // events, here we create a vector of size 1 for TPU event. - std::vector event_to_update = { - {EventType::UNKNOWN_TIME, Timespan(0, 0)}}; - - // Update RequestEventsMap using data transfer events. - for (const int64_t data_transfer_type : kDataTransferTypes) { - if (const auto* data_transfer_events = - gtl::FindOrNull(host_events_by_type, data_transfer_type)) { - for (const XEventVisitor& data_transfer_event : *data_transfer_events) { - std::optional optional_group_id = - data_transfer_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - event_to_update[0] = {data_transfer_type_to_enum(data_transfer_type), - data_transfer_event.GetTimespan()}; - if (request_events_map != nullptr) { - UpdateRequestEvents(group_metadata_map, event_to_update, group_id, - request_events_map); - } - if (batch_events_map != nullptr) { - UpdateBatchEvents(group_metadata_map, event_to_update, group_id, - batch_events_map); - } - } - } - } - - UpdateTpuDataTransferEventsInTpuSystem( - host_events_by_type, group_metadata_map, - HostEventType::kTransferToDeviceIssueEvent, - HostEventType::kTransferToDeviceDone, EventType::HOST_TO_DEVICE, - request_events_map, batch_events_map); - - UpdateTpuDataTransferEventsInTpuSystem( - host_events_by_type, group_metadata_map, - HostEventType::kTransferFromDeviceIssueEvent, - HostEventType::kTransferFromDeviceDone, EventType::DEVICE_TO_HOST, - request_events_map, batch_events_map); - - for (const XPlane* device_trace : device_traces) { - XPlaneVisitor device_plane = CreateTfXPlaneVisitor(device_trace); - device_plane.ForEachLine([request_events_map, batch_events_map, - &event_to_update, - &group_metadata_map](const XLineVisitor& line) { - if (line.Name() != tsl::profiler::kXlaModuleLineName) return; - line.ForEachEvent([request_events_map, batch_events_map, &event_to_update, - &group_metadata_map](const XEventVisitor& event) { - std::optional group_id = - event.GetStat(StatType::kGroupId); - if (!group_id) return; - // TPU compute does not specify 32bit or 16bit, use - // DEVICE_COMPUTE_32 to annotate this is a compute event. - event_to_update[0] = {EventType::DEVICE_COMPUTE_32, - event.GetTimespan()}; - if (request_events_map != nullptr) { - UpdateRequestEvents(group_metadata_map, event_to_update, - group_id->IntValue(), request_events_map); - } - if (batch_events_map != nullptr) { - UpdateBatchEvents(group_metadata_map, event_to_update, - group_id->IntValue(), batch_events_map); - } - }); - }); - } - - // Update timestamp for TPU execute event. It is used as the beginning of - // TPU runtime. For old TPU runtime, it is the TPUPartitionedCall events, - // for the new TPU runtime, it is the tpu::system::Execute event. There - // might be multiple TPU execute events in the same request, - // UpdateTsTPUExecute is implemented as getting the earlist timestamp of TPU - // execute event. - static constexpr int64_t kTPUExecuteTypes[] = { - HostEventType::kTpuPartitionedCallOpExecuteLocal, - HostEventType::kTpuPartitionedCallOpExecuteRemote, - HostEventType::kTpuPartitionedCallOpInitializeVarOnTpu, - HostEventType::kTpuSystemExecute}; - for (const int64_t tpu_execute_type : kTPUExecuteTypes) { - if (const auto* tpu_execute_events = - gtl::FindOrNull(host_events_by_type, tpu_execute_type)) { - for (const XEventVisitor& tpu_execute_event : *tpu_execute_events) { - std::optional optional_group_id = - tpu_execute_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - UpdateEventTimestamps( - group_metadata_map, group_id, tpu_execute_event.TimestampPs(), - UpdateTsTPUExecute, request_events_map, batch_events_map); - } - } - } - - // Update timestamp for TPU program launch events. This is used as the end - // of TPU runtime. Only one of the following program launch events will - // appear in a single profile. - static constexpr int64_t kTPUProgramLaunchTypes[] = { - HostEventType::kDoEnqueueProgram, - HostEventType::kDoEnqueueContinuationProgram}; - for (const int64_t tpu_program_launch_type : kTPUProgramLaunchTypes) { - if (const auto* tpu_program_launch_events = - gtl::FindOrNull(host_events_by_type, tpu_program_launch_type)) { - for (const XEventVisitor& tpu_program_launch_event : - *tpu_program_launch_events) { - std::optional optional_group_id = - tpu_program_launch_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - UpdateEventTimestamps(group_metadata_map, group_id, - tpu_program_launch_event.TimestampPs(), - UpdateTsTPUProgramLaunch, request_events_map, - batch_events_map); - } - } - } - - // Update timestamp for TPU complete callbacks. This is used as the start of - // host postprocessing. - if (const auto* tpu_complete_callback_events = gtl::FindOrNull( - host_events_by_type, HostEventType::kCompleteCallbacks)) { - for (const XEventVisitor& tpu_complete_callback_event : - *tpu_complete_callback_events) { - std::optional optional_group_id = - tpu_complete_callback_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - UpdateEventTimestamps(group_metadata_map, group_id, - tpu_complete_callback_event.TimestampPs(), - UpdateTsTPUCompleteCallback, request_events_map, - batch_events_map); - } - } -} - -// Initializes device side events for GPU. -void BuildGPUDeviceEvents(const StepEvents& nonoverlapped_step_events, - const GroupMetadataMap& group_metadata_map, - RequestEventsMap* request_events_map, - BatchEventsMap* batch_events_map) { - if (request_events_map != nullptr) { - for (const auto& [step_id, step_details] : nonoverlapped_step_events) { - UpdateRequestEvents(group_metadata_map, step_details.Events(), step_id, - request_events_map); - } - } - if (batch_events_map != nullptr) { - for (const auto& [step_id, step_details] : nonoverlapped_step_events) { - UpdateBatchEvents(group_metadata_map, step_details.Events(), step_id, - batch_events_map); - } - } -} - -// Initialize the mapping from group_id to model_id. Skip the event if it -// doesn't have group_id or model_id. -ModelIdMap InitializeModelIdMap( - const EventsByType& host_events_by_type, - const std::vector& user_defined_root_events) { - ModelIdMap model_id_map; - - // Helper function to process model id. - auto process_model_id = [&](const XEventVisitor& event) { - auto group_id = event.GetStat(StatType::kGroupId); - if (!group_id.has_value()) return; - std::optional model_id = event.GetStat(StatType::kModelId); - if (!model_id.has_value()) return; - model_id_map[group_id->IntValue()] = model_id->ToString(); - }; - - static constexpr int64_t kModelIdRequestTypes[] = { - HostEventType::kSessionRun, HostEventType::kTfrtModelRun, - HostEventType::kServingModelRun}; - for (const int64_t event_type : kModelIdRequestTypes) { - auto event_list = gtl::FindOrNull(host_events_by_type, event_type); - if (!event_list) continue; - for (const XEventVisitor& event : *event_list) { - process_model_id(event); - } - } - - for (const XEventVisitor* event : user_defined_root_events) { - process_model_id(*event); - } - - return model_id_map; -} - -// Builds a request_events_map from the given trace events. -void BuildRequestEventsMap(const std::vector& device_traces, - const EventsByType& host_events_by_type, - const GroupMetadataMap& group_metadata_map, - const StepEvents& nonoverlapped_step_events, - DeviceType device_type, - tensorflow::profiler::ModelIdDatabase* model_id_db, - RequestEventsMap* request_events_map) { - static constexpr int64_t kBatchingRequestTypes[] = { - HostEventType::kBatchingSessionRun}; - static constexpr int64_t kNonBatchingRequestTypes[] = { - HostEventType::kSessionRun, HostEventType::kRunGraph}; - // TODO(wffw): Merge them once go/pathways-tfrt-serving-unification is done. - static constexpr int64_t kTfrtRequestTypes[] = {HostEventType::kTfrtModelRun}; - static constexpr int64_t kPathwayRequestTypes[] = { - HostEventType::kServingModelRun}; - - static constexpr int64_t kScheduleEventTypes[] = { - HostEventType::kScheduleWithSplit, HostEventType::kScheduleWithoutSplit, - HostEventType::kScheduleWithEagerSplit, - HostEventType::kASBSQueueSchedule}; - - // Events marked with "_r:-1" are user defined root events. - std::vector user_defined_root_events; - for (const auto& [_, events] : host_events_by_type) { - for (const auto& event : events) { - std::optional stat = event.GetStat(StatType::kIsRoot); - if (stat.has_value() && stat->IntValue() == -1) { - user_defined_root_events.push_back(&event); - } - } - } - - // Group IDs of ProcessBatch events. - absl::flat_hash_set process_batch_group_ids; - if (const auto* process_batch_events = - gtl::FindOrNull(host_events_by_type, HostEventType::kProcessBatch)) { - for (const XEventVisitor& process_batch_event : *process_batch_events) { - std::optional optional_group_id = - process_batch_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - process_batch_group_ids.insert(optional_group_id->IntValue()); - } - } - - ModelIdMap model_id_map = - InitializeModelIdMap(host_events_by_type, user_defined_root_events); - - // Initialize RequestEventsMap. - bool is_batching_request = - host_events_by_type.contains(HostEventType::kBatchingSessionRun); - bool is_tfrt_request = - host_events_by_type.contains(HostEventType::kTfrtModelRun); - // TODO(wffw): Merge them once go/pathways-tfrt-serving-unification is done. - bool is_pathway_request = - host_events_by_type.contains(HostEventType::kServingModelRun); - absl::Span request_types; - if (is_batching_request) { - request_types = absl::Span(kBatchingRequestTypes); - } else if (is_tfrt_request) { - request_types = absl::Span(kTfrtRequestTypes); - } else if (is_pathway_request) { - request_types = absl::Span(kPathwayRequestTypes); - } else { - request_types = absl::Span(kNonBatchingRequestTypes); - } - for (const int64_t request_type : request_types) { - if (const auto* request_events = - gtl::FindOrNull(host_events_by_type, request_type)) { - for (const XEventVisitor& request_event : *request_events) { - InitializeRequestEvents(request_event, group_metadata_map, - process_batch_group_ids, model_id_map, - is_batching_request, - /* is_user_defined_request=*/false, model_id_db, - request_events_map); - } - } - } - - for (const XEventVisitor* event : user_defined_root_events) { - InitializeRequestEvents( - *event, group_metadata_map, process_batch_group_ids, model_id_map, - /*is_batching_request=*/false, - /* is_user_defined_request=*/true, model_id_db, request_events_map); - } - - // Set the begin and end timestamp of the request. - UpdateRequestTimespan(host_events_by_type, request_events_map); - - // Update RequestEventsMap using the request size in schedule event. - for (const int64_t schedule_type : kScheduleEventTypes) { - if (const auto* schedule_events = - gtl::FindOrNull(host_events_by_type, schedule_type)) { - for (const XEventVisitor& schedule_event : *schedule_events) { - std::optional optional_group_id = - schedule_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - // Update timestamp for schedule events. It is used as the beginning - // of batch formation. - UpdateEventTimestamps(group_metadata_map, group_id, - schedule_event.TimestampPs(), - UpdateTsBatchSchedule, request_events_map); - if (auto* request_events = - gtl::FindOrNull(*request_events_map, group_id)) { - std::optional batching_request_size = - schedule_event.GetStat(StatType::kBatchingInputTaskSize); - if (!batching_request_size.has_value()) continue; - request_events->batching_request_size = - batching_request_size->IntValue(); - } - } - } - } - - if (device_type == DeviceType::kTpu) { - BuildTPUDeviceEvents(device_traces, host_events_by_type, group_metadata_map, - request_events_map, nullptr); - } else if (device_type == DeviceType::kGpu) { - BuildGPUDeviceEvents(nonoverlapped_step_events, group_metadata_map, - request_events_map, nullptr); - } -} - -// Extracts batch details from . -void BuildBatchEventsMap(const std::vector& device_traces, - const EventsByType& host_events_by_type, - const GroupMetadataMap& group_metadata_map, - const StepEvents& nonoverlapped_step_events, - DeviceType device_type, - RequestEventsMap* request_events_map, - BatchEventsMap* batch_events_map) { - // Initialize BatchDetails from ProcessBatch events. - if (const auto* process_batch_events = - gtl::FindOrNull(host_events_by_type, HostEventType::kProcessBatch)) { - for (const XEventVisitor& process_batch_event : *process_batch_events) { - std::optional optional_group_id = - process_batch_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - const GroupMetadata* group_metadata = - gtl::FindOrNull(group_metadata_map, group_id); - if (!group_metadata) continue; - BatchEvents& batch_events = (*batch_events_map)[group_id]; - tensorflow::profiler::BatchDetail& batch_detail = - batch_events.batch_detail_proto; - batch_detail.set_batch_id(group_id); - batch_detail.set_start_time_ps(process_batch_event.TimestampPs()); - batch_detail.set_end_time_ps(process_batch_event.EndTimestampPs()); - // The parent group_ids of a batch are the requests related to this - // batch. - for (const int64_t parent_group_id : group_metadata->parents) { - batch_detail.add_related_request_ids(parent_group_id); - } - // Sort related_request_ids to get deterministic result. - std::sort(batch_detail.mutable_related_request_ids()->begin(), - batch_detail.mutable_related_request_ids()->end()); - } - } - - // Update BatchDetailsMap with padding information. Only one of - // ConcatInputTensors (for in-graph batching) or MergeInputTensors (for - // BatchingSession), or BrainSessionRun will appear in the - // same profile. - static constexpr int64_t kPaddingEventTypes[] = { - HostEventType::kConcatInputTensors, - HostEventType::kMergeInputTensors, - HostEventType::kBrainSessionRun, - }; - for (const int64_t padding_event_type : kPaddingEventTypes) { - if (const auto* padding_events = - gtl::FindOrNull(host_events_by_type, padding_event_type)) { - for (const XEventVisitor& padding_event : *padding_events) { - // Update timestamp for padding events. They are used as the - // beginning of batch processing. - std::optional optional_group_id = - padding_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - UpdateEventTimestamps(group_metadata_map, group_id, - padding_event.TimestampPs(), - UpdateTsBatchConcatInput, request_events_map); - BatchEvents* batch_events = - gtl::FindOrNull(*batch_events_map, group_id); - if (!batch_events) continue; - std::optional padding_amount = - padding_event.GetStat(StatType::kPaddingAmount); - if (!padding_amount.has_value()) continue; - std::optional batch_size_after_padding = - padding_event.GetStat(StatType::kBatchSizeAfterPadding); - if (!batch_size_after_padding.has_value()) continue; - tensorflow::profiler::BatchDetail* batch_detail = - &batch_events->batch_detail_proto; - batch_detail->set_batch_size_after_padding( - batch_size_after_padding->IntValue()); - batch_detail->set_padding_amount(padding_amount->IntValue()); - } - } - } - - // Populate BatchDetailsMap with model_id information from the corresponding - // requests in RequestEventsMap. - for (auto& [batch_id, batch_events] : *batch_events_map) { - tensorflow::profiler::BatchDetail& batch_detail = - batch_events.batch_detail_proto; - if (!batch_detail.related_request_ids().empty()) { - // Set the model_id of a batch using the model_id of the corresponding - // request. All requests in the same batch must share the same model_id, - // so we can pick any request in the batch here. - int32_t first_request_id = batch_detail.related_request_ids(0); - const RequestEvents* request_events = - gtl::FindOrNull(*request_events_map, first_request_id); - if (request_events) { - batch_detail.set_model_id_index(request_events->model_id_index); - } - } - } - - if (device_type == DeviceType::kTpu) { - BuildTPUDeviceEvents(device_traces, host_events_by_type, group_metadata_map, - nullptr, batch_events_map); - } else if (device_type == DeviceType::kGpu) { - BuildGPUDeviceEvents(nonoverlapped_step_events, group_metadata_map, nullptr, - batch_events_map); - } -} - -// Calculates the delay between request and batch. -void GenerateRequestAndBatchDelay(RequestEventsMap* request_events_map, - BatchEventsMap* batch_events_map) { - for (auto& [request_id, request_event] : *request_events_map) { - const tensorflow::profiler::BatchDetail* first_batch_detail = nullptr; - const tensorflow::profiler::BatchDetail* last_batch_detail = nullptr; - // For each request, measure the latency between the request and the first - // batch that processes this request. - for (const int64_t batch_id : request_event.related_batch_ids) { - const auto* batch_events = gtl::FindOrNull(*batch_events_map, batch_id); - if (!batch_events) continue; - const tensorflow::profiler::BatchDetail* batch_detail = - &batch_events->batch_detail_proto; - if (!first_batch_detail || (first_batch_detail->has_start_time_ps() > - batch_detail->has_start_time_ps())) { - first_batch_detail = batch_detail; - } - if (!last_batch_detail || (last_batch_detail->has_end_time_ps() < - batch_detail->has_end_time_ps())) { - last_batch_detail = batch_detail; - } - } - if (first_batch_detail) { - request_event.batching_request_delay_ps = - first_batch_detail->start_time_ps() - - request_event.request_timespan.begin_ps(); - } - if (last_batch_detail && request_event.request_timespan.end_ps() < - last_batch_detail->end_time_ps()) { - request_event.request_timespan = - Timespan::FromEndPoints(request_event.request_timespan.begin_ps(), - last_batch_detail->end_time_ps()); - } - } - - for (auto& [batch_id, batch_events] : *batch_events_map) { - const RequestEvents* first_request_events = nullptr; - tensorflow::profiler::BatchDetail& batch_detail = - batch_events.batch_detail_proto; - // For each batch, measure the latency between the first request in this - // batch and the start time of this batch. - for (const int64_t request_id : batch_detail.related_request_ids()) { - const auto* request_events = - gtl::FindOrNull(*request_events_map, request_id); - if (!request_events) continue; - if (!first_request_events || - (first_request_events->request_timespan.begin_ps() > - request_events->request_timespan.begin_ps())) { - first_request_events = request_events; - } - } - if (first_request_events) { - batch_detail.set_batch_delay_ps( - batch_detail.start_time_ps() - - first_request_events->request_timespan.begin_ps()); - } - } -} - -// Generates detailed breakdown for a request by generating events using the -// timestamps in RequestEvents. -void GenerateRequestDetailedBreakdown(RequestEventsMap* request_events_map) { - for (auto& [_, request] : *request_events_map) { - std::optional first_tpu_execute; - std::optional first_batch_concat_input; - std::optional last_tpu_complete_callback; - std::optional only_batch_schedule; - for (const auto& [group_id, timestamps] : request.timestamps) { - if (timestamps.ts_tpu_execute.has_value()) { - MinOfOptional(first_tpu_execute, timestamps.ts_tpu_execute); - - // Host runtime: From the start of TPU execute event to the start of - // TPU program launch. Because of request splitting, there can be - // multiple host runtime in a single request, one for each batch. - if (timestamps.ts_tpu_program_launch.has_value()) { - request.events.push_back( - {EventType::HOST_RUNTIME, - Timespan::FromEndPoints( - timestamps.ts_tpu_execute.value(), - timestamps.ts_tpu_program_launch.value())}); - } - } - - if (timestamps.ts_batch_concat_input.has_value()) { - MinOfOptional(first_batch_concat_input, - timestamps.ts_batch_concat_input); - } - - if (timestamps.ts_tpu_complete_callback.has_value()) { - MaxOfOptional(last_tpu_complete_callback, - timestamps.ts_tpu_complete_callback); - } - - if (timestamps.ts_batch_schedule.has_value()) { - if (only_batch_schedule.has_value()) { - LOG(ERROR) << "Found multiple batch schedule events in a single " - << "request."; - } else { - only_batch_schedule = timestamps.ts_batch_schedule; - } - } - } - - // Host preprocessing: From the start of the request to the start of the - // first execute event. There is only one host preprocess even if there - // are multiple batches caused by request splitting. - if (first_tpu_execute.has_value()) { - request.events.push_back( - {EventType::HOST_PREPROCESS, - Timespan::FromEndPoints(request.request_timespan.begin_ps(), - first_tpu_execute.value())}); - } - - // Host postprocessing: If there are CompleteCallback events for this - // request, use the last CompleteCallback event as the beginning of host - // postprocessing. Else, use the end time of the last TPU device compute - // events. There is only one host postprocessing even if there are - // multiple batches caused by request splitting. - if (last_tpu_complete_callback.has_value()) { - request.events.push_back( - {EventType::HOST_POSTPROCESS, - Timespan::FromEndPoints(last_tpu_complete_callback.value(), - request.request_timespan.end_ps())}); - } else { - // Get the latest end time of TPU device compute events. - // These events are annotated with type DEVICE_COMPUTE_32. - // TODO(tianrun): Deprecate this code path after CompleteCallback is - // enabled in all Tensorflow binaries. - uint64_t device_compute_end = 0; - for (const auto& event : request.events) { - if (event.type == EventType::DEVICE_COMPUTE_32) { - device_compute_end = - std::max(device_compute_end, event.span.end_ps()); - } - } - if (device_compute_end != 0) { - request.events.push_back( - {EventType::HOST_POSTPROCESS, - Timespan::FromEndPoints(device_compute_end, - request.request_timespan.end_ps())}); - } - } - - // Batch formation: From the start of batch schedule, to the start of the - // first concat input. This is only applicable when batching is enabled, - // and it overlaps with host preprocessing. - if (only_batch_schedule.has_value() && - first_batch_concat_input.has_value()) { - request.events.push_back( - {EventType::HOST_BATCH_FORMATION, - Timespan::FromEndPoints(only_batch_schedule.value(), - first_batch_concat_input.value())}); - } - } -} - -// Generates tensor patterns from tensor related EventNodes. -// If there is any error during the generation, return an empty string. -std::string GenerateTensorPattern( - const std::vector& tensor_events) { - // Generate one sub pattern for each tensor event, the sub pattern records - // the tensor shape, type, and layout. - std::vector sub_patterns; - sub_patterns.reserve(tensor_events.size()); - for (const XEventVisitor* tensor_event : tensor_events) { - std::optional shape = - tensor_event->GetStat(StatType::kTensorShapes); - if (!shape.has_value()) return ""; - std::optional layout = - tensor_event->GetStat(StatType::kTensorLayout); - if (!layout.has_value()) return ""; - sub_patterns.push_back(absl::StrCat(tensor_event->Name(), " ", - shape->StrOrRefValue(), " ", - layout->StrOrRefValue())); - } - // Sort the sub patterns to get a deterministic result. - std::sort(sub_patterns.begin(), sub_patterns.end()); - // The final tensor pattern is generated as the concatenation of all sub - // patterns. Use
as separator so it can be displayed properly in - // frontend. - return absl::StrJoin(sub_patterns, "
"); -} - -// Generates the total time spent on linearize and delinearize tensors. -uint64_t GenerateTensorLinearizeDelinearizeTime( - const std::vector& tensor_events) { - uint64_t result = 0; - for (const XEventVisitor* tensor_event : tensor_events) { - result += tensor_event->DurationPs(); - } - return result; -} - -// Generates the details related to tensor shape, type, and layout. -void GenerateTensorDetails( - const EventsByType& host_events_by_type, - RequestEventsMap* request_events_map, BatchEventsMap* batch_events_map, - tensorflow::profiler::InferenceStats* inference_stats) { - static constexpr int64_t kTensorDetailEventTypes[] = { - HostEventType::kLinearize, HostEventType::kDelinearize, - HostEventType::kTransferBufferFromDeviceFastPath}; - - for (const int64_t tensor_detail_event_type : kTensorDetailEventTypes) { - if (const auto* tensor_detail_events = - gtl::FindOrNull(host_events_by_type, tensor_detail_event_type)) { - for (const XEventVisitor& tensor_detail_event : *tensor_detail_events) { - std::optional optional_group_id = - tensor_detail_event.GetStat(StatType::kGroupId); - if (!optional_group_id.has_value()) continue; - int64_t group_id = optional_group_id->IntValue(); - // Add events to corresponding requests and batches. - if (auto* request_events = - gtl::FindOrNull(*request_events_map, group_id)) { - request_events->tensor_events.push_back(&tensor_detail_event); - } else if (auto* batch_events = - gtl::FindOrNull(*batch_events_map, group_id)) { - batch_events->tensor_events.push_back(&tensor_detail_event); - } - } - } - } - - absl::flat_hash_map tensor_patterns; - auto get_tensor_pattern_index = - [&tensor_patterns](const std::string& tensor_pattern) { - if (int* index = gtl::FindOrNull(tensor_patterns, tensor_pattern)) { - return *index; - } - int index = tensor_patterns.size(); - tensor_patterns.insert(std::make_pair(tensor_pattern, index)); - return index; - }; - - // Generates the tensor details that are owned by request. - for (auto& [group_id, request_events] : *request_events_map) { - if (request_events.tensor_events.empty()) continue; - std::string tensor_pattern = - GenerateTensorPattern(request_events.tensor_events); - if (tensor_pattern.empty()) continue; - int index = get_tensor_pattern_index(tensor_pattern); - tensorflow::profiler::TensorEventDetail tensor_event_detail; - tensor_event_detail.set_tensor_pattern_index(index); - tensor_event_detail.set_owner( - tensorflow::profiler::TensorEventDetail::REQUEST); - tensor_event_detail.set_linearize_delinearize_time_ps( - GenerateTensorLinearizeDelinearizeTime(request_events.tensor_events)); - request_events.tensor_event_detail_protos.push_back( - std::move(tensor_event_detail)); - } - - // Generates the tensor details that are owned by batch. - for (auto& [group_id, batch_events] : *batch_events_map) { - if (batch_events.tensor_events.empty()) continue; - std::string tensor_pattern = - GenerateTensorPattern(batch_events.tensor_events); - if (tensor_pattern.empty()) continue; - int index = get_tensor_pattern_index(tensor_pattern); - auto* tensor_event_detail = - batch_events.batch_detail_proto.mutable_tensor_event_detail(); - tensor_event_detail->set_tensor_pattern_index(index); - tensor_event_detail->set_owner( - tensorflow::profiler::TensorEventDetail::BATCH); - tensor_event_detail->set_linearize_delinearize_time_ps( - GenerateTensorLinearizeDelinearizeTime(batch_events.tensor_events)); - } - - // Populates the tensor details from batch to the related requests. These - // tensor details are still owned by the batches and will not be used to - // calculate statistics like the number of occurrence of each tensor - // pattern. - for (const auto& [group_id, batch_events] : *batch_events_map) { - if (!batch_events.batch_detail_proto.has_tensor_event_detail()) continue; - for (const int64_t request_id : - batch_events.batch_detail_proto.related_request_ids()) { - if (auto* request_events = - gtl::FindOrNull(*request_events_map, request_id)) { - request_events->tensor_event_detail_protos.push_back( - batch_events.batch_detail_proto.tensor_event_detail()); - } - } - } - - // Generates TensorPatternDatabase. - if (tensor_patterns.empty()) { - return; - } - absl::flat_hash_map reversed_tensor_patterns; - for (const auto& tensor_pattern : tensor_patterns) { - reversed_tensor_patterns[tensor_pattern.second] = &tensor_pattern.first; - } - for (int i = 0; i < static_cast(tensor_patterns.size()); i++) { - inference_stats->mutable_tensor_pattern_db()->add_tensor_pattern( - *reversed_tensor_patterns.at(i)); - } -} - -// Generate batch details from batch events. -// host runtime breakdown (added in request details) is not supported. -void BatchEventsToDetails(DeviceType device_type, int64_t group_id, - const BatchEvents& batch_events, - tensorflow::profiler::BatchDetail* batch_detail) { - std::vector tpu_non_overlapped_events; - const std::vector* non_overlapped_events = - &tpu_non_overlapped_events; - if (device_type == DeviceType::kTpu) { - // For TPU device events, batch_events.events may be overlapped in the - // timeline. So first converts it to non-overlapped events in the timeline - // before the breakdown. - tpu_non_overlapped_events = ToNonOverlappedEvents(batch_events.events); - } else if (device_type == DeviceType::kGpu) { - // For GPU device events, batch_events.events come from non overlapped - // StepEvents, so there is no need to convert to non overlapping events - // again. - non_overlapped_events = &(batch_events.events); - } - - int64_t device_time_ps = 0; - for (const auto& event : *non_overlapped_events) { - const auto& duration_ps = event.span.duration_ps(); - switch (event.type) { - case EventType::DEVICE_COMPUTE_16: - case EventType::DEVICE_COMPUTE_32: - device_time_ps += duration_ps; - break; - default: - break; - } - } - batch_detail->set_device_time_ps(device_time_ps); -} - -// Generates the request details proto from its events. -void RequestEventsToDetails( - DeviceType device_type, int64_t group_id, - const RequestEvents& request_events, - tensorflow::profiler::RequestDetail* request_detail) { - request_detail->set_request_id(group_id); - request_detail->set_model_id_index(request_events.model_id_index); - request_detail->set_start_time_ps(request_events.request_timespan.begin_ps()); - request_detail->set_end_time_ps(request_events.request_timespan.end_ps()); - request_detail->set_batching_request_delay_ps( - request_events.batching_request_delay_ps); - request_detail->set_batching_request_size( - request_events.batching_request_size); - for (const auto& tensor_event_detail : - request_events.tensor_event_detail_protos) { - *request_detail->add_tensor_event_details() = tensor_event_detail; - } - for (const int64_t batch_id : request_events.related_batch_ids) { - request_detail->add_related_batch_ids(batch_id); - } - - std::vector tpu_non_overlapped_events; - const std::vector* non_overlapped_events = - &tpu_non_overlapped_events; - if (device_type == DeviceType::kTpu) { - // For TPU device events, request_events.events may be overlapped in the - // timeline. So first converts it to non-overlapped events in the timeline - // before the breakdown. - tpu_non_overlapped_events = ToNonOverlappedEvents(request_events.events); - } else if (device_type == DeviceType::kGpu) { - // For GPU device events, request_events.events come from non overlapped - // StepEvents, so there is no need to convert to non overlapping events - // again. - non_overlapped_events = &(request_events.events); - } - - int64_t device_time_ps = 0; - int64_t write_time_ps = 0; - int64_t read_time_ps = 0; - int64_t host_preprocess_ps = 0; - int64_t host_postprocess_ps = 0; - int64_t host_runtime_ps = 0; - int64_t host_batch_formation_ps = 0; - int64_t idle_time_ps = 0; - for (const auto& event : *non_overlapped_events) { - const auto& duration_ps = event.span.duration_ps(); - switch (event.type) { - case EventType::DEVICE_COMPUTE_16: - case EventType::DEVICE_COMPUTE_32: - device_time_ps += duration_ps; - break; - case EventType::HOST_TO_DEVICE: - write_time_ps += duration_ps; - break; - case EventType::DEVICE_TO_HOST: - read_time_ps += duration_ps; - break; - case EventType::HOST_PREPROCESS: - host_preprocess_ps += duration_ps; - break; - case EventType::HOST_POSTPROCESS: - host_postprocess_ps += duration_ps; - break; - case EventType::HOST_RUNTIME: - host_runtime_ps += duration_ps; - break; - case EventType::HOST_BATCH_FORMATION: - host_batch_formation_ps += duration_ps; - break; - case EventType::UNKNOWN_TIME: - idle_time_ps += duration_ps; - break; - default: - break; - } - } - request_detail->set_device_time_ps(device_time_ps); - request_detail->set_write_to_device_time_ps(write_time_ps); - request_detail->set_read_from_device_time_ps(read_time_ps); - request_detail->set_host_preprocessing_ps(host_preprocess_ps); - request_detail->set_host_postprocessing_ps(host_postprocess_ps); - request_detail->set_host_runtime_ps(host_runtime_ps); - request_detail->set_host_batch_formation_ps(host_batch_formation_ps); - request_detail->set_idle_time_ps(idle_time_ps); -} - -// Compares two data points by duration. -// DataType can be either RequestDetail or BatchDetail. -template -bool CompareByDuration(const DataType& a, const DataType& b) { - return Timespan::ByDuration( - Timespan::FromEndPoints(a.start_time_ps(), a.end_time_ps()), - Timespan::FromEndPoints(b.start_time_ps(), b.end_time_ps())); -} - -void BuildRequestDetails( - const RequestEventsMap& request_events_map, DeviceType device_type, - const int32_t host_id, - tsl::protobuf::RepeatedPtrField* - request_details) { - for (auto& [group_id, request_events] : request_events_map) { - if (request_events.request_timespan.duration_ps() == 0) continue; - tensorflow::profiler::RequestDetail* request_detail = - request_details->Add(); - request_detail->set_host_id(host_id); - RequestEventsToDetails(device_type, group_id, request_events, - request_detail); - } - std::sort(request_details->begin(), request_details->end(), - CompareByDuration); -} - -void BuildBatchDetails( - BatchEventsMap batch_events_map, DeviceType device_type, - const int32_t host_id, - tsl::protobuf::RepeatedPtrField* - batch_details) { - for (auto& [group_id, batch_events] : batch_events_map) { - tensorflow::profiler::BatchDetail* batch_detail = batch_details->Add(); - *batch_detail = std::move(batch_events.batch_detail_proto); - batch_detail->set_host_id(host_id); - BatchEventsToDetails(device_type, group_id, batch_events, batch_detail); - } - std::sort(batch_details->begin(), batch_details->end(), - CompareByDuration); -} - -// Parses TFstreamz xplane to get batching parameters, and stores the -// parameters to . -void ParseTfstreamzForBatchingParameter( - const XSpace& xspace, tensorflow::profiler::ModelIdDatabase* model_id_db) { - const XPlane* tfstreamz_plane = ::tsl::profiler::FindPlaneWithName( - xspace, tsl::profiler::kTFStreamzPlaneName); - // There are two TFStreamz events per profile, one at the beginning, one at - // the end of the profile, each represents a snapshot of the TFstreamz. - // Use the last one as the source to get batching parameters because the - // first snapshot might be taken before Tensorflow setting up the batching - // parameters. - if (tfstreamz_plane == nullptr || tfstreamz_plane->lines().empty() || - tfstreamz_plane->lines(0).events_size() != 2) { - return; - } - XPlaneVisitor plane(tfstreamz_plane); - XEventVisitor event(&plane, &tfstreamz_plane->lines(0), - &tfstreamz_plane->lines(0).events(1)); - - static constexpr char kBatchingParamPrefix[] = - "/tensorflow/serving/batching/"; - static constexpr char kBatchingParamNumBatchThreads[] = "num_batch_threads"; - static constexpr char kBatchingParamBatchTimeoutMicros[] = - "batch_timeout_micros"; - static constexpr char kBatchingParamMaxBatchSize[] = "max_batch_size"; - static constexpr char kBatchingParamMaxEnqueuedBatches[] = - "max_enqueued_batches"; - static constexpr char kBatchingParamAllowedBatchSizes[] = - "allowed_batch_sizes"; - - // Parse the batching parameters from TFstreamz and associate it them with - // model IDs. - absl::flat_hash_map - model_params; - event.ForEachStat([&](const XStatVisitor& stat) { - if (!absl::StartsWith(stat.Name(), kBatchingParamPrefix)) return; - - absl::string_view param_detail = - stat.Name().substr(ABSL_ARRAYSIZE(kBatchingParamPrefix) - 1); - auto [parse_success, model_id_tfstreamz] = ParseModelName(param_detail); - if (!parse_success) { - return; - } - - if (absl::StartsWith(param_detail, kBatchingParamNumBatchThreads)) { - model_params[model_id_tfstreamz].set_num_batch_threads(stat.IntValue()); - } else if (absl::StartsWith(param_detail, - kBatchingParamBatchTimeoutMicros)) { - model_params[model_id_tfstreamz].set_batch_timeout_micros( - stat.IntValue()); - } else if (absl::StartsWith(param_detail, kBatchingParamMaxBatchSize)) { - model_params[model_id_tfstreamz].set_max_batch_size(stat.IntValue()); - } else if (absl::StartsWith(param_detail, - kBatchingParamMaxEnqueuedBatches)) { - model_params[model_id_tfstreamz].set_max_enqueued_batches( - stat.IntValue()); - } else if (absl::StartsWith(param_detail, - kBatchingParamAllowedBatchSizes)) { - model_params[model_id_tfstreamz].set_allowed_batch_sizes( - std::string(stat.StrOrRefValue())); - } - }); - - // It is possible that the model IDs from Session.Run is in the format of - // :, while the model IDs in TFstreamz is in the format - // of (without the version number). Build a map to connect the - // model IDs in TFstreamz and Session.Run. - absl::flat_hash_map> - model_id_map; - for (const auto& model_id_and_version : model_id_db->ids()) { - size_t i = model_id_and_version.find_last_of(':'); - if (i == std::string::npos) { - model_id_map[model_id_and_version].push_back(model_id_and_version); - } else { - // If there is a version number at the end of model_id, remove the - // version number. - absl::string_view version_str(model_id_and_version.data() + i + 1); - int64_t version; - bool success = absl::SimpleAtoi(version_str, &version); - if (success) { - absl::string_view model_id_only(model_id_and_version.data(), i); - model_id_map[model_id_only].push_back(model_id_and_version); - } else { - LOG(ERROR) << "Can not parse model version number: " << version_str; - } - } - } - - // One model ID from TFstreamz might map to multiple model IDs in - // Session.Run, update the batching parameters of all the model IDs in - // Session.Run. - for (const auto& [model_id_tfstreamz, params] : model_params) { - if (const std::vector* model_ids_session_run = - gtl::FindOrNull(model_id_map, model_id_tfstreamz)) { - for (const absl::string_view model_id_session_run : - *model_ids_session_run) { - (*model_id_db->mutable_id_to_batching_params())[model_id_session_run] = - params; - } - } - } -} - -} // namespace - -std::pair ParseModelName(absl::string_view param) { - // Param can be in one of the two following formats: - // batching_param{model_name=} - // batching_param{model_name=, op_name=} - size_t label_begin = param.find_first_of('{'); - size_t label_end = param.find_last_of('}'); - if (label_begin == absl::string_view::npos || - label_end == absl::string_view::npos || label_end <= label_begin) { - return {false, ""}; - } - // Go over all the labels to look for model name. - std::vector labels = absl::StrSplit( - param.substr(label_begin + 1, label_end - label_begin - 1), ", "); - for (const absl::string_view label : labels) { - std::vector key_value = absl::StrSplit(label, '='); - if (key_value.size() != 2) continue; - if (key_value[0] == "model_name") { - return {true, key_value[1]}; - } - } - // Unable to find model name. - return {false, ""}; -} - -void GenerateInferenceStats( - const std::vector& device_traces, - const StepEvents& nonoverlapped_step_events, - const GroupMetadataMap& group_metadata_map, const XSpace& xspace, - DeviceType device_type, int32_t host_id, - tensorflow::profiler::InferenceStats* inference_stats) { - tensorflow::profiler::PerHostInferenceStats* per_host_inference_stats = - &(*inference_stats->mutable_inference_stats_per_host())[host_id]; - RequestEventsMap request_events_map; - - // Build the mapping from host event type to events. - EventsByType host_events_by_type; - const XPlane* host = tsl::profiler::FindPlaneWithName( - xspace, tsl::profiler::kHostThreadsPlaneName); - if (!host) return; - XPlaneVisitor host_plane = CreateTfXPlaneVisitor(host); - for (const auto& line : host->lines()) { - for (const auto& event : line.events()) { - XEventVisitor event_visitor(&host_plane, &line, &event); - auto type = event_visitor.Type(); - if (!type.has_value()) { - type = HostEventType::kUnknownHostEventType; - } - host_events_by_type[type.value()].push_back(event_visitor); - } - } - - BuildRequestEventsMap(device_traces, host_events_by_type, group_metadata_map, - nonoverlapped_step_events, device_type, - inference_stats->mutable_model_id_db(), - &request_events_map); - BatchEventsMap batch_events_map; - BuildBatchEventsMap(device_traces, host_events_by_type, group_metadata_map, - nonoverlapped_step_events, device_type, - &request_events_map, &batch_events_map); - - GenerateRequestAndBatchDelay(&request_events_map, &batch_events_map); - GenerateRequestDetailedBreakdown(&request_events_map); - - GenerateTensorDetails(host_events_by_type, &request_events_map, - &batch_events_map, inference_stats); - - auto* request_details = per_host_inference_stats->mutable_request_details(); - BuildRequestDetails(request_events_map, device_type, host_id, - request_details); - auto* batch_details = per_host_inference_stats->mutable_batch_details(); - BuildBatchDetails(std::move(batch_events_map), device_type, host_id, - batch_details); - - ParseTfstreamzForBatchingParameter(xspace, - inference_stats->mutable_model_id_db()); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/inference_stats.h b/tensorflow/core/profiler/convert/inference_stats.h deleted file mode 100644 index 36b0aa600a2125..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_H_ - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/device_utils.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Generates PerHostInferenceStats from the given trace events. -// For TPU, get time breakdown from device_traces. For GPU, get time breakdown -// from nonoverlapped_step_events. -// Get batching parameters from TFstreamz xplane in . -void GenerateInferenceStats( - const std::vector& device_traces, - const tensorflow::profiler::StepEvents& nonoverlapped_step_events, - const tsl::profiler::GroupMetadataMap& group_metadata_map, - const tsl::profiler::XSpace& xspace, tsl::profiler::DeviceType device_type, - int32_t host_id, tensorflow::profiler::InferenceStats* inference_stats); - -// Parses model name from TFstreamz. -// Returns whether the parsing is successful and the actual model name. If -// parsing failed, returns false and an empty string. -std::pair ParseModelName(absl::string_view param); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_H_ diff --git a/tensorflow/core/profiler/convert/inference_stats_combiner.cc b/tensorflow/core/profiler/convert/inference_stats_combiner.cc deleted file mode 100644 index fcca1310061d16..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_combiner.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/inference_stats_combiner.h" - -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/lib/gtl/map_util.h" - -namespace tensorflow::profiler { -namespace { -// Combines two ModelIdDatabases. Returns true if this combination requires -// updating the model_id_index in the SessionRunTimes of dst. This will be -// the case if: (1) Src has a model name that doesn't already exist in dst; -// or (2) Src has a model name that does exist in dst but has a different -// index. -bool CombineModelIdDatabases(const ModelIdDatabase& src, ModelIdDatabase* dst) { - if (dst->ids_size() == 0) { - // dst is empty. Simply copy src to dst. This avoids rebuilding - // dst from src from scratch, which may change the name-to-index mapping. - *dst = src; - return false; - } - // TODO(tianrun): For now, assume a model is always served with the same - // parameter on different hosts. In the future, we might consider the case - // when the same model are served with different batching parameters on - // different hosts. - for (const auto& id_and_param : src.id_to_batching_params()) { - dst->mutable_id_to_batching_params()->insert(id_and_param); - } - bool need_update = false; - for (const auto& [src_id, index] : src.id_to_index()) { - auto [iter, was_inserted] = - dst->mutable_id_to_index()->insert({src_id, dst->ids_size()}); - if (was_inserted) { - *dst->add_ids() = src_id; - need_update = true; - continue; - } - if (iter->second != index) { - // src_id is already in dst but has a different index. - need_update = true; - } - } - return need_update; -} - -// Combines two TensorPatternDatabase. Returns true if this combination requires -// updating the tensor_pattern_index. This will be the case if: (1) Src has a -// tensor pattern that doesn't exist in dst; or (2) Src has a tensor pattern -// that does exist in dst but has a different index. -bool CombineTensorPatternDatabase( - const TensorPatternDatabase& src, TensorPatternDatabase* dst, - absl::flat_hash_map* dst_pattern_to_index) { - if (dst->tensor_pattern().empty()) { - *dst = src; - return false; - } - - bool need_update = false; - for (int i = 0; i < static_cast(src.tensor_pattern_size()); i++) { - auto [iter, inserted] = dst_pattern_to_index->insert( - {src.tensor_pattern(i), dst_pattern_to_index->size()}); - if (inserted) { - // Src has a tensor pattern that doesn't exist in dst. - dst->add_tensor_pattern(src.tensor_pattern(i)); - need_update = true; - } else if (iter->second != i) { - // Src has a tensor pattern with different index than dst. - need_update = true; - } - } - return need_update; -} - -void UpdateTensorPatternIndex( - const TensorPatternDatabase& src, - const absl::flat_hash_map& dst_pattern_to_index, - TensorEventDetail* detail) { - absl::string_view tensor_pattern = - src.tensor_pattern(detail->tensor_pattern_index()); - if (const int* new_index = - tsl::gtl::FindOrNull(dst_pattern_to_index, tensor_pattern)) { - detail->set_tensor_pattern_index(*new_index); - } else { - LOG(WARNING) << "Tensor pattern " << tensor_pattern - << " is not found in dst->tensor_pattern_db()"; - } -} -} // namespace - -void CombineInferenceStatsResult(int src_host_id, const InferenceStats& src, - InferenceStats* dst) { - // There should be one key-value pair inside src.inference_stats_per_host(), - // because the src comes from one XprofResponse (i.e., one host). - DCHECK_LE(src.inference_stats_per_host_size(), 1); - bool need_update_model_id = - CombineModelIdDatabases(src.model_id_db(), dst->mutable_model_id_db()); - absl::flat_hash_map dst_pattern_to_index; - for (int i = 0; - i < static_cast(dst->tensor_pattern_db().tensor_pattern_size()); - i++) { - dst_pattern_to_index[dst->tensor_pattern_db().tensor_pattern(i)] = i; - } - bool need_update_tensor_pattern = CombineTensorPatternDatabase( - src.tensor_pattern_db(), dst->mutable_tensor_pattern_db(), - &dst_pattern_to_index); - for (const auto& [host_id, inf_stats] : src.inference_stats_per_host()) { - auto [iter, was_inserted] = dst->mutable_inference_stats_per_host()->insert( - {src_host_id, inf_stats}); - if (!was_inserted) { - LOG(INFO) << "Duplicate host_id: " << iter->first; - } - if (need_update_model_id || need_update_tensor_pattern) { - // Needs to update the model_id_index in the dst. - PerHostInferenceStats* dst_inference_stats = - &(*dst->mutable_inference_stats_per_host())[src_host_id]; - for (RequestDetail& request_detail : - *dst_inference_stats->mutable_request_details()) { - if (need_update_model_id && request_detail.model_id_index() != -1) { - // "model_id_index = -1" means there is no model_id associated with - // the group id in this event if client doesn't specify "model_id" in - // TraceMeEncode. so we don't need to update model_id if it doesn't - // have a model. - const std::string& model_id = - src.model_id_db().ids(request_detail.model_id_index()); - auto iter = dst->model_id_db().id_to_index().find(model_id); - if (iter == dst->model_id_db().id_to_index().end()) { - LOG(WARNING) << "Model ID " << model_id - << " is not found in dst->model_id_db()"; - continue; - } - request_detail.set_model_id_index(iter->second); - } - if (need_update_tensor_pattern) { - for (auto& tensor_event_details : - *request_detail.mutable_tensor_event_details()) { - UpdateTensorPatternIndex(src.tensor_pattern_db(), - dst_pattern_to_index, - &tensor_event_details); - } - } - } - } - if (need_update_tensor_pattern) { - PerHostInferenceStats* dst_inference_stats = - &(*dst->mutable_inference_stats_per_host())[src_host_id]; - for (BatchDetail& batch_detail : - *dst_inference_stats->mutable_batch_details()) { - UpdateTensorPatternIndex(src.tensor_pattern_db(), dst_pattern_to_index, - batch_detail.mutable_tensor_event_detail()); - } - } - } -} -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/inference_stats_combiner.h b/tensorflow/core/profiler/convert/inference_stats_combiner.h deleted file mode 100644 index ceccc9cca2608a..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_combiner.h +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_COMBINER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_COMBINER_H_ -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { -void CombineInferenceStatsResult(int src_host_id, const InferenceStats& src, - InferenceStats* dst); -} // namespace tensorflow::profiler - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_COMBINER_H_ diff --git a/tensorflow/core/profiler/convert/inference_stats_grouping.cc b/tensorflow/core/profiler/convert/inference_stats_grouping.cc deleted file mode 100644 index fad8330ac72f63..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_grouping.cc +++ /dev/null @@ -1,475 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/inference_stats_grouping.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "xla/tsl/lib/gtl/map_util.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow::profiler { - -namespace { - -using ::tensorflow::profiler::BatchDetail; -using ::tensorflow::profiler::InferenceStats; -using ::tensorflow::profiler::ModelIdDatabase; -using ::tensorflow::profiler::PerBatchSizeAggregatedResult; -using ::tensorflow::profiler::PerModelInferenceStats; -using ::tensorflow::profiler::RequestDetail; -using ::tensorflow::profiler::TensorEventDetail; -using ::tsl::profiler::Timespan; - -template -void push_down_heap(size_t hole, RandIt first, RandIt last, Compare comp) { - size_t size = last - first; - assert(hole < size); - auto value = std::move(first[hole]); - while (true) { - size_t l_child = 2 * hole + 1; - size_t r_child = l_child + 1; - size_t max_child = l_child; - if (r_child < size && comp(first[l_child], first[r_child])) { - max_child = r_child; - } - if (max_child >= size) break; - if (!comp(value, first[max_child])) break; - first[hole] = std::move(first[max_child]); - hole = max_child; - } - first[hole] = std::move(value); -} -// Pushes the root down the heap. -template -void push_root_heap(RandIt first, RandIt last, Compare comp) { - push_down_heap(0, std::move(first), std::move(last), std::move(comp)); -} - -template -Out nway_merge(const ContainerContainer& containers, Out out, Cmp cmp) { - using std::begin; - using std::end; - using In = decltype(begin(*begin(containers))); // The input iterator type. - using Range = std::pair; - std::vector sources; - for (const auto& container : containers) { - Range r(begin(container), end(container)); - if (r.first != r.second) sources.push_back(std::move(r)); - } - // Zero, one or two collections can be merged without a priority queue. - switch (sources.size()) { - case 0: - return out; - case 1: - return std::copy(sources[0].first, sources[0].second, out); - case 2: - return std::merge(sources[0].first, sources[0].second, sources[1].first, - sources[1].second, out, cmp); - } - // Take a comparator for T and produce an inverse comparator - // for std::pair, In>, inverted so as to produce a min-heap. - auto heap_cmp = [&](const Range& a, const Range& b) { - // Compares b < a instead of a < b. - return cmp(*b.first, *a.first); - }; - auto heap_data = sources.data(); - auto heap_size = sources.size(); - std::make_heap(heap_data, heap_data + heap_size, heap_cmp); - auto& top = sources.front(); - auto pop = [&]() { - *out = *top.first; - ++out; - ++top.first; - }; - - for (; heap_size > 2;) { - for (pop(); top.first != top.second; pop()) { - push_root_heap(heap_data, heap_data + heap_size, heap_cmp); - } - top = std::move(sources[--heap_size]); - push_root_heap(heap_data, heap_data + heap_size, heap_cmp); - } - - return std::merge(sources[0].first, sources[0].second, sources[1].first, - sources[1].second, out, cmp); -} - -double GetThroughput(size_t data_size, uint64_t start_time_ps, - uint64_t end_time_ps) { - return data_size / tsl::profiler::PicoToUni(end_time_ps - start_time_ps); -} - -// Compute throughput and average latency. -// DataType can either be RequestDetail or BatchDetail. -template -std::pair ComputeThroughputAndAverageLatencyUs( - const std::vector& all_data) { - if (all_data.empty()) { - // Return 0 immediately to avoid divide by zero error. - return std::make_pair(0.0, 0.0); - } - - uint64_t min_start_time_ps = std::numeric_limits::max(); - uint64_t max_end_time_ps = 0; - uint64_t total_latency_ps = 0; - - for (const DataType* data : all_data) { - min_start_time_ps = std::min(min_start_time_ps, data->start_time_ps()); - max_end_time_ps = std::max(max_end_time_ps, data->end_time_ps()); - total_latency_ps += (data->end_time_ps() - data->start_time_ps()); - } - - double throughput = - GetThroughput(all_data.size(), min_start_time_ps, max_end_time_ps); - double average_latency_us = - tsl::profiler::PicoToMicro(total_latency_ps) / all_data.size(); - return std::make_pair(throughput, average_latency_us); -} - -template -bool CompareByDuration(const DataType* a, const DataType* b) { - return Timespan::ByDuration( - Timespan::FromEndPoints(a->start_time_ps(), a->end_time_ps()), - Timespan::FromEndPoints(b->start_time_ps(), b->end_time_ps())); -} - -// Regroup data in using model id for future analysis. -// DataType can be either RequestDetail or BatchDetail. -template -void RegroupDataByModelId( - const ModelIdDatabase& model_id_db, - const std::vector*>& - data_by_host, - std::vector>* data_by_model_id) { - // First group data by model_id and host. - std::vector>> - data_by_model_id_by_host; - - // If model_id_db is empty, this means model_id is not available in the trace, - // so we simply consider the entire execution as a single model_id. - bool no_model_id = model_id_db.ids_size() == 0; - int model_index_size = no_model_id ? 1 : model_id_db.ids_size(); - int host_index_size = data_by_host.size(); - data_by_model_id_by_host.resize(model_index_size); - for (size_t model_index = 0; model_index < model_index_size; ++model_index) { - data_by_model_id_by_host[model_index].resize(host_index_size); - } - - int32_t host_index = 0; - for (const tsl::protobuf::RepeatedPtrField* single_host_data : - data_by_host) { - for (const DataType& data : *single_host_data) { - int model_index = no_model_id ? 0 : data.model_id_index(); - // If model_id_db is not empty, and a session/batch does not have - // model_id, ignore it in per model analysis. - if (model_index == -1) { - continue; - } - data_by_model_id_by_host[model_index][host_index].push_back(&data); - } - ++host_index; - } - - // data_by_host is already sorted by the latency, so - // data_by_model_id_by_host is also sorted by the latency. Therefore, - // we just need to do a n way merge instead of a real sorting. - data_by_model_id->resize(model_index_size); - for (size_t model_index = 0; model_index < model_index_size; ++model_index) { - int total_size = 0; - for (const auto& per_model_per_host : - data_by_model_id_by_host[model_index]) { - total_size += per_model_per_host.size(); - } - data_by_model_id->at(model_index).reserve(total_size); - } - for (size_t model_index = 0; model_index < model_index_size; ++model_index) { - nway_merge(data_by_model_id_by_host[model_index], - std::back_inserter(data_by_model_id->at(model_index)), - CompareByDuration); - } -} - -// Generates the tensor transfer aggregated result using the per model data in -// . -void GenerateTensorTransferAggregatedResult(PerModelInferenceStats* per_model) { - absl::flat_hash_map> - tensor_events_by_index; - // For requests, only count the tensor events with owner REQUEST, because if - // inference batching is enabled, there will be tensor events that are owned - // by batches and just inherited by requests. Counting these tensor events - // will lead to double counting. - for (const auto& request : per_model->request_details()) { - for (const auto& tensor_event : request.tensor_event_details()) { - if (tensor_event.owner() == TensorEventDetail::REQUEST) { - tensor_events_by_index[tensor_event.tensor_pattern_index()].push_back( - &tensor_event); - } - } - } - for (const auto& batch : per_model->batch_details()) { - if (batch.has_tensor_event_detail()) { - tensor_events_by_index[batch.tensor_event_detail().tensor_pattern_index()] - .push_back(&batch.tensor_event_detail()); - } - } - - if (tensor_events_by_index.empty()) return; - - static constexpr double kPercentiles[] = {50.0, 75.0, 90.0, 95.0, 99.0, 99.9}; - for (auto& [index, events] : tensor_events_by_index) { - auto* tensor_pattern_result = - per_model->mutable_tensor_transfer_aggregated_result() - ->add_tensor_pattern_results(); - tensor_pattern_result->set_tensor_pattern_index(index); - tensor_pattern_result->set_count(events.size()); - std::sort(events.begin(), events.end(), - [](const TensorEventDetail* a, const TensorEventDetail* b) { - return a->linearize_delinearize_time_ps() < - b->linearize_delinearize_time_ps(); - }); - for (const double percentile : kPercentiles) { - int index = static_cast(percentile / 100.0 * events.size()); - auto* percentile_time = - tensor_pattern_result->add_linearize_delinearize_percentile_time(); - percentile_time->set_percentile(percentile); - percentile_time->set_time_ps( - events[index]->linearize_delinearize_time_ps()); - } - } -} - -void AggregateRequest(const RequestDetail& input, RequestDetail* result) { - // In aggregated result, start_time is set to 0, and end time is set to the - // sum of the duration of the input requests. - result->set_end_time_ps(input.end_time_ps() - input.start_time_ps() + - result->end_time_ps()); - result->set_device_time_ps(result->device_time_ps() + input.device_time_ps()); - result->set_read_from_device_time_ps(result->read_from_device_time_ps() + - input.read_from_device_time_ps()); - result->set_write_to_device_time_ps(result->write_to_device_time_ps() + - input.write_to_device_time_ps()); - result->set_batching_request_delay_ps(result->batching_request_delay_ps() + - input.batching_request_delay_ps()); - result->set_batching_request_size(result->batching_request_size() + - input.batching_request_size()); - result->set_host_preprocessing_ps(result->host_preprocessing_ps() + - input.host_preprocessing_ps()); - result->set_host_batch_formation_ps(result->host_batch_formation_ps() + - input.host_batch_formation_ps()); - result->set_host_runtime_ps(result->host_runtime_ps() + - input.host_runtime_ps()); - result->set_host_postprocessing_ps(result->host_postprocessing_ps() + - input.host_postprocessing_ps()); - result->set_idle_time_ps(result->idle_time_ps() + input.idle_time_ps()); -} - -RequestDetail GetAverageRequestDetails(const RequestDetail& request, - int64_t size) { - RequestDetail result; - if (size == 0) return result; - // Average request detail does not have a request ID. - result.set_request_id(-1); - result.set_start_time_ps(0); - // Calculating average by dividing aggregated request by size. - result.set_end_time_ps(request.end_time_ps() / size); - result.set_device_time_ps(request.device_time_ps() / size); - result.set_write_to_device_time_ps(request.write_to_device_time_ps() / size); - result.set_read_from_device_time_ps(request.read_from_device_time_ps() / - size); - result.set_batching_request_delay_ps(request.batching_request_delay_ps() / - size); - result.set_batching_request_size(request.batching_request_size() / size); - result.set_host_preprocessing_ps(request.host_preprocessing_ps() / size); - result.set_host_batch_formation_ps(request.host_batch_formation_ps() / size); - result.set_host_runtime_ps(request.host_runtime_ps() / size); - result.set_host_postprocessing_ps(request.host_postprocessing_ps() / size); - result.set_idle_time_ps(request.idle_time_ps() / size); - return result; -} - -void AggregateBatch(const BatchDetail& input, BatchDetail* result) { - // In aggregated result, start_time is set to 0, and end time is set to the - // sum of the duration of the input batches. - result->set_end_time_ps(input.end_time_ps() - input.start_time_ps() + - result->end_time_ps()); - result->set_batch_delay_ps(result->batch_delay_ps() + input.batch_delay_ps()); - result->set_padding_amount(result->padding_amount() + input.padding_amount()); - result->set_batch_size_after_padding(result->batch_size_after_padding() + - input.batch_size_after_padding()); - result->set_device_time_ps(result->device_time_ps() + input.device_time_ps()); -} - -BatchDetail GetAverageBatchDetails(const BatchDetail& batch, int64_t size) { - BatchDetail result; - if (size == 0) return result; - // Average batch detail does not have a batch ID. - result.set_batch_id(-1); - result.set_start_time_ps(0); - // Calculating average by dividing aggregated batch by size. - result.set_end_time_ps(batch.end_time_ps() / size); - result.set_batch_delay_ps(batch.batch_delay_ps() / size); - result.set_padding_amount(batch.padding_amount() / size); - result.set_batch_size_after_padding(batch.batch_size_after_padding() / size); - result.set_device_time_ps(batch.device_time_ps() / size); - return result; -} - -void AggregatePerModelInferenceStats(InferenceStats* inference_stats) { - for (auto& [model_index, per_model_stats] : - *inference_stats->mutable_inference_stats_per_model()) { - // TODO: remove batch size aggregation from request table. - absl::flat_hash_map batch_id_to_batch; - for (const BatchDetail& b : per_model_stats.batch_details()) { - batch_id_to_batch[b.batch_id()] = &b; - } - - // Aggregated result for all data. - RequestDetail aggregated_r; - BatchDetail aggregated_b; - - struct PerBatchSizeInfo { - PerBatchSizeAggregatedResult result; - int request_count; - int batch_count; - }; - // Aggregated result per batch size. - absl::flat_hash_map per_batch_size_info; - - for (const RequestDetail& r : per_model_stats.request_details()) { - // Aggregate all data. - AggregateRequest(r, &aggregated_r); - // Aggregate per batch size. - // TODO: remove batch size aggregation from request table. - for (const auto batch_id : r.related_batch_ids()) { - if (const BatchDetail* batch = - ::tsl::gtl::FindPtrOrNull(batch_id_to_batch, batch_id)) { - int batch_size = batch->batch_size_after_padding(); - auto& info = per_batch_size_info[batch_size]; - AggregateRequest(r, info.result.mutable_aggregated_request_result()); - info.request_count++; - } - } - } - - for (const BatchDetail& b : per_model_stats.batch_details()) { - // Aggregate all data. - AggregateBatch(b, &aggregated_b); - // Aggregate per batch size. - int batch_size = b.batch_size_after_padding(); - auto& info = per_batch_size_info[batch_size]; - AggregateBatch(b, info.result.mutable_aggregated_batch_result()); - info.batch_count++; - } - - *per_model_stats.mutable_aggregated_request_detail() = - GetAverageRequestDetails(aggregated_r, - per_model_stats.request_details().size()); - *per_model_stats.mutable_aggregated_batch_detail() = GetAverageBatchDetails( - aggregated_b, per_model_stats.batch_details().size()); - - std::vector sorted_batch_sizes; - for (const auto& [batch_size, _] : per_batch_size_info) { - sorted_batch_sizes.push_back(batch_size); - } - std::sort(sorted_batch_sizes.begin(), sorted_batch_sizes.end()); - for (const int batch_size : sorted_batch_sizes) { - auto* result = per_model_stats.add_per_batch_size_aggregated_result(); - result->set_batch_size(batch_size); - auto& info = per_batch_size_info[batch_size]; - *result->mutable_aggregated_request_result() = GetAverageRequestDetails( - info.result.aggregated_request_result(), info.request_count); - result->set_request_throughput(info.request_count * - per_model_stats.request_throughput() / - per_model_stats.request_details_size()); - *result->mutable_aggregated_batch_result() = GetAverageBatchDetails( - info.result.aggregated_batch_result(), info.batch_count); - result->set_batch_throughput(info.batch_count * - per_model_stats.batch_throughput() / - per_model_stats.batch_details_size()); - } - } -} - -} // namespace - -void RegroupInferenceStatsByModel(InferenceStats* inference_stats) { - if (inference_stats->inference_stats_per_host().empty()) { - return; - } - std::vector*> - all_requests_by_host; - for (const auto& [host_id, per_host_inference_stats] : - inference_stats->inference_stats_per_host()) { - all_requests_by_host.push_back(&per_host_inference_stats.request_details()); - } - std::vector> requests_by_model_id; - RegroupDataByModelId(inference_stats->model_id_db(), all_requests_by_host, - &requests_by_model_id); - - std::vector*> - all_batches_by_host; - for (const auto& [host_id, per_host_inference_stats] : - inference_stats->inference_stats_per_host()) { - all_batches_by_host.push_back(&per_host_inference_stats.batch_details()); - } - std::vector> batches_by_model_id; - RegroupDataByModelId(inference_stats->model_id_db(), all_batches_by_host, - &batches_by_model_id); - - for (size_t index = 0; index < requests_by_model_id.size(); index++) { - auto* per_model = - &(*inference_stats->mutable_inference_stats_per_model())[index]; - for (const RequestDetail* request : requests_by_model_id[index]) { - *per_model->add_request_details() = *request; - } - for (const BatchDetail* batch : batches_by_model_id[index]) { - *per_model->add_batch_details() = *batch; - } - auto [request_throughput, request_latency] = - ComputeThroughputAndAverageLatencyUs(requests_by_model_id[index]); - per_model->set_request_throughput(request_throughput); - per_model->set_request_average_latency_us(request_latency); - auto [batch_throughput, batch_latency] = - ComputeThroughputAndAverageLatencyUs(batches_by_model_id[index]); - per_model->set_batch_throughput(batch_throughput); - per_model->set_batch_average_latency_us(batch_latency); - GenerateTensorTransferAggregatedResult(per_model); - } - - AggregatePerModelInferenceStats(inference_stats); - - // If there is no model id provided by user, create a fake "ALL" model id to - // represent all the requests during profiling. - // This ALL model id is mapped to index 0, which is consistent with the index - // used by RegroupDataByModelId. - if (inference_stats->model_id_db().ids().empty()) { - inference_stats->mutable_model_id_db()->add_ids("ALL"); - inference_stats->mutable_model_id_db()->mutable_id_to_index()->insert( - {"ALL", 0}); - } - inference_stats->clear_inference_stats_per_host(); -} - -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/inference_stats_grouping.h b/tensorflow/core/profiler/convert/inference_stats_grouping.h deleted file mode 100644 index 7d60da0f311826..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_grouping.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_GROUPING_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_GROUPING_H_ - -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { - -// Change inference stats from per host to per model_id by doing a regroup. -// Future analysis of inference_stats will be on a per model_id basis. -void RegroupInferenceStatsByModel( - tensorflow::profiler::InferenceStats* inference_stats); - -} // namespace tensorflow::profiler - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_GROUPING_H_ diff --git a/tensorflow/core/profiler/convert/inference_stats_grouping_test.cc b/tensorflow/core/profiler/convert/inference_stats_grouping_test.cc deleted file mode 100644 index 5d6d43e5ba8150..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_grouping_test.cc +++ /dev/null @@ -1,508 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/inference_stats_grouping.h" - -#include -#include "xla/tests/test_utils.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { -namespace { - -using ::testing::EqualsProto; -using ::xla::ParseTextProto; - -TEST(InferenceStatsGroupingTest, TestWithModelId) { - // An inference stats with two hosts, two models. - InferenceStats inference_stats = ParseTextProto(R"pb( - inference_stats_per_host { - key: 0 - value { - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - model_id_index: 0 - request_id: 0 - device_time_ps: 100 - } - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - model_id_index: 1 - request_id: 1 - device_time_ps: 100 - } - } - } - inference_stats_per_host { - key: 1 - value { - request_details { - start_time_ps: 3000 - end_time_ps: 4000 - model_id_index: 0 - request_id: 2 - device_time_ps: 100 - } - request_details { - start_time_ps: 4000 - end_time_ps: 5000 - model_id_index: 1 - request_id: 3 - device_time_ps: 100 - } - } - } - model_id_db { - ids: "Model-A:1" - ids: "Model-B:1" - id_to_index { key: "Model-A:1" value: 0 } - id_to_index { key: "Model-B:1" value: 1 } - } - )pb") - .value(); - - RegroupInferenceStatsByModel(&inference_stats); - - // Verifies that requests with the same model ID are grouped together. - EXPECT_THAT(inference_stats, EqualsProto(R"pb( - model_id_db { - ids: "Model-A:1" - ids: "Model-B:1" - id_to_index { key: "Model-A:1" value: 0 } - id_to_index { key: "Model-B:1" value: 1 } - } - inference_stats_per_model { - key: 0 - value { - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - model_id_index: 0 - request_id: 0 - device_time_ps: 100 - } - request_details { - start_time_ps: 3000 - end_time_ps: 4000 - model_id_index: 0 - request_id: 2 - device_time_ps: 100 - } - aggregated_request_detail { - request_id: -1 - start_time_ps: 0 - end_time_ps: 1000 - write_to_device_time_ps: 0 - read_from_device_time_ps: 0 - device_time_ps: 100 - batching_request_delay_ps: 0 - batching_request_size: 0 - host_preprocessing_ps: 0 - host_batch_formation_ps: 0 - host_runtime_ps: 0 - host_postprocessing_ps: 0 - idle_time_ps: 0 - } - aggregated_batch_detail {} - request_throughput: 666666666.66666663 - request_average_latency_us: 0.001 - batch_throughput: 0 - batch_average_latency_us: 0 - } - } - inference_stats_per_model { - key: 1 - value { - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - model_id_index: 1 - request_id: 1 - device_time_ps: 100 - } - request_details { - start_time_ps: 4000 - end_time_ps: 5000 - model_id_index: 1 - request_id: 3 - device_time_ps: 100 - } - aggregated_request_detail { - request_id: -1 - start_time_ps: 0 - end_time_ps: 1000 - write_to_device_time_ps: 0 - read_from_device_time_ps: 0 - device_time_ps: 100 - batching_request_delay_ps: 0 - batching_request_size: 0 - host_preprocessing_ps: 0 - host_batch_formation_ps: 0 - host_runtime_ps: 0 - host_postprocessing_ps: 0 - idle_time_ps: 0 - } - aggregated_batch_detail {} - request_throughput: 666666666.66666663 - request_average_latency_us: 0.001 - batch_throughput: 0 - batch_average_latency_us: 0 - } - })pb")); -} - -TEST(InferenceStatsGroupingTest, TestTensorPatternPercentile) { - // Generates an inference stats for test, 6 requests have tensor events owned - // by REQUEST, 2 requests have tensor events owned by BATCH. - InferenceStats inference_stats = - ParseTextProto(R"pb( - inference_stats_per_host { - key: 0 - value { - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - request_id: 0 - tensor_event_details { - tensor_pattern_index: 0 - owner: REQUEST - linearize_delinearize_time_ps: 600000 - } - } - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - request_id: 1 - tensor_event_details { - tensor_pattern_index: 0 - owner: REQUEST - linearize_delinearize_time_ps: 500000 - } - } - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - request_id: 2 - tensor_event_details { - tensor_pattern_index: 0 - owner: REQUEST - linearize_delinearize_time_ps: 400000 - } - } - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - request_id: 3 - tensor_event_details { - tensor_pattern_index: 0 - owner: REQUEST - linearize_delinearize_time_ps: 300000 - } - } - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - request_id: 4 - tensor_event_details { - tensor_pattern_index: 0 - owner: REQUEST - linearize_delinearize_time_ps: 200000 - } - } - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - request_id: 5 - tensor_event_details { - tensor_pattern_index: 0 - owner: REQUEST - linearize_delinearize_time_ps: 100000 - } - } - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - request_id: 6 - tensor_event_details { - tensor_pattern_index: 0 - owner: BATCH - linearize_delinearize_time_ps: 700000 - } - } - request_details { - start_time_ps: 2000 - end_time_ps: 3000 - request_id: 7 - tensor_event_details { - tensor_pattern_index: 0 - owner: BATCH - linearize_delinearize_time_ps: 800000 - } - } - } - } - )pb") - .value(); - - RegroupInferenceStatsByModel(&inference_stats); - - // Count equals to 6 because request tensor events owned by BATCH are ignored. - // Percentile selector selects linearize and delinearize time at 50.0, 75.0, - // 90.0, 95.0, 99.0, 99.9 percentiles. - EXPECT_THAT(inference_stats.inference_stats_per_model() - .at(0) - .tensor_transfer_aggregated_result(), - EqualsProto(R"pb( - tensor_pattern_results { - tensor_pattern_index: 0 - count: 6 - linearize_delinearize_percentile_time { - percentile: 50 - time_ps: 400000 - } - linearize_delinearize_percentile_time { - percentile: 75 - time_ps: 500000 - } - linearize_delinearize_percentile_time { - percentile: 90 - time_ps: 600000 - } - linearize_delinearize_percentile_time { - percentile: 95 - time_ps: 600000 - } - linearize_delinearize_percentile_time { - percentile: 99 - time_ps: 600000 - } - linearize_delinearize_percentile_time { - percentile: 99.9 - time_ps: 600000 - } - } - )pb")); -} - -TEST(InferenceStatsGroupingTest, TestWithoutModelId) { - // An inference stats with two hosts, no model ID data. - InferenceStats inference_stats = ParseTextProto(R"pb( - inference_stats_per_host { - key: 0 - value { - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - request_id: 0 - related_batch_ids: 0 - host_runtime_ps: 100 - } - request_details { - start_time_ps: 2000 - end_time_ps: 4000 - request_id: 1 - related_batch_ids: 0 - host_runtime_ps: 100 - } - batch_details { - batch_id: 0 - related_request_ids: 0 - related_request_ids: 1 - start_time_ps: 1000 - end_time_ps: 2000 - batch_size_after_padding: 128 - } - } - } - inference_stats_per_host { - key: 1 - value { - request_details { - start_time_ps: 3000 - end_time_ps: 6000 - request_id: 2 - related_batch_ids: 1 - host_runtime_ps: 100 - } - request_details { - start_time_ps: 4000 - end_time_ps: 8000 - request_id: 3 - related_batch_ids: 1 - host_runtime_ps: 100 - } - batch_details { - batch_id: 1 - related_request_ids: 2 - related_request_ids: 3 - start_time_ps: 3000 - end_time_ps: 4000 - batch_size_after_padding: 256 - } - } - } - )pb") - .value(); - - RegroupInferenceStatsByModel(&inference_stats); - - // Verifies that all requests are grouped into a single model, and a "ALL" - // model ID is added. - EXPECT_THAT(inference_stats, EqualsProto(R"pb( - model_id_db { - ids: "ALL" - id_to_index { key: "ALL" value: 0 } - } - inference_stats_per_model { - key: 0 - value { - request_details { - start_time_ps: 1000 - end_time_ps: 2000 - request_id: 0 - related_batch_ids: 0 - host_runtime_ps: 100 - } - request_details { - start_time_ps: 2000 - end_time_ps: 4000 - request_id: 1 - related_batch_ids: 0 - host_runtime_ps: 100 - } - request_details { - start_time_ps: 3000 - end_time_ps: 6000 - request_id: 2 - related_batch_ids: 1 - host_runtime_ps: 100 - } - request_details { - start_time_ps: 4000 - end_time_ps: 8000 - request_id: 3 - related_batch_ids: 1 - host_runtime_ps: 100 - } - batch_details { - batch_id: 0 - related_request_ids: 0 - related_request_ids: 1 - start_time_ps: 1000 - end_time_ps: 2000 - batch_size_after_padding: 128 - } - batch_details { - batch_id: 1 - related_request_ids: 2 - related_request_ids: 3 - start_time_ps: 3000 - end_time_ps: 4000 - batch_size_after_padding: 256 - } - aggregated_request_detail { - request_id: -1 - start_time_ps: 0 - end_time_ps: 2500 - write_to_device_time_ps: 0 - read_from_device_time_ps: 0 - device_time_ps: 0 - batching_request_delay_ps: 0 - batching_request_size: 0 - host_preprocessing_ps: 0 - host_batch_formation_ps: 0 - host_runtime_ps: 100 - host_postprocessing_ps: 0 - idle_time_ps: 0 - } - aggregated_batch_detail { - batch_id: -1 - start_time_ps: 0 - end_time_ps: 1000 - batch_delay_ps: 0 - padding_amount: 0 - device_time_ps: 0 - batch_size_after_padding: 192 - } - per_batch_size_aggregated_result { - batch_size: 128 - aggregated_request_result { - start_time_ps: 0 - end_time_ps: 1500 - write_to_device_time_ps: 0 - read_from_device_time_ps: 0 - device_time_ps: 0 - request_id: -1 - batching_request_delay_ps: 0 - batching_request_size: 0 - host_preprocessing_ps: 0 - host_batch_formation_ps: 0 - host_runtime_ps: 100 - host_postprocessing_ps: 0 - idle_time_ps: 0 - } - aggregated_batch_result { - batch_id: -1 - start_time_ps: 0 - end_time_ps: 1000 - batch_delay_ps: 0 - padding_amount: 0 - device_time_ps: 0 - batch_size_after_padding: 128 - } - request_throughput: 285714285.71428573 - batch_throughput: 333333333.33333331 - } - per_batch_size_aggregated_result { - batch_size: 256 - aggregated_request_result { - start_time_ps: 0 - end_time_ps: 3500 - write_to_device_time_ps: 0 - read_from_device_time_ps: 0 - device_time_ps: 0 - request_id: -1 - batching_request_delay_ps: 0 - batching_request_size: 0 - host_preprocessing_ps: 0 - host_batch_formation_ps: 0 - host_runtime_ps: 100 - host_postprocessing_ps: 0 - idle_time_ps: 0 - } - aggregated_batch_result { - batch_id: -1 - start_time_ps: 0 - end_time_ps: 1000 - batch_delay_ps: 0 - padding_amount: 0 - device_time_ps: 0 - batch_size_after_padding: 256 - } - request_throughput: 285714285.71428573 - batch_throughput: 333333333.33333331 - } - request_throughput: 571428571.42857146 - request_average_latency_us: 0.0025 - batch_throughput: 666666666.66666663 - batch_average_latency_us: 0.001 - } - })pb")); -} - -} // namespace -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/inference_stats_sampler.cc b/tensorflow/core/profiler/convert/inference_stats_sampler.cc deleted file mode 100644 index be3e392ec85e78..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_sampler.cc +++ /dev/null @@ -1,311 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/inference_stats_sampler.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { - -namespace { - -using ::tensorflow::profiler::BatchDetail; -using ::tensorflow::profiler::InferenceStats; -using ::tensorflow::profiler::PerModelInferenceStats; -using ::tensorflow::profiler::RequestDetail; - -// Column names that can be used to do percentile selection. -// For request: -constexpr char kColumnLatencyUs[] = "Latency"; -constexpr char kColumnBatchingRequestDelayUs[] = "Request delay for batching"; -constexpr char kColumnBatchingRequestSize[] = "Request size"; -constexpr char kColumnHostPreprocessing[] = "Host preprocess"; -constexpr char kColumnHostBatchFormation[] = "Host batch formation"; -constexpr char kColumnHostRuntime[] = "Host runtime"; -constexpr char kColumnHostToDevice[] = "Data transfer H2D"; -constexpr char kColumnDeviceToHost[] = "Data transfer D2H"; -constexpr char kColumnDeviceCompute[] = "Device compute"; -constexpr char kColumnHostPostprocessing[] = "Host postprocess"; -constexpr char kColumnIdleTime[] = "Idle time"; -// For batch: -constexpr char kColumnBatchingDelayUs[] = "Batching delay"; -constexpr char kColumnPaddingAmount[] = "Padding amount"; -constexpr char kColumnBatchSizeAfterPadding[] = "Batch size after padding"; -constexpr char kColumnBatchingEfficiency[] = "Batching efficiency"; - -double CalculateBatchingEfficiency(const BatchDetail& batch) { - return tsl::profiler::SafeDivide( - static_cast(batch.batch_size_after_padding() - - batch.padding_amount()), - static_cast(batch.batch_size_after_padding())); -} - -// Comparator for RequestDetail proto. -bool CompareByRequestLatency(const RequestDetail* a, const RequestDetail* b) { - return (a->end_time_ps() - a->start_time_ps()) < - (b->end_time_ps() - b->start_time_ps()); -} -bool CompareByBatchingRequestDelay(const RequestDetail* a, - const RequestDetail* b) { - return a->batching_request_delay_ps() < b->batching_request_delay_ps(); -} -bool CompareByBatchingRequestSize(const RequestDetail* a, - const RequestDetail* b) { - return a->batching_request_size() < b->batching_request_size(); -} -bool CompareByHostPreprocessing(const RequestDetail* a, - const RequestDetail* b) { - return a->host_preprocessing_ps() < b->host_preprocessing_ps(); -} -bool CompareByHostBatchFormation(const RequestDetail* a, - const RequestDetail* b) { - return a->host_batch_formation_ps() < b->host_batch_formation_ps(); -} -bool CompareByHostRuntime(const RequestDetail* a, const RequestDetail* b) { - return a->host_runtime_ps() < b->host_runtime_ps(); -} -bool CompareByHostToDevice(const RequestDetail* a, const RequestDetail* b) { - return a->write_to_device_time_ps() < b->write_to_device_time_ps(); -} -bool CompareByDeviceToHost(const RequestDetail* a, const RequestDetail* b) { - return a->read_from_device_time_ps() < b->read_from_device_time_ps(); -} -bool CompareByDeviceCompute(const RequestDetail* a, const RequestDetail* b) { - return a->device_time_ps() < b->device_time_ps(); -} -bool CompareByPostProcessing(const RequestDetail* a, const RequestDetail* b) { - return a->host_postprocessing_ps() < b->host_postprocessing_ps(); -} -bool CompareByIdleTime(const RequestDetail* a, const RequestDetail* b) { - return a->idle_time_ps() < b->idle_time_ps(); -} -// Use percentile column name to get the corresponding compare function. -std::function -GetRequestCompareFunction(absl::string_view column_name) { - if (column_name == kColumnBatchingRequestDelayUs) { - return CompareByBatchingRequestDelay; - } else if (column_name == kColumnBatchingRequestSize) { - return CompareByBatchingRequestSize; - } else if (column_name == kColumnHostPreprocessing) { - return CompareByHostPreprocessing; - } else if (column_name == kColumnHostBatchFormation) { - return CompareByHostBatchFormation; - } else if (column_name == kColumnHostRuntime) { - return CompareByHostRuntime; - } else if (column_name == kColumnHostToDevice) { - return CompareByHostToDevice; - } else if (column_name == kColumnDeviceToHost) { - return CompareByDeviceToHost; - } else if (column_name == kColumnDeviceCompute) { - return CompareByDeviceCompute; - } else if (column_name == kColumnHostPostprocessing) { - return CompareByPostProcessing; - } else if (column_name == kColumnIdleTime) { - return CompareByIdleTime; - } else { - // Return CompareByRequestLatency by default. - return CompareByRequestLatency; - } -} - -// Comparator for BatchDetail proto. -bool CompareByBatchLatency(const BatchDetail* a, const BatchDetail* b) { - return (a->end_time_ps() - a->start_time_ps()) < - (b->end_time_ps() - b->start_time_ps()); -} -bool CompareByBatchDelay(const BatchDetail* a, const BatchDetail* b) { - return a->batch_delay_ps() < b->batch_delay_ps(); -} -bool CompareByPaddingAmount(const BatchDetail* a, const BatchDetail* b) { - return a->padding_amount() < b->padding_amount(); -} -bool CompareByBatchSizeAfterPadding(const BatchDetail* a, - const BatchDetail* b) { - return a->batch_size_after_padding() < b->batch_size_after_padding(); -} -bool CompareByBatchingEfficiency(const BatchDetail* a, const BatchDetail* b) { - return CalculateBatchingEfficiency(*a) < CalculateBatchingEfficiency(*b); -} -// Use percentile column name to get the corresponding compare function. -std::function -GetBatchCompareFunction(absl::string_view column_name) { - if (column_name == kColumnBatchingDelayUs) { - return CompareByBatchDelay; - } else if (column_name == kColumnPaddingAmount) { - return CompareByPaddingAmount; - } else if (column_name == kColumnBatchSizeAfterPadding) { - return CompareByBatchSizeAfterPadding; - } else if (column_name == kColumnBatchingEfficiency) { - return CompareByBatchingEfficiency; - } else { - // Return CompareByBatchLatency by default. - return CompareByBatchLatency; - } -} - -// A static helper class to select a subset of inference data (request or batch) -// to show in the frontend. -// DataType can be either RequestDetail or BatchDetail. -template -class PercentileSelector { - public: - // The range of values in [percentile, perentile+error) are still regarded as - // percentile. - struct PercentileRange { - double percentile; - double error; - }; - - // The percentiles (with the corresponding error bounds) that will be included - // in inference profile result. - static constexpr std::array kWantedPercentiles = { - {{50.0, 1}, - {75.0, 1}, - {90.0, 1}, - {99.0, 0.5}, - {99.9, 0.05}, - {99.99, 0.005}}}; - - // Maximum number of values included for each percentile range. - static constexpr size_t kMaxNumDataSelectedPerPercentile = 10; - - // Select a subset of data from , return pointer to the original - // data and the percentile. - static std::vector> Select( - const std::vector& all_data) { - return SelectInternal(all_data); - } - - private: - static bool GreaterThan(double percentile, const PercentileRange& wanted) { - // Uses ">=" instead of ">" so that the round-up value is not included. - return percentile >= (wanted.percentile + wanted.error); - } - - static bool LessThan(double percentile, const PercentileRange& wanted) { - return percentile < wanted.percentile; - } - - static bool WithinRange(double percentile, const PercentileRange& wanted) { - return !GreaterThan(percentile, wanted) && !LessThan(percentile, wanted); - } - - static std::vector> SelectInternal( - const std::vector& all_data) { - std::vector> result; - // If the number of data points is too small (smaller than the result size - // when select by percentile, like in a unit test), it does not make sense - // to select by percentile, just select all the data points and the frontend - // is able to display all of them. - if (all_data.size() <= - kWantedPercentiles.size() * kMaxNumDataSelectedPerPercentile) { - for (size_t i = 0; i < all_data.size(); i++) { - double percentile = 100.0 * i / all_data.size(); - result.push_back(std::make_pair(all_data[i], percentile)); - } - return result; - } - - // Select by percentile. - size_t idx_to_next_data = 0; - for (size_t i = 0; i < kWantedPercentiles.size(); i++) { - const auto& wanted = kWantedPercentiles[i]; - size_t num_data_selected = 0; - for (size_t k = idx_to_next_data; k < all_data.size(); k++) { - double percentile = 100.0 * k / all_data.size(); - if (GreaterThan(percentile, wanted)) { - // Updates idx_to_next_data to k so that when we select data for the - // next percentile we don't need to consider the data with smaller - // latenices than that for the next percentile. - idx_to_next_data = k; - break; - } - if (WithinRange(percentile, wanted)) { - if (num_data_selected < kMaxNumDataSelectedPerPercentile) { - // Selects this data only if we have not hit the limit for this - // percentile. - result.push_back(std::make_pair(all_data[k], percentile)); - ++num_data_selected; - } - } - } - } - return result; - } -}; - -// Sample the requests and batches in using sampling column -// and . -void SamplePerModelInferenceStats( - absl::string_view request_percentile_column, - absl::string_view batch_percentile_column, - const PerModelInferenceStats& per_model_stats, - SampledPerModelInferenceStats* sampled_per_model_stats) { - // Select a subset of requests and batches based on percentile and generate - // final result. - std::vector requests( - per_model_stats.request_details_size()); - for (size_t i = 0; i < per_model_stats.request_details_size(); i++) { - requests[i] = &per_model_stats.request_details(i); - } - // Requests in per model stats are already sorted by latency. Only redo the - // sorting when percentile column is not latency. - if (request_percentile_column != kColumnLatencyUs) { - std::sort(requests.begin(), requests.end(), - GetRequestCompareFunction(request_percentile_column)); - } - sampled_per_model_stats->sampled_requests = - PercentileSelector::Select(requests); - - std::vector batches(per_model_stats.batch_details_size()); - for (size_t i = 0; i < per_model_stats.batch_details_size(); i++) { - batches[i] = &per_model_stats.batch_details(i); - } - // Batches in per model stats are already sorted by latency. Only redo the - // sorting when percentile column is not latency. - if (batch_percentile_column != kColumnLatencyUs) { - std::sort(batches.begin(), batches.end(), - GetBatchCompareFunction(batch_percentile_column)); - } - sampled_per_model_stats->sampled_batches = - PercentileSelector::Select(batches); -} - -} // namespace - -SampledInferenceStats SampleInferenceStats( - absl::string_view request_percentile_column, - absl::string_view batch_percentile_column, - const InferenceStats& inference_stats) { - SampledInferenceStats result; - for (const auto& [model_index, model_inference_stats] : - inference_stats.inference_stats_per_model()) { - SamplePerModelInferenceStats(request_percentile_column, - batch_percentile_column, model_inference_stats, - &(result[model_index])); - } - - return result; -} - -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/inference_stats_sampler.h b/tensorflow/core/profiler/convert/inference_stats_sampler.h deleted file mode 100644 index 2706c16a8ff97a..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_sampler.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_SAMPLER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_SAMPLER_H_ - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { - -// Sampled inference stats of a model. -// The pointers of RequestDetail and BatchDetail point to the actual data stored -// in TfOpStats.InferenceStats. -struct SampledPerModelInferenceStats { - // Sampled requests and their percentile. - std::vector> - sampled_requests; - // Sampled batches and their percentile. - std::vector> - sampled_batches; -}; - -// All the sampled inference stats of a profile. -// TODO: Move to use SampledInferenceStatsProto if feasible. -using SampledInferenceStats = - absl::flat_hash_map; - -// Samples a subset of InferenceStats from based on sampling -// column and . -SampledInferenceStats SampleInferenceStats( - absl::string_view request_percentile_column, - absl::string_view batch_percentile_column, - const tensorflow::profiler::InferenceStats& inference_stats); - -} // namespace tensorflow::profiler - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_SAMPLER_H_ diff --git a/tensorflow/core/profiler/convert/inference_stats_sampler_test.cc b/tensorflow/core/profiler/convert/inference_stats_sampler_test.cc deleted file mode 100644 index 72c35a520a4cf4..00000000000000 --- a/tensorflow/core/profiler/convert/inference_stats_sampler_test.cc +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/inference_stats_sampler.h" - -#include "absl/status/statusor.h" -#include "xla/tests/test_utils.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" - -namespace tensorflow::profiler { -namespace { -using ::tensorflow::profiler::InferenceStats; -using xla::ParseTextProto; - -TEST(ConvertInferenceStatsToInferenceProfileTest, TestSort) { - // Generate an inference stats for test. - // Requests and batches are ordered by latency (end_time_ps - start_time_ps), - // this is guaranteed by inference_stats.cc - InferenceStats inference_stats = ParseTextProto( - R"pb( - inference_stats_per_model { - key: 1 - value { - request_details { - request_id: 0 - start_time_ps: 0 - end_time_ps: 10000 - batching_request_delay_ps: 2000 - batching_request_size: 200 - } - request_details { - request_id: 1 - start_time_ps: 0 - end_time_ps: 20000 - batching_request_delay_ps: 1000 - batching_request_size: 100 - } - request_details { - request_id: 2 - start_time_ps: 0 - end_time_ps: 30000 - batching_request_delay_ps: 3000 - batching_request_size: 300 - } - batch_details { - batch_id: 3 - start_time_ps: 0 - end_time_ps: 10000 - batch_delay_ps: 2000 - padding_amount: 20 - batch_size_after_padding: 200 - } - batch_details { - batch_id: 4 - start_time_ps: 0 - end_time_ps: 20000 - batch_delay_ps: 1000 - padding_amount: 10 - batch_size_after_padding: 100 - } - batch_details { - batch_id: 5 - start_time_ps: 0 - end_time_ps: 30000 - batch_delay_ps: 3000 - padding_amount: 30 - batch_size_after_padding: 300 - } - } - } - )pb") - .value(); - - // Sort by latency, the result does not change. - auto result_1 = SampleInferenceStats("Latency", "Latency", inference_stats); - const auto& per_model_1 = result_1.at(1); - EXPECT_EQ(per_model_1.sampled_requests.at(0).first->request_id(), 0); - EXPECT_EQ(per_model_1.sampled_requests.at(1).first->request_id(), 1); - EXPECT_EQ(per_model_1.sampled_requests.at(2).first->request_id(), 2); - EXPECT_EQ(per_model_1.sampled_batches.at(0).first->batch_id(), 3); - EXPECT_EQ(per_model_1.sampled_batches.at(1).first->batch_id(), 4); - EXPECT_EQ(per_model_1.sampled_batches.at(2).first->batch_id(), 5); - - // Sort requests by Request size, sort batches by Padding amount. - // Verifies the values are in increasing order. - auto result_2 = - SampleInferenceStats("Request size", "Padding amount", inference_stats); - const auto& per_model_2 = result_2.at(1); - EXPECT_EQ(per_model_2.sampled_requests.at(0).first->batching_request_size(), - 100); - EXPECT_EQ(per_model_2.sampled_requests.at(1).first->batching_request_size(), - 200); - EXPECT_EQ(per_model_2.sampled_requests.at(2).first->batching_request_size(), - 300); - EXPECT_EQ(per_model_2.sampled_batches.at(0).first->padding_amount(), 10); - EXPECT_EQ(per_model_2.sampled_batches.at(1).first->padding_amount(), 20); - EXPECT_EQ(per_model_2.sampled_batches.at(2).first->padding_amount(), 30); - - // Sort requests by Request delay for batching, sort batches by - // Batching delay. Verifies the values are in increasing order. - auto result_3 = SampleInferenceStats("Request delay for batching", - "Batching delay", inference_stats); - const auto& per_model_3 = result_3.at(1); - EXPECT_EQ( - per_model_3.sampled_requests.at(0).first->batching_request_delay_ps(), - 1000); - EXPECT_EQ( - per_model_3.sampled_requests.at(1).first->batching_request_delay_ps(), - 2000); - EXPECT_EQ( - per_model_3.sampled_requests.at(2).first->batching_request_delay_ps(), - 3000); - EXPECT_EQ(per_model_3.sampled_batches.at(0).first->batch_delay_ps(), 1000); - EXPECT_EQ(per_model_3.sampled_batches.at(1).first->batch_delay_ps(), 2000); - EXPECT_EQ(per_model_3.sampled_batches.at(2).first->batch_delay_ps(), 3000); -} - -} // namespace -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc b/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc deleted file mode 100644 index e01b645d3b19b6..00000000000000 --- a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h" - -#include -#include - -#include "absl/status/status.h" -#include "xla/tsl/platform/statusor.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_stats_combiner.h" -#include "tensorflow/core/profiler/convert/preprocess_single_host_xplane.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" -#include "tensorflow/core/profiler/utils/step_intersection.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -absl::Status ConvertMultiXSpacesToCombinedOpStats( - const SessionSnapshot& session_snapshot, const OpStatsOptions& options, - OpStats* combined_op_stats) { - // Read multiple XSpaces and convert to multiple OpStats. - // TODO(profiler): Change the combiner to convert and combine one OpStats at a - // time, to reduce peak memory usage. - std::vector all_op_stats; - all_op_stats.reserve(session_snapshot.XSpaceSize()); - for (int i = 0; i < session_snapshot.XSpaceSize(); i++) { - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(i)); - PreprocessSingleHostXSpace(xspace.get(), /*step_grouping=*/true, - /*derived_timeline=*/true); - all_op_stats.push_back(ConvertXSpaceToOpStats(*xspace, options)); - } - - // Combine OpStats. - std::vector all_op_stats_info; - all_op_stats_info.reserve(all_op_stats.size()); - for (int i = 0; i < all_op_stats.size(); i++) { - all_op_stats_info.emplace_back( - &all_op_stats[i], - ParseHardwareType(all_op_stats[i].run_environment().device_type()), i); - } - - // Do not limit the maximum number of steps during the merge of OpStats. - StepIntersection step_intersection = - ComputeStepIntersectionToMergeOpStats(all_op_stats_info, kuint32max); - CombineAllOpStats(all_op_stats_info, step_intersection, combined_op_stats); - - return absl::OkStatus(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h b/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h deleted file mode 100644 index 51348097d321f3..00000000000000 --- a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XPLANES_TO_OP_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XPLANES_TO_OP_STATS_H_ - -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -// Converts and combines multiple XSpace protos into a single OpStats -// . -// Return the first error status during conversion, or return OkStatus() if -// there is no error. -absl::Status ConvertMultiXSpacesToCombinedOpStats( - const SessionSnapshot& session_snapshot, const OpStatsOptions& options, - OpStats* combined_op_stats); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XPLANES_TO_OP_STATS_H_ diff --git a/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.cc b/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.cc deleted file mode 100644 index e1a466665492dc..00000000000000 --- a/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.cc +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h" - -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/utils/device_utils.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/convert/inference_stats.h" -#include "tensorflow/core/profiler/convert/inference_stats_combiner.h" -#include "tensorflow/core/profiler/convert/inference_stats_grouping.h" -#include "tensorflow/core/profiler/convert/inference_stats_sampler.h" -#include "tensorflow/core/profiler/convert/preprocess_single_host_xplane.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/xplane_to_step_events.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow::profiler { - -namespace { -using tsl::profiler::FindMutablePlanesWithPrefix; -using tsl::profiler::FindMutablePlaneWithName; - -SampledInferenceStatsProto GetSampledInferenceStatsProto( - const InferenceStats& inference_stats, absl::string_view request_column, - absl::string_view batch_column) { - SampledInferenceStatsProto result; - SampledInferenceStats sampled_stats = - SampleInferenceStats(request_column, batch_column, inference_stats); - for (const auto& [model_index, samples] : sampled_stats) { - SampledPerModelInferenceStatsProto per_model_stats; - for (const auto& [request, percentile] : samples.sampled_requests) { - RequestDetail request_detail = *request; - request_detail.set_percentile(percentile); - *per_model_stats.add_sampled_requests() = request_detail; - } - for (const auto& [batch, percentile] : samples.sampled_batches) { - BatchDetail batch_detail = *batch; - batch_detail.set_percentile(percentile); - *per_model_stats.add_sampled_batches() = batch_detail; - } - result.mutable_sampled_inference_stats_per_model()->insert( - {model_index, per_model_stats}); - } - return result; -} -} // namespace - -StepEvents GetNonOverlappedStepEvents(XSpace* xspace) { - StepEvents non_overlapped_step_events; - - std::vector device_traces = - FindMutablePlanesWithPrefix(xspace, kGpuPlanePrefix); - if (device_traces.empty()) return non_overlapped_step_events; - - StepEvents device_step_events; - StepEvents host_step_events; - for (XPlane* device_trace : device_traces) { - StepEvents events = ConvertDeviceTraceXPlaneToStepEvents(*device_trace); - UnionCombineStepEvents(events, &device_step_events); - } - - XPlaneVisitor host_plane = tsl::profiler::CreateTfXPlaneVisitor( - FindMutablePlaneWithName(xspace, kHostThreadsPlaneName)); - - host_plane.ForEachLine([&](const XLineVisitor& line) { - StepEvents events = - ConvertHostThreadsXLineToStepEvents(line, &device_step_events); - UnionCombineStepEvents(events, &host_step_events); - }); - StepEvents overlapped_step_events; - UnionCombineStepEvents(device_step_events, &overlapped_step_events); - UnionCombineStepEvents(host_step_events, &overlapped_step_events); - non_overlapped_step_events = - ToNonOverlappedStepEvents(overlapped_step_events); - return non_overlapped_step_events; -} - -absl::Status ConvertMultiXSpaceToInferenceStats( - const SessionSnapshot& session_snapshot, absl::string_view request_column, - absl::string_view batch_column, InferenceStats* inference_stats) { - for (int i = 0; i < session_snapshot.XSpaceSize(); ++i) { - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(i)); - tsl::profiler::GroupMetadataMap metadata_map; - InferenceStats inference_stats_per_host; - std::vector device_traces = - tsl::profiler::FindMutableTensorCorePlanes(xspace.get()); - PreprocessSingleHostXSpace(xspace.get(), /*step_grouping=*/true, - /*derived_timeline=*/false, &metadata_map); - StepEvents non_overlapped_step_events = - GetNonOverlappedStepEvents(xspace.get()); - GenerateInferenceStats( - device_traces, non_overlapped_step_events, metadata_map, *xspace, - tsl::profiler::DeviceType::kTpu, i, &inference_stats_per_host); - CombineInferenceStatsResult(i, inference_stats_per_host, inference_stats); - } - RegroupInferenceStatsByModel(inference_stats); - *inference_stats->mutable_sampled_inference_stats() = - GetSampledInferenceStatsProto(*inference_stats, request_column, - batch_column); - return absl::OkStatus(); -} -} // namespace tensorflow::profiler diff --git a/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h b/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h deleted file mode 100644 index 8214921600efea..00000000000000 --- a/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XSPACE_TO_INFERENCE_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XSPACE_TO_INFERENCE_STATS_H_ - -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" - -namespace tensorflow::profiler { -// Get non overlapped step events from xspace for GPU. -StepEvents GetNonOverlappedStepEvents(XSpace* xspace); - -absl::Status ConvertMultiXSpaceToInferenceStats( - const SessionSnapshot& session_snapshot, absl::string_view request_column, - absl::string_view batch_column, InferenceStats* inference_stats); -} // namespace tensorflow::profiler - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XSPACE_TO_INFERENCE_STATS_H_ diff --git a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc b/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc deleted file mode 100644 index ab0c25b6f38c33..00000000000000 --- a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using OperationType = OpMetrics::MemoryAccessed::OperationType; - -void CombinePrecisionStats(const PrecisionStats& src, PrecisionStats* dst) { - dst->set_compute_16bit_ps(src.compute_16bit_ps() + dst->compute_16bit_ps()); - dst->set_compute_32bit_ps(src.compute_32bit_ps() + dst->compute_32bit_ps()); -} - -} // namespace - -void CopyOpMetricsMetadata(const OpMetrics& src, OpMetrics* dst) { - DCHECK(dst != nullptr); - DCHECK_EQ(src.hlo_module_id(), dst->hlo_module_id()); - DCHECK_EQ(src.name(), dst->name()); - if (dst->long_name().empty()) { - dst->set_long_name(src.long_name()); - } - if (dst->fingerprint() == 0) { - dst->set_fingerprint(src.fingerprint()); - } - if (dst->category().empty()) { - dst->set_category(src.category()); - } - if (dst->provenance().empty()) { - dst->set_provenance(src.provenance()); - } - if (dst->deduplicated_name().empty()) { - dst->set_deduplicated_name(src.deduplicated_name()); - } - if (!dst->has_layout() && src.has_layout()) { - *dst->mutable_layout() = src.layout(); - } - if (!dst->has_children() && src.has_children()) { - *dst->mutable_children() = src.children(); - } -} - -void CombineOpMetrics(const OpMetrics& src, OpMetrics* dst, - bool update_num_cores) { - DCHECK(dst != nullptr); - if (dst->occurrences() == 0) { - dst->set_min_time_ps(src.min_time_ps()); - } else { - dst->set_min_time_ps(std::min(src.min_time_ps(), dst->min_time_ps())); - } - dst->set_is_eager(dst->is_eager() || src.is_eager()); - dst->set_occurrences(src.occurrences() + dst->occurrences()); - dst->set_time_ps(src.time_ps() + dst->time_ps()); - dst->set_self_time_ps(src.self_time_ps() + dst->self_time_ps()); - dst->set_flops(src.flops() + dst->flops()); - dst->set_model_flops(src.model_flops() + dst->model_flops()); - dst->set_bytes_accessed(src.bytes_accessed() + dst->bytes_accessed()); - dst->set_autotuned(dst->autotuned() || src.autotuned()); - if (update_num_cores) { - dst->set_num_cores(src.num_cores() + dst->num_cores()); - } - CombineMemoryAccessedBreakdown(src.memory_accessed_breakdown(), - dst->mutable_memory_accessed_breakdown()); - dst->set_dma_stall_ps(src.dma_stall_ps() + dst->dma_stall_ps()); -} - -void CombineMemoryAccessedBreakdown( - const tsl::protobuf::RepeatedPtrField& src, - tsl::protobuf::RepeatedPtrField* dst) { - if (src.empty()) return; - absl::flat_hash_map, - OpMetrics_MemoryAccessed*> - dst_memory_accessed_map; - for (auto& dst_memory_accessed : *dst) { - dst_memory_accessed_map[{dst_memory_accessed.memory_space(), - dst_memory_accessed.operation_type()}] = - &dst_memory_accessed; - } - for (const auto& src_memory_accessed : src) { - uint64 memory_space = src_memory_accessed.memory_space(); - OperationType operation_type = src_memory_accessed.operation_type(); - auto*& dst_memory_accessed = - dst_memory_accessed_map[{memory_space, operation_type}]; - if (dst_memory_accessed == nullptr) { - dst_memory_accessed = dst->Add(); - dst_memory_accessed->set_memory_space(memory_space); - dst_memory_accessed->set_operation_type(operation_type); - } - dst_memory_accessed->set_bytes_accessed( - src_memory_accessed.bytes_accessed() + - dst_memory_accessed->bytes_accessed()); - } -} - -void OpMetricsDbCombiner::Combine(const OpMetricsDb& src, - bool update_num_cores) { - OpMetricsDb* dst = db(); - dst->set_total_host_infeed_enq_duration_ps( - src.total_host_infeed_enq_duration_ps() + - dst->total_host_infeed_enq_duration_ps()); - dst->set_total_host_infeed_enq_start_timestamp_ps_diff( - src.total_host_infeed_enq_start_timestamp_ps_diff() + - dst->total_host_infeed_enq_start_timestamp_ps_diff()); - dst->set_total_time_ps(src.total_time_ps() + dst->total_time_ps()); - dst->set_total_op_time_ps(src.total_op_time_ps() + dst->total_op_time_ps()); - dst->set_idle_time_ps(src.idle_time_ps() + dst->idle_time_ps()); - dst->set_busy_time_ps(src.busy_time_ps() + dst->busy_time_ps()); - CombinePrecisionStats(src.precision_stats(), dst->mutable_precision_stats()); - - for (const auto& src_metrics : src.metrics_db()) { - auto* dst_metrics = LookupOrInsertNewOpMetrics(src_metrics.hlo_module_id(), - src_metrics.name()); - CopyOpMetricsMetadata(src_metrics, dst_metrics); - CombineOpMetrics(src_metrics, dst_metrics, update_num_cores); - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_metrics_db_combiner.h b/tensorflow/core/profiler/convert/op_metrics_db_combiner.h deleted file mode 100644 index 76019da86cd467..00000000000000 --- a/tensorflow/core/profiler/convert/op_metrics_db_combiner.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_DB_COMBINER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_DB_COMBINER_H_ - -#include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" - -namespace tensorflow { -namespace profiler { - -// Copies OpMetrics metadata (e.g., category, provenance) from src to dst. -void CopyOpMetricsMetadata(const OpMetrics& src, OpMetrics* dst); - -// Combines OpMetrics data (e.g., occurrences, time) from src into dst. -// If is set to true, update the dst->num_cores to -// calculate the number of cores a certain op occurs. -void CombineOpMetrics(const OpMetrics& src, OpMetrics* dst, - bool update_num_cores); - -// Combines the memory access breakdown. -void CombineMemoryAccessedBreakdown( - const protobuf::RepeatedPtrField& src, - protobuf::RepeatedPtrField* dst); - -// Helper to combine op metrics databases. -class OpMetricsDbCombiner : public OpMetricsDbBuilder { - public: - explicit OpMetricsDbCombiner(OpMetricsDb* dst) : OpMetricsDbBuilder(dst) {} - - // Combine the OpMetrics in OpMetricsDb to current OpMetricsDbCombiner. - // If is set to true, update the OpMetrics.num_cores to - // calculate the number of cores a certain op occurs. - void Combine(const OpMetricsDb& src, bool update_num_cores = true); -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_DB_COMBINER_H_ diff --git a/tensorflow/core/profiler/convert/op_metrics_to_record.cc b/tensorflow/core/profiler/convert/op_metrics_to_record.cc deleted file mode 100644 index b6f1cadb59388c..00000000000000 --- a/tensorflow/core/profiler/convert/op_metrics_to_record.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" - -#include -#include - -#include "absl/algorithm/container.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" - -namespace tensorflow { -namespace profiler { - -std::vector SortedOpMetricsDb(const OpMetricsDb& metrics_db, - int max_records) { - std::vector result; - result.reserve(metrics_db.metrics_db_size()); - for (const OpMetrics& metrics : metrics_db.metrics_db()) { - result.push_back(&metrics); - } - - auto comp = [](const OpMetrics* a, const OpMetrics* b) { - return std::make_tuple(a->self_time_ps(), b->name()) > - std::make_tuple(b->self_time_ps(), a->name()); - }; - int result_size = result.size(); - if (max_records != -1 && result_size > max_records) { - absl::c_partial_sort(result, result.begin() + max_records, comp); - result.resize(max_records); - } else { - absl::c_sort(result, comp); - } - return result; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_metrics_to_record.h b/tensorflow/core/profiler/convert/op_metrics_to_record.h deleted file mode 100644 index 4884fb64adc24c..00000000000000 --- a/tensorflow/core/profiler/convert/op_metrics_to_record.h +++ /dev/null @@ -1,341 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_TO_RECORD_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_TO_RECORD_H_ - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -std::vector SortedOpMetricsDb(const OpMetricsDb& metrics_db, - int max_records = -1); - -inline double GigaFlopsPerSecondPerCore(const OpMetrics& metrics) { - // flops and time_ps are accumulated across all occurrences on all cores. - // time_ps is used instead of self_time_ps because flops for an op includes - // the flops executed by children (nested) ops. - return tsl::profiler::SafeDivide( - metrics.flops(), tsl::profiler::PicoToNano(metrics.time_ps())); -} - -inline double GigaModelFlopsPerSecondPerCore(const OpMetrics& metrics) { - // flops and time_ps are accumulated across all occurrences on all cores. - // time_ps is used instead of self_time_ps because flops for an op includes - // the flops executed by children (nested) ops. - return tsl::profiler::SafeDivide( - metrics.model_flops(), tsl::profiler::PicoToNano(metrics.time_ps())); -} - -// Return ByteAccessed for memory_space and operation_type. -inline double BytesAccessedPerCore( - const OpMetrics& metrics, uint64_t memory_space, - OpMetrics::MemoryAccessed::OperationType operation_type) { - uint64_t bytes = 0; - if (memory_space == MemorySpace::MEMORY_SPACE_ALL) { - bytes = metrics.bytes_accessed(); - } else { - for (const auto& breakdown : metrics.memory_accessed_breakdown()) { - // Count either on-chip or off-chip bytes. - if ((breakdown.operation_type() != operation_type) && - (operation_type != OpMetrics::MemoryAccessed::UNKNOWN)) { - continue; - } - if (((memory_space == MemorySpace::MEMORY_SPACE_HBM) && - (breakdown.memory_space() == MemorySpace::MEMORY_SPACE_HBM)) || - ((memory_space == MemorySpace::MEMORY_SPACE_ON_CHIP) && - (breakdown.memory_space() != MemorySpace::MEMORY_SPACE_HBM))) { - bytes += breakdown.bytes_accessed(); - } - } - } - return bytes; -} - -inline double GigaBytesPerSecondPerCore( - const OpMetrics& metrics, uint64_t memory_space, - OpMetrics::MemoryAccessed::OperationType operation_type) { - // bytes_accessed and time_ps are accumulated across all occurrences on all - // cores. - // time_ps is used instead of self_time_ps because bytes_accessed for an op - // includes the bytes accessed by children (nested) ops. - return tsl::profiler::SafeDivide( - BytesAccessedPerCore(metrics, memory_space, operation_type), - tsl::profiler::PicoToNano(metrics.time_ps())); -} - -inline double GibiBytesPerSecondPerCore( - const OpMetrics& metrics, uint64_t memory_space, - OpMetrics::MemoryAccessed::OperationType op_type) { - return tsl::profiler::GigaToGibi( - GigaBytesPerSecondPerCore(metrics, memory_space, op_type)); -} - -template -inline void SetExecutionTimes(const OpMetrics& metrics, Record* record) { - record->set_occurrences(metrics.occurrences()); - record->set_total_time_in_us(tsl::profiler::PicoToMicro(metrics.time_ps())); - record->set_avg_time_in_us(tsl::profiler::SafeDivide( - record->total_time_in_us(), metrics.occurrences())); - record->set_total_self_time_in_us( - tsl::profiler::PicoToMicro(metrics.self_time_ps())); - record->set_avg_self_time_in_us(tsl::profiler::SafeDivide( - record->total_self_time_in_us(), metrics.occurrences())); -} - -template -inline void SetTpuUnitFractions(const OpMetrics& metrics, Record* record) { - record->set_dma_stall_fraction( - tsl::profiler::SafeDivide(metrics.dma_stall_ps(), metrics.time_ps())); -} - -template -inline void SetRankAndTimeFractions(double total_time_us, - const Record& prev_record, Record* record) { - record->set_rank(prev_record.rank() + 1); - record->set_total_self_time_as_fraction(tsl::profiler::SafeDivide( - record->total_self_time_in_us(), total_time_us)); - record->set_cumulative_total_self_time_as_fraction( - prev_record.cumulative_total_self_time_as_fraction() + - record->total_self_time_as_fraction()); -} - -template -inline void SetRankAndDeviceTimeFractions(double total_time_us, - const Record& prev_record, - Record* record) { - record->set_rank(prev_record.rank() + 1); - record->set_device_total_self_time_as_fraction(tsl::profiler::SafeDivide( - record->total_self_time_in_us(), total_time_us)); - record->set_device_cumulative_total_self_time_as_fraction( - prev_record.device_cumulative_total_self_time_as_fraction() + - record->device_total_self_time_as_fraction()); -} - -template -inline void SetRankAndHostTimeFractions(double total_time_us, - const Record& prev_record, - Record* record) { - record->set_rank(prev_record.rank() + 1); - record->set_host_total_self_time_as_fraction(tsl::profiler::SafeDivide( - record->total_self_time_in_us(), total_time_us)); - record->set_host_cumulative_total_self_time_as_fraction( - prev_record.host_cumulative_total_self_time_as_fraction() + - record->host_total_self_time_as_fraction()); -} - -// Returns the memory bandwidth in GigaBytes/s in the PerfEnv. -// memory space is chosen by index following order in xplane_to_op_stats.cc -static inline double GetMemoryPeakBandwidth(const PerfEnv& perf_env, - const int index) { - if (perf_env.peak_bws_giga_bytes_per_second_size() > index) { - return perf_env.peak_bws_giga_bytes_per_second(index); - } - return perf_env.peak_hbm_bw_giga_bytes_per_second(); -} - -template -inline void SetRooflineMetrics(const OpMetrics& metrics, const PerfEnv perf_env, - const RunEnvironment& run_env, Record* record) { - using ::tensorflow::profiler::MemorySpace; - using ::tensorflow::profiler::PerformanceInfo; - - // Set overall performance metrics. - record->set_measured_flop_rate(GigaFlopsPerSecondPerCore(metrics)); - record->set_model_flop_rate(GigaModelFlopsPerSecondPerCore(metrics)); - record->set_measured_memory_bw(GibiBytesPerSecondPerCore( - metrics, tensorflow::profiler::MemorySpace::MEMORY_SPACE_ALL, - OpMetrics::MemoryAccessed::UNKNOWN)); - record->set_flops(metrics.flops()); - record->set_bytes_accessed(metrics.bytes_accessed()); - record->set_operational_intensity( - tsl::profiler::SafeDivide(metrics.flops(), metrics.bytes_accessed())); - // Set performance metrics per memory access type. - uint64_t hbm_bytes = 0; - uint64_t cmem_read_bytes = 0; - uint64_t cmem_write_bytes = 0; - uint64_t vmem_read_bytes = 0; - uint64_t vmem_write_bytes = 0; - for (const auto& memory_access : metrics.memory_accessed_breakdown()) { - if (memory_access.memory_space() == PerformanceInfo::MemoryAccessed::HBM) { - hbm_bytes += memory_access.bytes_accessed(); - } else if (memory_access.memory_space() == - PerformanceInfo::MemoryAccessed::CMEM) { - if (memory_access.operation_type() == OpMetrics::MemoryAccessed::READ) { - cmem_read_bytes += memory_access.bytes_accessed(); - } else if (memory_access.operation_type() == - OpMetrics::MemoryAccessed::WRITE) { - cmem_write_bytes += memory_access.bytes_accessed(); - } - } else if (memory_access.memory_space() == - PerformanceInfo::MemoryAccessed::VMEM) { - if (memory_access.operation_type() == OpMetrics::MemoryAccessed::READ) { - vmem_read_bytes += memory_access.bytes_accessed(); - } else if (memory_access.operation_type() == - OpMetrics::MemoryAccessed::WRITE) { - vmem_write_bytes += memory_access.bytes_accessed(); - } - } - } - if (metrics.memory_accessed_breakdown_size() == 0) { - // For legacy profiles without memory access breakdown, consider all memory - // access as HBM access. - hbm_bytes = metrics.bytes_accessed(); - } - record->set_hbm_bw(tsl::profiler::GibibytesPerSecond( - hbm_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); - record->set_cmem_read_bw(tsl::profiler::GibibytesPerSecond( - cmem_read_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); - record->set_cmem_write_bw(tsl::profiler::GibibytesPerSecond( - cmem_write_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); - record->set_vmem_read_bw(tsl::profiler::GibibytesPerSecond( - vmem_read_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); - record->set_vmem_write_bw(tsl::profiler::GibibytesPerSecond( - vmem_write_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); - record->set_hbm_operational_intensity( - tsl::profiler::SafeDivide(metrics.flops(), hbm_bytes)); - record->set_cmem_read_operational_intensity( - tsl::profiler::SafeDivide(metrics.flops(), cmem_read_bytes)); - record->set_cmem_write_operational_intensity( - tsl::profiler::SafeDivide(metrics.flops(), cmem_write_bytes)); - record->set_vmem_read_operational_intensity( - tsl::profiler::SafeDivide(metrics.flops(), vmem_read_bytes)); - record->set_vmem_write_operational_intensity( - tsl::profiler::SafeDivide(metrics.flops(), vmem_write_bytes)); - // Resources considered for roofline analysis. - constexpr absl::string_view kUnknown = "Unknown"; - constexpr absl::string_view kCompute = "Compute"; - constexpr absl::string_view kHbm = "HBM"; - constexpr absl::string_view kCmemRead = "CMEM Read"; - constexpr absl::string_view kCmemWrite = "CMEM Write"; - constexpr absl::string_view kVmemRead = "VMEM Read"; - constexpr absl::string_view kVmemWrite = "VMEM Write"; - constexpr absl::string_view kShmL1 = "Shm/L1"; - // Compute the bound time assuming the peak capacity of each resource and - // choose the highest one as the bottleneck. See go/xprof-roofline-pxc for - // more details. - // NOTE: The roofline analysis result is the same for Megacore because every - // resource's capacity is doubled for Megacore so the comparison result is the - // same. - absl::string_view bottleneck_resource = kUnknown; - double bottleneck_utilization = 0; - double bottleneck_operational_intensity = 0; - double peak_flops = - tsl::profiler::TeraToGiga(perf_env.peak_tera_flops_per_second()); - double flops_utilization = - tsl::profiler::SafeDivide(record->measured_flop_rate(), peak_flops); - if (bottleneck_utilization < flops_utilization) { - bottleneck_resource = kCompute; - bottleneck_utilization = flops_utilization; - bottleneck_operational_intensity = record->operational_intensity(); - } - double peak_hbm_bw = GetMemoryPeakBandwidth(perf_env, 0); - double hbm_bw_utilization = tsl::profiler::SafeDivide( - record->hbm_bw(), tsl::profiler::GigaToGibi(peak_hbm_bw)); - if (bottleneck_utilization < hbm_bw_utilization) { - bottleneck_resource = kHbm; - bottleneck_utilization = hbm_bw_utilization; - bottleneck_operational_intensity = record->hbm_operational_intensity(); - } - tensorflow::profiler::HardwareType hardware_type = run_env.hardware_type(); - if (hardware_type == tensorflow::profiler::HardwareType::TPU) { - if (cmem_read_bytes) { - double peak_cmem_read_bw = GetMemoryPeakBandwidth(perf_env, 3); - if (peak_cmem_read_bw) { - double cmem_read_bw_utilization = tsl::profiler::SafeDivide( - record->cmem_read_bw(), - tsl::profiler::GigaToGibi(peak_cmem_read_bw)); - if (bottleneck_utilization < cmem_read_bw_utilization) { - bottleneck_resource = kCmemRead; - bottleneck_utilization = cmem_read_bw_utilization; - bottleneck_operational_intensity = - record->cmem_read_operational_intensity(); - } - } - } - if (cmem_write_bytes) { - double peak_cmem_write_bw = GetMemoryPeakBandwidth(perf_env, 4); - if (peak_cmem_write_bw) { - double cmem_write_bw_utilization = tsl::profiler::SafeDivide( - record->cmem_write_bw(), - tsl::profiler::GigaToGibi(peak_cmem_write_bw)); - if (bottleneck_utilization < cmem_write_bw_utilization) { - bottleneck_resource = kCmemWrite; - bottleneck_utilization = cmem_write_bw_utilization; - bottleneck_operational_intensity = - record->cmem_write_operational_intensity(); - } - } - } - if (vmem_read_bytes) { - double peak_vmem_read_bw = GetMemoryPeakBandwidth(perf_env, 5); - if (peak_vmem_read_bw) { - double vmem_read_bw_utilization = tsl::profiler::SafeDivide( - record->vmem_read_bw(), - tsl::profiler::GigaToGibi(peak_vmem_read_bw)); - if (bottleneck_utilization < vmem_read_bw_utilization) { - bottleneck_resource = kVmemRead; - bottleneck_utilization = vmem_read_bw_utilization; - bottleneck_operational_intensity = - record->vmem_read_operational_intensity(); - } - } - } - if (vmem_write_bytes) { - double peak_vmem_write_bw = GetMemoryPeakBandwidth(perf_env, 6); - if (peak_vmem_write_bw) { - double vmem_write_bw_utilization = tsl::profiler::SafeDivide( - record->vmem_write_bw(), - tsl::profiler::GigaToGibi(peak_vmem_write_bw)); - if (bottleneck_utilization < vmem_write_bw_utilization) { - bottleneck_resource = kVmemWrite; - bottleneck_utilization = vmem_write_bw_utilization; - bottleneck_operational_intensity = - record->vmem_write_operational_intensity(); - } - } - } - } - if (hardware_type == tensorflow::profiler::HardwareType::GPU) { - double peak_shm_l1_bw = GetMemoryPeakBandwidth(perf_env, 2); - if (peak_shm_l1_bw) { - // Currently, we only have general read/write bandwidth in record. - double shm_l1_bw_utilization = tsl::profiler::SafeDivide( - record->hbm_bw(), tsl::profiler::GigaToGibi(peak_shm_l1_bw)); - if (bottleneck_utilization < shm_l1_bw_utilization) { - bottleneck_resource = kShmL1; - bottleneck_utilization = shm_l1_bw_utilization; - bottleneck_operational_intensity = record->hbm_operational_intensity(); - } - } - } - record->set_bound_by(std::string(bottleneck_resource)); - record->set_bottleneck_operational_intensity( - bottleneck_operational_intensity); -} - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_TO_RECORD_H_ diff --git a/tensorflow/core/profiler/convert/op_profile_builder.cc b/tensorflow/core/profiler/convert/op_profile_builder.cc deleted file mode 100644 index 0fde4660dcd496..00000000000000 --- a/tensorflow/core/profiler/convert/op_profile_builder.cc +++ /dev/null @@ -1,455 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_profile_builder.h" - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/node_hash_map.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/ascii.h" -#include "absl/strings/str_cat.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/lib/gtl/top_n.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_profile.pb.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using op_profile::Metrics; -using op_profile::Node; -using tsl::profiler::IsFusion; - -double CapUtilization(double utilization) { return std::min(utilization, 1.0); } - -// Fill symbol details into a node. -void PopulateSymbolNode(const OpMetrics& op_metrics, Node* node) { - node->set_name(op_metrics.name()); - Node::XLAInstruction& xla = *node->mutable_xla(); - xla.set_program_id(op_metrics.hlo_module_id()); - xla.set_expression(op_metrics.long_name()); - xla.set_fingerprint(op_metrics.fingerprint()); - xla.set_category(op_metrics.category()); - xla.set_provenance(op_metrics.provenance()); - if (op_metrics.has_layout()) { - for (const auto& dimension : op_metrics.layout().dimensions()) { - auto* dim = xla.mutable_layout()->add_dimensions(); - dim->set_size(dimension.size()); - dim->set_alignment(dimension.alignment()); - dim->set_semantics(absl::AsciiStrToLower( - LayoutDimensionSemantics_Name(dimension.semantics()))); - } - } - xla.set_computation_primitive_size(op_metrics.computation_primitive_size()); -} - -// Sort the children and only keep the top K children. -template -Node TopKChildren(const Node* root, int k, Cmp cmp) { - tensorflow::gtl::TopN top_n(k, cmp); - for (const Node& node : root->children()) { - top_n.push(&node); - } - Node output; - std::unique_ptr> extracted_nodes(top_n.Extract()); - for (const Node* node : *extracted_nodes) { - *output.add_children() = *node; - } - return output; -} - -// Copy symbol details into a deduplicated node from the top child node. -void CopySymbolDetailsToDeduplicatedNode(Node* top_child_node, - Node* deduplicated_node) { - deduplicated_node->set_name( - absl::StrCat(top_child_node->name(), " and its duplicate(s)")); - Node::XLAInstruction& xla = *deduplicated_node->mutable_xla(); - const Node::XLAInstruction& top_child_node_xla = top_child_node->xla(); - xla.set_program_id(top_child_node_xla.program_id()); - xla.set_expression(top_child_node_xla.expression()); - xla.set_fingerprint(top_child_node_xla.fingerprint()); - xla.set_category(top_child_node_xla.category()); - if (IsFusion(top_child_node_xla.category())) return; - xla.set_provenance(top_child_node_xla.provenance()); - *xla.mutable_layout() = top_child_node_xla.layout(); -} - -void SortAndPruneChildren(int k, int level, Node* root) { - // Set the total number of children before pruning. - root->set_num_children(root->children_size()); - for (Node& node : *root->mutable_children()) { - SortAndPruneChildren(k, level - 1, &node); - } - k = level > 0 ? root->children_size() : k; - - if (root->children_size() > 1) { - if (root->has_xla() && IsFusion(root->xla().category())) { - // Sort the children under fusion node by raw flops. - *root->mutable_children() = - TopKChildren(root, k, [](const Node* a, const Node* b) { - return a->metrics().raw_flops() > b->metrics().raw_flops(); - }).children(); - } else { - *root->mutable_children() = - TopKChildren(root, k, [](const Node* a, const Node* b) { - return a->metrics().raw_time() > b->metrics().raw_time(); - }).children(); - } - } -} - -// Finalize deduplicated nodes by copying symbol details from the top child -// node. -void FinalizeDeduplicatedNodes(bool by_program, Node* root) { - if (by_program) { - for (Node& program_node : *root->mutable_children()) { - for (Node& category_node : *program_node.mutable_children()) { - for (Node& deduplicated_node : *category_node.mutable_children()) { - // Node with 1 child doesn't have deduplication, the child is itself. - // Removing the dedup layer. - if (deduplicated_node.children_size() == 1) { - Node child = *deduplicated_node.mutable_children(0); - deduplicated_node = child; - continue; - } - CopySymbolDetailsToDeduplicatedNode( - deduplicated_node.mutable_children(0), &deduplicated_node); - } - } - } - } else { - for (Node& category_node : *root->mutable_children()) { - for (Node& deduplicated_node : *category_node.mutable_children()) { - // Node with 1 child doesn't have deduplication, the child is itself. - // Removing the dedup layer. - if (deduplicated_node.children_size() == 1) { - Node child = *deduplicated_node.mutable_children(0); - deduplicated_node = child; - continue; - } - CopySymbolDetailsToDeduplicatedNode( - deduplicated_node.mutable_children(0), &deduplicated_node); - } - } - } -} - -// Fills op metrics into a node. -void PopulateOpMetricsNode( - const OpMetrics& op_metrics, double peak_gigaflops_per_second_per_core, - std::vector peak_mem_gibibytes_per_second_per_core, - uint64_t total_time_ps, Node* node) { - // TODO(dfinchel): remove this temporary change to avoid crash. - // This is only needed while we make an update to proto version that is not - // backwards compatible. - if (peak_mem_gibibytes_per_second_per_core.size() != - (MemBwType_MAX - MemBwType_MIN + 1)) { - peak_mem_gibibytes_per_second_per_core.clear(); - for (int i = MemBwType_MIN; i <= MemBwType_MAX; ++i) { - peak_mem_gibibytes_per_second_per_core.push_back(0); - } - } - - Metrics* metrics = node->mutable_metrics(); - // The UI computes flops_rate = raw_flops / raw_time - // and memory_bandwidth = raw_bytes_accessed / raw_time. See: - // https://github.com/tensorflow/profiler/blob/master/frontend/app/common/utils/utils.ts - metrics->set_raw_time(op_metrics.time_ps()); - metrics->set_raw_flops(op_metrics.model_flops()); - metrics->set_occurrences(op_metrics.occurrences()); - metrics->set_avg_time_ps(tsl::profiler::SafeDivide(op_metrics.time_ps(), - op_metrics.occurrences())); - - double flops_utilization = CapUtilization( - tsl::profiler::SafeDivide(GigaFlopsPerSecondPerCore(op_metrics), - peak_gigaflops_per_second_per_core)); - // The UI expects flops_utilization = flop_util / time_fraction. See: - // https://github.com/tensorflow/profiler/blob/master/frontend/app/common/utils/utils.ts - const double time_fraction = - tsl::profiler::SafeDivide(op_metrics.time_ps(), total_time_ps); - metrics->set_flops(flops_utilization * time_fraction); - - // Capture both on-chip and off-chip memory utilization. - const double hbm_gibibytes_per_second = - tsl::profiler::GigaToGibi( - GigaBytesPerSecondPerCore(op_metrics, MemorySpace::MEMORY_SPACE_HBM, - OpMetrics::MemoryAccessed::READ)) + - tsl::profiler::GigaToGibi( - GigaBytesPerSecondPerCore(op_metrics, MemorySpace::MEMORY_SPACE_HBM, - OpMetrics::MemoryAccessed::WRITE)); - const double hbm_bw_utilization = CapUtilization(tsl::profiler::SafeDivide( - hbm_gibibytes_per_second, - peak_mem_gibibytes_per_second_per_core[MemBwType::MEM_BW_TYPE_HBM_RW])); - metrics->add_bandwidth_utils(hbm_bw_utilization); - double hbm_bytes = tsl::profiler::GibiToGiga(hbm_gibibytes_per_second) * - tsl::profiler::PicoToNano(op_metrics.time_ps()); - - const double sram_rd_gibibytes_per_second = tsl::profiler::GigaToGibi( - GigaBytesPerSecondPerCore(op_metrics, MemorySpace::MEMORY_SPACE_ON_CHIP, - OpMetrics::MemoryAccessed::READ)); - const double sram_rd_bw_utilization = - CapUtilization(tsl::profiler::SafeDivide( - sram_rd_gibibytes_per_second, peak_mem_gibibytes_per_second_per_core - [MemBwType::MEM_BW_TYPE_SRAM_RD])); - metrics->add_bandwidth_utils(sram_rd_bw_utilization); - double sram_rd_bytes = - tsl::profiler::GibiToGiga(sram_rd_gibibytes_per_second) * - tsl::profiler::PicoToNano(op_metrics.time_ps()); - - const double sram_wr_gibibytes_per_second = tsl::profiler::GigaToGibi( - GigaBytesPerSecondPerCore(op_metrics, MemorySpace::MEMORY_SPACE_ON_CHIP, - OpMetrics::MemoryAccessed::WRITE)); - const double sram_wr_bw_utilization = - CapUtilization(tsl::profiler::SafeDivide( - sram_wr_gibibytes_per_second, peak_mem_gibibytes_per_second_per_core - [MemBwType::MEM_BW_TYPE_SRAM_WR])); - metrics->add_bandwidth_utils(sram_wr_bw_utilization); - double sram_wr_bytes = - tsl::profiler::GibiToGiga(sram_wr_gibibytes_per_second) * - tsl::profiler::PicoToNano(op_metrics.time_ps()); - - metrics->add_raw_bytes_accessed_array(hbm_bytes); - metrics->add_raw_bytes_accessed_array(sram_rd_bytes); - metrics->add_raw_bytes_accessed_array(sram_wr_bytes); -} - -// Recursively insert "fused instruction" nodes (with raw flops). -void InsertFusedInstructions(const OpMetrics& op_metrics, Node* node) { - if (!op_metrics.has_children()) return; - for (const auto& child : op_metrics.children().metrics_db()) { - Node* new_node = node->add_children(); - PopulateSymbolNode(child, new_node); - new_node->mutable_metrics()->set_raw_flops(child.flops()); - if (child.has_children()) { - InsertFusedInstructions(child, new_node); - } - } -} - -void UpdateNodeMetrics(const OpMetrics& child, OpMetrics* parent) { - DCHECK(parent != nullptr); - parent->set_time_ps(child.self_time_ps() + parent->time_ps()); - if (ChildrenTimePs(child) == 0) { - parent->set_flops(child.flops() + parent->flops()); - parent->set_model_flops(child.model_flops() + parent->model_flops()); - parent->set_bytes_accessed(child.bytes_accessed() + - parent->bytes_accessed()); - parent->set_dma_stall_ps(child.dma_stall_ps() + parent->dma_stall_ps()); - CombineMemoryAccessedBreakdown(child.memory_accessed_breakdown(), - parent->mutable_memory_accessed_breakdown()); - } -} - -} // namespace - -std::string OpProfileBuilder::GenerateProgramName(uint64_t program_id) const { - DCHECK(program_name_map_ != nullptr); - auto iter = program_name_map_->find(program_id); - if (iter == program_name_map_->end()) return "main"; - return tsl::profiler::HloModuleNameWithProgramId(iter->second, program_id); -} - -Node* OpProfileBuilder::AddOpNode(const OpMetrics& op_metrics, - Category* category, Node* deduplicated_node) { - Node* leaf; - if (deduplicated_node != nullptr) { - leaf = deduplicated_node->add_children(); - } else if (category != nullptr) { - leaf = category->node->add_children(); - } else { - leaf = root_->add_children(); - } - PopulateSymbolNode(op_metrics, leaf); - InsertFusedInstructions(op_metrics, leaf); - return leaf; -} - -// Function to create deduplicated aggregation layer. -// 1. Empty deduplicated_name in op_metrics means either: -// (1) a grouping op of a deduplicated op list. (fusion.3 in the example below) -// (2) an op that does not have duplicates. (fusion.4 in the example below) -// We create dedup layer for both cases due to lack of clue which case it is. -// The op name is used directly as the hash key for the dedup group. The dedup -// layer will be removed in the 2nd pass for case (2). -// 2. Non-empty deduplicated_name means this op can be grouped to a -// deduplicated op list (fusion.1 in the example below). -// Example: -// op_metrics { -// name: "fusion.1" -// deduplicated_name: "fusion.3" -// category: "convolution" -// } -// op_metrics { -// name: "fusion.3" -// deduplicated_name: "" -// category: "convolution" -// } -// op_metrics { -// name: "fusion.4" -// deduplicated_name: "" -// category: "convolution" -// } -// The data above will create the following tree after calling the function -// repeatedly: -// root(by_program) -// - jit.xx -// - convolution -// - fusion.3 -// - fusion.1 -// - fusion.2 -// - fusion.3 -// - fusion.4 -// - fusion.4 -// After finalization, the tree will look like: -// root(by_program) -// - jit.xx -// - convolution -// - fusion.3 and its duplicate(s) -// - fusion.1 -// - fusion.2 -// - fusion.3 -// - fusion.4 -Node* OpProfileBuilder::LookupOrAddDeduplicatedNode(const OpMetrics& op_metrics, - Category* category) { - std::string deduplicated_name = op_metrics.deduplicated_name().empty() - ? op_metrics.name() - : op_metrics.deduplicated_name(); - Node*& deduplicated_node = category->deduplicated_nodes[deduplicated_name]; - if (deduplicated_node == nullptr) { - deduplicated_node = category->node->add_children(); - // Set deduplicated name which is the hash key for the dedup group. - // Symbol details will be added in finalization step. - deduplicated_node->set_name(deduplicated_name); - } - return deduplicated_node; -} - -OpProfileBuilder::Category* OpProfileBuilder::LookupOrAddCategoryNode( - const OpMetrics& op_metrics, Program* program) { - Category* category; - Node* category_parent; - if (program != nullptr) { - category = &program->categories[op_metrics.category()]; - category_parent = program->node; - } else { - category = &category_map_[op_metrics.category()]; - category_parent = root_; - } - if (category->node == nullptr) { - category->node = category_parent->add_children(); - category->node->set_name(op_metrics.category()); - } - return category; -} - -OpProfileBuilder::Program* OpProfileBuilder::LookupOrAddProgramNode( - const OpMetrics& op_metrics) { - uint64_t program_id = op_metrics.hlo_module_id(); - Program* program = &programs_map_[program_id]; - if (program->node == nullptr) { - program->node = root_->add_children(); - program->node->set_name(GenerateProgramName(program_id)); - } - return program; -} - -void OpProfileBuilder::AddOp(const OpMetrics& op_metrics) { - // 1. Deal with nested parent nodes - // op_metrics.time_ps in root node will be reset to total_time_ps later - UpdateNodeMetrics(op_metrics, &metrics_[root_]); - Program* program = nullptr; - if (!IsIdleOp(op_metrics) && options_.group_by_program) { - program = LookupOrAddProgramNode(op_metrics); - UpdateNodeMetrics(op_metrics, &metrics_[program->node]); - } - - // 2. Deal with nested grouping nodes, only accumulate non-child ops - if (ChildrenTimePs(op_metrics) > 0) return; - std::vector nested_grouping_nodes; - if (IsIdleOp(op_metrics)) { - Node* leaf = AddOpNode(op_metrics); - nested_grouping_nodes.push_back(leaf); - } else { - Category* category = LookupOrAddCategoryNode(op_metrics, program); - nested_grouping_nodes.push_back(category->node); - - Node* deduplicated_node = nullptr; - if (options_.group_by_deduplicated_name) { - deduplicated_node = LookupOrAddDeduplicatedNode(op_metrics, category); - nested_grouping_nodes.push_back(deduplicated_node); - } - - Node* leaf = AddOpNode(op_metrics, category, deduplicated_node); - nested_grouping_nodes.push_back(leaf); - } - - for (auto* node : nested_grouping_nodes) { - // Per program combiner does not need to update OpMetrics.num_cores - CombineOpMetrics(op_metrics, &metrics_[node], /*update_num_cores=*/false); - } -} - -void OpProfileBuilder::Finalize( - double peak_gigaflops_per_second_per_core, - std::vector peak_mem_gibibytes_per_second_per_core, - uint64_t total_time_ps) { - // Call to `PopulateOpMetricsNode` depends on node time_ps to calculate - // flops, bandwidth_utils..etc. The root / program node time_ps might - // be off a bit, missing its own self_time when calling `UpdateNodeMetrics`. - // This is best effort to at least reset the time_ps for root node to be more - // precise. - metrics_[root_].set_time_ps(total_time_ps); - for (const auto& [node, op_metrics] : metrics_) { - PopulateOpMetricsNode(op_metrics, peak_gigaflops_per_second_per_core, - peak_mem_gibibytes_per_second_per_core, total_time_ps, - node); - } - // If grouping by program, we build a two-level pruned tree: the first level - // is per program and the second level is per category. Otherwise we build a - // single-level per category pruned tree. - int level = options_.group_by_program ? 2 : 1; - SortAndPruneChildren(options_.children_per_node, level, root_); - if (options_.group_by_deduplicated_name) { - FinalizeDeduplicatedNodes(options_.group_by_program, root_); - } -} - -OpProfileBuilder::OpProfileBuilder( - const OpProfileOptions& options, - tensorflow::profiler::op_profile::Node* root, - const tsl::protobuf::Map* program_name_map) - : options_(options), root_(root), program_name_map_(program_name_map) { - if (root == nullptr) { - LOG(DFATAL) << "root is null."; - return; - } - DCHECK(!options_.group_by_program || program_name_map_ != nullptr); - root->set_name(options_.group_by_program ? "by_program" : "by_category"); -} -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_profile_builder.h b/tensorflow/core/profiler/convert/op_profile_builder.h deleted file mode 100644 index 3d4e7abd1f6b18..00000000000000 --- a/tensorflow/core/profiler/convert/op_profile_builder.h +++ /dev/null @@ -1,157 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_PROFILE_BUILDER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_PROFILE_BUILDER_H_ - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/node_hash_map.h" -#include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_profile.pb.h" - -namespace tensorflow { -namespace profiler { - -struct OpProfileOptions { - bool group_by_program = true; - bool group_by_deduplicated_name = true; - int children_per_node = 100; -}; - -// The structure of an op profile tree may looks like below: -// 1. group "by_program" -// - It starts from the root node, named as "by_program", and this node does -// not show up in op profile. -// - The children of root node is a list of hlo program node, named as the -// program/module name (eg. cluster.xx). -// - The children of a program node is hlo op category node, named as the -// category name (eg. data formatting). -// - The children of a category node is a list of op node or deduplicated -// group node: -// - For op that has duplicates, the child will be a deduplicated node, -// named like "copy.1111 and its deduplicate(s)". Its children will be all op -// nodes that are deduplicated. -// - For op that does not have duplicates, the child will be an op node -// under the op category (eg. copy.2222). -// -// Example path: "by_program" -> "main(...)" -// -> "data_formatting" -> "copy.12345 and its duplicate(s) -> "copy.12345" -// -// 2. group "by_category" -// Similarly to how the `by_program` op profile tree is constructed, -// `by_category` just removed the "program_node" layer: -// - It starts from the root node, named as "by_category", this node also does -// not show up in op profile. -// - The children of root node is a list of op category node, everything below -// is similar to above. -// - ... -// -// Example path: "by_category" -> "data_formatting" -> "copy.12345 and its -// duplicate(s) -> "copy.12345" -// -// How the op profile metrics are calculated: -// 1. For parent node in the nested structure like root node and program node: -// - time_ps will be accumulated from the self_time of all op nodes under it -// (might still be off a bit if the parent node has self_time, more details in -// b/333608397#comment5) -// - flops and memory access will only be accumulated from leaf op node under -// it to avoid double counting -// - unable to get occurrences of program executions now -// 2. For conceptual horizontal grouping node (eg.category, deduplicated) -// - all op_metris fields will be accumulated from leaf op node only in the -// group, to avoid double counting -class OpProfileBuilder { - public: - OpProfileBuilder(const OpProfileOptions& options, op_profile::Node* root, - const tensorflow::protobuf::Map* - program_name_map = nullptr); - - // Accumulate the op_metrics to the op_profile node tree - void AddOp(const OpMetrics& op_metrics); - - // Finalize the op_profile proto in a few steps (inter-dependent): - // 1. Reset time_ps for root node for more precise total time - // 2. Loop over the node to op_metrics map, populate corresponding op_metrics - // to the node.metrics - // 3. `SortAndPruneChildren` given query param `op_profile_limit` - // 4. `FinalizeDeduplicatedNodes` by coping the first op node data to the - // deduplicated node - void Finalize(double peak_gigaflops_per_second_per_core, - std::vector peak_mem_gibibytes_per_second_per_core, - uint64_t total_time_ps); - - private: - struct Category { - op_profile::Node* node; - absl::flat_hash_map deduplicated_nodes; - }; - - struct Program { - op_profile::Node* node; - absl::flat_hash_map categories; - }; - - std::string GenerateProgramName(uint64_t program_id) const; - - // Adds and returns a node for op_metrics. - // If op_metrics corresponds to a fusion, adds children to the node for the - // fused instructions. - // If deduplicated_node is not null, adds the node under it. - // Otherwise, if category is not null, adds the node under category. - // Otherwise, adds the node under root. - op_profile::Node* AddOpNode(const OpMetrics& op_metrics, - Category* category = nullptr, - op_profile::Node* deduplicated_node = nullptr); - - // Returns a node for op_metrics.deduplicated_name(). - // Adds a node to the tree if necessary. - op_profile::Node* LookupOrAddDeduplicatedNode(const OpMetrics& op_metrics, - Category* category); - - // Returns a node for op_metrics.category(). - // Adds a node to the tree if necessary. - // If program is not null, the category node is added under program. - // Otherwise, the category node is added under root. - Category* LookupOrAddCategoryNode(const OpMetrics& op_metrics, - Program* program); - - // Returns a node for op_metrics.hlo_module_id(). - // Adds a node to the Node tree if necessary. - Program* LookupOrAddProgramNode(const OpMetrics& op_metrics); - - OpProfileOptions options_; - op_profile::Node* root_; - - // Map to look up and aggregate OpMetrics. - absl::node_hash_map metrics_; - - // Maps to look up if a category / program / deduplicated node has - // already been added to the tree. - absl::flat_hash_map programs_map_; - absl::flat_hash_map category_map_; - - // Map to look up program names by id. - const tensorflow::protobuf::Map* program_name_map_ = - nullptr; -}; -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_PROFILE_BUILDER_H_ diff --git a/tensorflow/core/profiler/convert/op_stack.h b/tensorflow/core/profiler/convert/op_stack.h deleted file mode 100644 index 6bfa4d776436da..00000000000000 --- a/tensorflow/core/profiler/convert/op_stack.h +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STACK_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STACK_H_ - -#include -#include -#include - -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace profiler { - -template -class OpStack { - public: - // Pushes an Op onto the stack. - void Push(uint32 op_id, std::unique_ptr op_info) { - stack_.emplace_back(op_id, std::move(op_info)); - } - - // Pops the Op with the given op_id from the stack. - std::unique_ptr Pop(uint32 op_id) { - // Pop until match or stack_ is empty. - std::unique_ptr result; - while (!stack_.empty()) { - auto back = std::move(stack_.back()); - stack_.pop_back(); - if (op_id == back.first) { - result = std::move(back.second); - break; - } - } - return result; - } - - // Returns the Op at the top of the stack. - OpInfo* Top() const { - return stack_.empty() ? nullptr : stack_.back().second.get(); - } - - // Returns true if the stack is empty. - bool Empty() const { return stack_.empty(); } - - // Clears the stack. - void Clear() { stack_.clear(); } - - private: - std::vector>> stack_; -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STACK_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_combiner.cc b/tensorflow/core/profiler/convert/op_stats_combiner.cc deleted file mode 100644 index 8879eac65c4f04..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_combiner.cc +++ /dev/null @@ -1,318 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_combiner.h" - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h" -#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/power_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/protobuf/topology.pb.h" -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" -#include "tensorflow/core/profiler/utils/step_intersection.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -// Combines the src PerCoreStepInfo into the dst PerCoreStepInfo. -void CombinePerCoreStepInfo( - int src_host_id, const PerCoreStepInfo& src, bool use_incomplete_step, - PerCoreStepInfo* dst, - OpMetricsDbCombiner* hlo_metrics_db_complete_steps_only_combiner, - OpMetricsDbCombiner* hlo_metrics_db_per_step_combiner) { - CombineCoreIdMap(src_host_id, src.step_info_per_core(), - dst->mutable_step_info_per_core()); - - // Since we have assigned a new step number to the combined result, update - // the step number on each core to this new step number. - uint32 new_step_num = dst->step_num(); - for (auto& percore_stepinfo : *dst->mutable_step_info_per_core()) { - auto& stepinfo = percore_stepinfo.second; - stepinfo.set_step_num(new_step_num); - } - - if (!use_incomplete_step) { - hlo_metrics_db_complete_steps_only_combiner->Combine(src.hlo_metrics_db()); - } - hlo_metrics_db_per_step_combiner->Combine(src.hlo_metrics_db()); - CombineCoreIdMap(src_host_id, src.all_reduce_db_per_core(), - dst->mutable_all_reduce_db_per_core()); - CombineCoreIdMap(src_host_id, src.core_id_to_replica_id_map(), - dst->mutable_core_id_to_replica_id_map()); -} - -void CombineStepDatabase( - int src_host_id, const StepIntersection& step_intersection, - const StepDatabaseResult& src, StepDatabaseResult* dst, - OpMetricsDbCombiner* hlo_metrics_db_complete_steps_only_combiner, - std::vector* hlo_metrics_db_per_step_combiners) { - if (src.use_incomplete_step()) dst->set_use_incomplete_step(true); - uint32 src_first_step_idx = step_intersection.FirstStepIndex(src_host_id); - for (uint32 i = 0; i < step_intersection.NumSteps(); i++) { - CombinePerCoreStepInfo( - src_host_id, src.step_sequence(src_first_step_idx + i), - src.use_incomplete_step(), dst->mutable_step_sequence(i), - hlo_metrics_db_complete_steps_only_combiner, - &(*hlo_metrics_db_per_step_combiners)[i]); - } -} - -void CombinePowerMetrics(const RunEnvironment& src, RunEnvironment* dst) { - const size_t src_hosts = src.hostnames_size(); - const size_t dst_hosts = dst->hostnames_size(); - const double src_weight = src_hosts * 1.0 / (src_hosts + dst_hosts); - const double dst_weight = dst_hosts * 1.0 / (src_hosts + dst_hosts); - // Always assume src/dst have the same number of power components. - for (const auto& src_metric : src.power_metrics().power_component_metrics()) { - for (auto& dst_metric : - *dst->mutable_power_metrics()->mutable_power_component_metrics()) { - if (src_metric.component_name() != dst_metric.component_name()) continue; - dst_metric.set_max_power( - std::max(src_metric.max_power(), dst_metric.max_power())); - dst_metric.set_avg_power(src_metric.avg_power() * src_weight + - dst_metric.avg_power() * dst_weight); - } - } -} - -void CombineRunEnvironment(const RunEnvironment& src, RunEnvironment* dst) { - dst->mutable_hostnames()->insert(src.hostnames().begin(), - src.hostnames().end()); - dst->set_host_count(dst->hostnames_size()); - // Ignore CPU and Unknown Device type for device type selection if the - // destination does not have a device type already. - if (src.device_type() != "CPU" && src.device_type() != "Device") { - dst->set_device_type(src.device_type()); - dst->set_device_core_count(src.device_core_count() + - dst->device_core_count()); - // Replica count and num cores per replica must be same for all copies. - dst->set_replica_count(std::max(src.replica_count(), dst->replica_count())); - dst->set_num_cores_per_replica( - std::max(src.num_cores_per_replica(), dst->num_cores_per_replica())); - *dst->mutable_system_topology() = src.system_topology(); - } else if (dst->device_type().empty()) { - dst->set_device_type(src.device_type()); - } - if (src.hardware_type() != dst->hardware_type()) { - // Select the highest hardware type as TPU/GPU should override CPU_ONLY - // (e.g. coordinator). - dst->set_hardware_type(std::max(src.hardware_type(), dst->hardware_type())); - } - dst->set_task_count(src.task_count() + dst->task_count()); - // Only overwrite the dst if profile_duration_ms in dst is not defined or - // is zero and profile_duration_ms in src is greater than zero. - if (src.host_independent_job_info().profile_duration_ms() > 0) { - (*dst->mutable_host_independent_job_info()) = - src.host_independent_job_info(); - } - for (const auto& job_info : src.host_dependent_job_info()) { - *(dst->add_host_dependent_job_info()) = job_info; - } - dst->set_host_trace_level(src.host_trace_level()); - dst->set_is_training(src.is_training()); - CombinePowerMetrics(src, dst); -} - -// Combines the src PerfEnv into the dst PerfEnv. -void CombinePerfEnv(const PerfEnv& src, PerfEnv* dst) { - if (src.peak_tera_flops_per_second() > 0) { - dst->set_peak_tera_flops_per_second(src.peak_tera_flops_per_second()); - } - - if (src.peak_bws_giga_bytes_per_second_size() > 0 && - dst->peak_bws_giga_bytes_per_second_size() == 0) { - *dst->mutable_peak_bws_giga_bytes_per_second() = - src.peak_bws_giga_bytes_per_second(); - } - if (src.ridge_point() > 0) { - dst->set_ridge_point(src.ridge_point()); - } -} - -// Combines the src Diagnostics into the dst Diagnostics. -void CombineDiagnostics(const Diagnostics& src, Diagnostics* dst) { - dst->mutable_info()->MergeFrom(src.info()); - dst->mutable_warnings()->MergeFrom(src.warnings()); - dst->mutable_errors()->MergeFrom(src.errors()); -} - -// Combine the src OpStats into the dst OpStats. -void CombineOpStats( - bool no_accelerator_in_system, int src_host_id, HardwareType hardware_type, - const StepIntersection& step_intersection, const OpStats& src, OpStats* dst, - OpMetricsDbCombiner* host_op_metrics_db_combiner, - OpMetricsDbCombiner* device_op_metrics_db_combiner, - OpMetricsDbCombiner* hlo_metrics_db_complete_steps_only_combiner, - std::vector* hlo_metrics_db_per_step_combiners) { - // Combine host_metrics_db. - // Host OpMetricsDb does not need to update the number of cores a certain op - // occurs. - host_op_metrics_db_combiner->Combine(src.host_op_metrics_db(), - /*update_num_cores=*/false); - // Combine device_metrics_db. - device_op_metrics_db_combiner->Combine(src.device_op_metrics_db()); - - // Combine step_db. - if (!IsCoordinator(no_accelerator_in_system, hardware_type)) { - CombineStepDatabase(src_host_id, step_intersection, src.step_db(), - dst->mutable_step_db(), - hlo_metrics_db_complete_steps_only_combiner, - hlo_metrics_db_per_step_combiners); - } - - // Combine run environment info. - CombineRunEnvironment(src.run_environment(), dst->mutable_run_environment()); - - // Combine the perf environment info. - CombinePerfEnv(src.perf_env(), dst->mutable_perf_env()); - - // Combine diagnostics. - CombineDiagnostics(src.diagnostics(), dst->mutable_diagnostics()); - - // Combine kernel stats. - dst->mutable_kernel_stats_db()->mutable_reports()->MergeFrom( - src.kernel_stats_db().reports()); - - // Combine tf-function stats. - CombineTfFunctionDb(src.tf_function_db(), dst->mutable_tf_function_db()); - - // Combine the mapping from core ID to details. - CombineCoreIdMap(src_host_id, src.core_id_to_details(), - dst->mutable_core_id_to_details()); - - // Combine performance counter result. - dst->mutable_performance_counter_result() - ->set_matrix_unit_utilization_percent( - dst->performance_counter_result().matrix_unit_utilization_percent() + - src.performance_counter_result().matrix_unit_utilization_percent()); -} - -} // namespace - -bool IsCoordinator(bool no_accelerator_in_system, HardwareType hardware_type) { - // A host is a coordinator if: - // (1) The host doesn't have a device, and - // (2) The system does use accelerator (if not, it uses CPU only and so this - // host should be regarded as a worker as well). - return !HasDevice(hardware_type) && !no_accelerator_in_system; -} - -bool NoAcceleratorInSystem(const std::vector& all_op_stats_info) { - for (const auto& op_stats_info : all_op_stats_info) { - if (HasDevice(op_stats_info.hardware_type)) { - return false; - } - } - return true; -} - -uint32 GlobalCoreId(int host_id, uint32 device_ordinal) { - constexpr uint32 kMaxDevicesPerHost = 1000; // power-of-10 for debuggability - return host_id * kMaxDevicesPerHost + device_ordinal; -} - -StepIntersection ComputeStepIntersectionToMergeOpStats( - const std::vector& all_op_stats_info, - uint32 max_step_per_host) { - bool no_accelerator_in_system = NoAcceleratorInSystem(all_op_stats_info); - - absl::flat_hash_map per_host_step_db; - for (const auto& op_stats_info : all_op_stats_info) { - if (IsCoordinator(no_accelerator_in_system, op_stats_info.hardware_type)) - continue; - // Includes only workers in per_host_step_db. - per_host_step_db[op_stats_info.src_host_id] = - &op_stats_info.op_stats->step_db(); - } - - return StepIntersection(max_step_per_host, per_host_step_db); -} - -void CombineAllOpStats(const std::vector& all_op_stats_info, - const StepIntersection& step_intersection, - OpStats* combined_op_stats) { - // A shortcut code path for a single OpStats. There is no need to merge. - if (all_op_stats_info.size() == 1) { - *combined_op_stats = *all_op_stats_info[0].op_stats; - return; - } - - StepDatabaseResult* combined_step_db = combined_op_stats->mutable_step_db(); - // Initialize the StepDatabaseResult field that depends on the number of - // steps. - for (uint32 dst_step_num : step_intersection.DstStepNumbers()) { - combined_step_db->add_step_sequence()->set_step_num(dst_step_num); - } - // Record the number of steps that are dropped. - combined_step_db->set_num_steps_dropped(step_intersection.StepsDropped()); - - combined_step_db->set_empty_intersect(step_intersection.EmptyIntersect()); - - // Initialize all the OpMetricsDbCombiners. - OpMetricsDbCombiner host_op_metrics_db_combiner( - combined_op_stats->mutable_host_op_metrics_db()); - OpMetricsDbCombiner device_op_metrics_db_combiner( - combined_op_stats->mutable_device_op_metrics_db()); - OpMetricsDbCombiner hlo_metrics_db_complete_steps_only_combiner( - combined_op_stats->mutable_hlo_metrics_db_complete_steps_only()); - std::vector hlo_metrics_db_per_step_combiners; - hlo_metrics_db_per_step_combiners.reserve( - combined_step_db->step_sequence_size()); - for (PerCoreStepInfo& step_info : - *combined_step_db->mutable_step_sequence()) { - hlo_metrics_db_per_step_combiners.emplace_back( - step_info.mutable_hlo_metrics_db()); - } - - bool no_accelerator_in_system = NoAcceleratorInSystem(all_op_stats_info); - - for (const auto& op_stats_info : all_op_stats_info) { - CombineOpStats(no_accelerator_in_system, op_stats_info.src_host_id, - op_stats_info.hardware_type, step_intersection, - *op_stats_info.op_stats, combined_op_stats, - &host_op_metrics_db_combiner, &device_op_metrics_db_combiner, - &hlo_metrics_db_complete_steps_only_combiner, - &hlo_metrics_db_per_step_combiners); - } - - // Sorts all the kernel reports that have been merged by CombineTfOpStats and - // keeps only the top kernel reports with long kernel duration. - SortAndKeepTopKDurationKernelReportsInDb( - combined_op_stats->mutable_kernel_stats_db()); - - // Process performance counter results. - combined_op_stats->mutable_performance_counter_result() - ->set_matrix_unit_utilization_percent( - combined_op_stats->performance_counter_result() - .matrix_unit_utilization_percent() / - all_op_stats_info.size()); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_combiner.h b/tensorflow/core/profiler/convert/op_stats_combiner.h deleted file mode 100644 index a8cb3c62c4087a..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_combiner.h +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_COMBINER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_COMBINER_H_ - -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/utils/step_intersection.h" - -namespace tensorflow { -namespace profiler { - -// Whether a host is a coordinator. -bool IsCoordinator(bool no_accelerator_in_system, HardwareType hardware_type); - -// Translates the core id from single host to the one for multiple-host. -// We need this translation because the device_ordinal was assigned when a -// single host response was given. Now, we need a global core_id to distinguish -// it with multiple hosts. -uint32 GlobalCoreId(int host_id, uint32 device_ordinal); - -// Combines the src map into the dst map. -// The src map keys are local core_ids. The src_host_id is used to convert them -// into global core_ids used as keys in the dst map. -// REQUIRED: cores from src_host_id are not already in dst. -template -void CombineCoreIdMap(int src_host_id, const CoreIdMap& src, CoreIdMap* dst) { - for (const auto& core_id_and_value : src) { - uint32 global_core_id = GlobalCoreId(src_host_id, core_id_and_value.first); - auto iter_and_inserted = - dst->insert({global_core_id, core_id_and_value.second}); - DCHECK(iter_and_inserted.second) - << "Duplicated core_id: " << iter_and_inserted.first->first; - } -} - -// A struct that contains all the information that is needed to combine OpStats. -struct OpStatsInfo { - OpStatsInfo(const OpStats* op_stats, HardwareType hardware_type, - int src_host_id) - : op_stats(op_stats), - hardware_type(hardware_type), - src_host_id(src_host_id) {} - const OpStats* op_stats; - HardwareType hardware_type; - int src_host_id; -}; - -// Returns true if there is no device (accelerator) in any of the hosts. -bool NoAcceleratorInSystem(const std::vector& all_op_stats_info); - -// Compute the StepIntersection to merge OpStats. -// Profiler will limit the number of steps to be at most . -StepIntersection ComputeStepIntersectionToMergeOpStats( - const std::vector& all_op_stats_info, - uint32 max_step_per_host); - -// Combine all the OpStats in using the steps in range -// . The result is stored in . -void CombineAllOpStats(const std::vector& all_op_stats_info, - const StepIntersection& step_intersection, - OpStats* combined_op_stats); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_COMBINER_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_combiner_test.cc b/tensorflow/core/profiler/convert/op_stats_combiner_test.cc deleted file mode 100644 index cd5e97fe3c7e18..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_combiner_test.cc +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_combiner.h" - -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/step_intersection.h" - -namespace tensorflow { -namespace profiler { -namespace { - -// Tests that the run_environment field of the combined op stats is set -// correctly. -TEST(CombineAllOpStatsTest, CombineRunEnvironment) { - // Construct OpStatsInfo and all_op_stats_info. - OpStats dst_op_stats, op_stats_1, op_stats_2; - op_stats_1.mutable_run_environment() - ->mutable_host_independent_job_info() - ->set_profile_duration_ms(100); - op_stats_2.mutable_run_environment() - ->mutable_host_independent_job_info() - ->set_profile_duration_ms(0); - OpStatsInfo op_stats_info_1(&op_stats_1, TPU, 0), - op_stats_info_2(&op_stats_2, TPU, 0); - std::vector all_op_stats_info = {op_stats_info_1, - op_stats_info_2}; - - // Construct dummy step_intersection. - StepDatabaseResult dummy_step_db_result; - absl::flat_hash_map result; - result.insert({0, &dummy_step_db_result}); - StepIntersection dummy_step_intersection = StepIntersection(1, result); - - // Combine all op stats. - CombineAllOpStats(all_op_stats_info, dummy_step_intersection, &dst_op_stats); - - // Verify that the profile_duration_ms field of the second object is now set. - EXPECT_EQ(100, dst_op_stats.run_environment() - .host_independent_job_info() - .profile_duration_ms()); -} - -TEST(CombineAllOpStatsTest, CombineRunEnvironmentWithUnknownDevice) { - OpStats dst_op_stats, op_stats_1, op_stats_2; - op_stats_1.mutable_run_environment()->set_device_type("TPU"); - op_stats_2.mutable_run_environment()->set_device_type("Device"); - OpStatsInfo op_stats_info_1(&op_stats_1, TPU, 0), - op_stats_info_2(&op_stats_2, TPU, 0); - std::vector all_op_stats_info = {op_stats_info_1, - op_stats_info_2}; - - // Construct dummy step_intersection. - StepDatabaseResult dummy_step_db_result; - absl::flat_hash_map result; - result.insert({0, &dummy_step_db_result}); - StepIntersection dummy_step_intersection = StepIntersection(1, result); - - CombineAllOpStats(all_op_stats_info, dummy_step_intersection, &dst_op_stats); - - EXPECT_EQ("TPU", dst_op_stats.run_environment().device_type()); -} - -TEST(CombineAllOpStatsTest, CombinePerfEnvOrderZero) { - // Ensure CombinePerfEnv behaves consistently regardless of order of op stats. - OpStats dst_op_stats1, dst_op_stats2, op_stats_1, op_stats_2; - op_stats_1.mutable_perf_env()->set_peak_tera_flops_per_second(100); - op_stats_2.mutable_perf_env()->set_peak_tera_flops_per_second(0); - // Construct dummy step_intersection which is required by CombineAllOpStats(). - absl::flat_hash_map result; - StepIntersection dummy_step_intersection = StepIntersection(1, result); - - OpStatsInfo op_stats_info_1(&op_stats_1, TPU, 0), - op_stats_info_2(&op_stats_2, TPU, 0); - - // Test order 1. - std::vector all_op_stats_info = {op_stats_info_1, - op_stats_info_2}; - CombineAllOpStats(all_op_stats_info, dummy_step_intersection, &dst_op_stats1); - EXPECT_EQ(100, dst_op_stats1.perf_env().peak_tera_flops_per_second()); - - // Test order 2. - all_op_stats_info = { - op_stats_info_2, - op_stats_info_1, - }; - CombineAllOpStats(all_op_stats_info, dummy_step_intersection, &dst_op_stats2); - EXPECT_EQ(100, dst_op_stats2.perf_env().peak_tera_flops_per_second()); -} - -TEST(CombineAllOpStatsTest, CombineRunEnvironmentWithMismatchHardwareType) { - OpStats coordinator_op_stats, device_op_stats, dst_op_stats; - coordinator_op_stats.mutable_run_environment()->set_hardware_type( - HardwareType::CPU_ONLY); - device_op_stats.mutable_run_environment()->set_hardware_type( - HardwareType::TPU); - CombineAllOpStats({OpStatsInfo(&coordinator_op_stats, CPU_ONLY, 0), - OpStatsInfo(&device_op_stats, TPU, 1)}, - StepIntersection(1, {}), &dst_op_stats); - EXPECT_EQ(dst_op_stats.run_environment().hardware_type(), HardwareType::TPU); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.cc deleted file mode 100644 index 66ceccf0af3efc..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.cc +++ /dev/null @@ -1,178 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h" - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "tensorflow/core/profiler/convert/data_table_utils.h" -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" -#include "tensorflow/core/profiler/protobuf/hlo_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using tensorflow::profiler::OpMetrics; -using tensorflow::profiler::OpMetricsDb; -using ::tensorflow::profiler::OpStats; -using ::tensorflow::profiler::PerfEnv; -using ::tensorflow::profiler::RunEnvironment; -using tensorflow::profiler::hlo_stats::HloStatsDatabase; -using tensorflow::profiler::hlo_stats::HloStatsRecord; -using tsl::profiler::IsOutsideCompilationOp; - -HloStatsRecord ConvertOpMetricsToHloStatsRecord(const OpMetrics& metrics, - const PerfEnv& perf_env, - const RunEnvironment& run_env) { - HloStatsRecord record; - record.set_program_id(metrics.hlo_module_id()); - record.set_hlo_expression(metrics.long_name()); - record.set_tf_op_name(metrics.provenance()); - record.set_hlo_category(metrics.category()); - record.set_autotuned(metrics.autotuned()); - tensorflow::profiler::SetExecutionTimes(metrics, &record); - tensorflow::profiler::SetTpuUnitFractions(metrics, &record); - SetRooflineMetrics(metrics, perf_env, run_env, &record); - record.set_rematerialization(tsl::profiler::IsRematerialization( - /*hlo_expression=*/metrics.long_name(), - /*framework_op_name=*/metrics.provenance())); - record.set_outside_compilation( - IsOutsideCompilationOp(metrics.provenance(), metrics.long_name())); - return record; -} - -} // namespace - -HloStatsDatabase ConvertOpStatsToHloStats(const OpStats& op_stats) { - HloStatsDatabase hlo_stats_db; - const OpMetricsDb& hlo_metrics_db = op_stats.device_op_metrics_db(); - double total_device_time_us = - tsl::profiler::PicoToMicro(hlo_metrics_db.total_time_ps()); - HloStatsRecord sentinel; - sentinel.set_rank(0); - sentinel.set_cumulative_total_self_time_as_fraction(0.0); - const HloStatsRecord* prev_record = &sentinel; - for (const OpMetrics* metrics : - tensorflow::profiler::SortedOpMetricsDb(hlo_metrics_db)) { - if (metrics->occurrences() == 0) continue; - HloStatsRecord* record = hlo_stats_db.add_hlo_stats_record(); - *record = ConvertOpMetricsToHloStatsRecord(*metrics, op_stats.perf_env(), - op_stats.run_environment()); - tensorflow::profiler::SetRankAndTimeFractions(total_device_time_us, - *prev_record, record); - prev_record = record; - } - return hlo_stats_db; -} - -// The parse logic based on the assumption that the hlo op text is in format of -// '%op_name = ' -std::string GetHloOpNameFromExpression(std::string expression) { - std::vector<::std::string> parts = absl::StrSplit(expression, " = "); - std::string hlo_op_name = parts[0]; - if (hlo_op_name[0] == '%') { - hlo_op_name = hlo_op_name.substr(1); - } - return hlo_op_name; -} - -std::vector> HloStatsDataTableColumns() { - const std::vector> kColumns = { - {"rank", "number", "Rank"}, - {"program_id", "string", "Program id"}, - {"category", "string", "HLO op category"}, - {"hlo_op_name", "string", "HLO op name"}, - {"hlo_op_expression", "string", "HLO op text"}, - {"tf_op_name", "string", "Framework op name"}, - {"occurrences", "number", "#Occurrences"}, - {"total_time", "number", "Total time (us)"}, - {"avg_time", "number", "Avg. time (us)"}, - {"total_self_time", "number", "Total self time (us)"}, - {"avg_self_time", "number", "Avg. self time (us)"}, - {"total_self_time_percent", "number", "Total self time (%)"}, - { - "cumulative_total_self_time_percent", - "number", - "Cumulative total self time (%)", - }, - {"dma_stall_percent", "number", "%time stalled by DMA"}, - {"model_flop_rate", "number", "Model GFLOP/s"}, - {"normalized_flop_rate", "number", "Normalized GFLOP/s"}, - {"measured_memory_bw", "number", "Measured memory BW (GiB/s)"}, - {"hbm_bw", "number", "HBM BW (GiB/s)"}, - {"cmem_read_bw", "number", "CMEM Read BW (GiB/s)"}, - {"cmem_write_bw", "number", "CMEM Write BW (GiB/s)"}, - {"operational_intensity", "number", "Operational intensity (FLOPS/Byte)"}, - {"bound_by", "string", "Bound by"}, - {"hlo_rematerialization", "string", "Rematerialization"}, - {"outside_compilation", "string", "Outside Compilation"}, - {"autotuned", "string", "Autotuned"}, - }; - return kColumns; -} - -std::unique_ptr CreateHloStatsDataTable( - const HloStatsDatabase& hlo_stats_db) { - auto data_table = std::make_unique(); - for (const std::vector& col : HloStatsDataTableColumns()) { - data_table->AddColumn(TableColumn(col[0], col[1], col[2])); - } - for (const HloStatsRecord& record : hlo_stats_db.hlo_stats_record()) { - TableRow* row = data_table->AddRow(); - row->AddCell(record.rank()); - row->AddCell(absl::StrCat(record.program_id())); - row->AddCell(record.hlo_category()); - row->AddCell(GetHloOpNameFromExpression(record.hlo_expression())); - row->AddCell(record.hlo_expression()); - row->AddCell(record.tf_op_name()); - row->AddCell(record.occurrences()); - row->AddCell(record.total_time_in_us()); - row->AddCell(record.avg_time_in_us()); - row->AddCell(record.total_self_time_in_us()); - row->AddCell(record.avg_self_time_in_us()); - row->AddCell(record.total_self_time_as_fraction()); - row->AddCell(record.cumulative_total_self_time_as_fraction()); - row->AddCell(record.dma_stall_fraction()); - row->AddCell(record.model_flop_rate()); - row->AddCell(record.measured_flop_rate()); - row->AddCell(record.measured_memory_bw()); - row->AddCell(record.hbm_bw()); - row->AddCell(record.cmem_read_bw()); - row->AddCell(record.cmem_write_bw()); - row->AddCell(record.operational_intensity()); - row->AddCell(absl::StrCat(record.bound_by())); - row->AddCell(record.rematerialization() ? "Yes" : "No"); - row->AddCell(record.outside_compilation() ? "Yes" : "No"); - row->AddCell(record.autotuned() ? "Yes" : "No"); - } - return data_table; -} - -std::string HloStatsToDataTableJson(const HloStatsDatabase& hlo_stats_db) { - return CreateHloStatsDataTable(hlo_stats_db)->ToJson(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h b/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h deleted file mode 100644 index 359024df04b221..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_HLO_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_HLO_STATS_H_ - -#include -#include - -#include "tensorflow/core/profiler/convert/data_table_utils.h" -#include "tensorflow/core/profiler/protobuf/hlo_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" - -namespace tensorflow { -namespace profiler { -tensorflow::profiler::hlo_stats::HloStatsDatabase ConvertOpStatsToHloStats( - const tensorflow::profiler::OpStats& op_stats); - -// Converts to JSON align with current DataTable JSON format. -std::string HloStatsToDataTableJson( - const hlo_stats::HloStatsDatabase& hlo_stats_db); - -// Construct a DataTable object from HloStatsDatabase. -std::unique_ptr CreateHloStatsDataTable( - const hlo_stats::HloStatsDatabase& hlo_stats_db); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_HLO_STATS_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc deleted file mode 100644 index e77a8a21cc9c73..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc +++ /dev/null @@ -1,1646 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" - -#include - -#include -#include -#include -#include -#include -#include - -#include "google/protobuf/any.pb.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/format_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/util/stats_calculator.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" -#include "tensorflow/core/profiler/convert/profile_time_breakdown.h" -#include "tensorflow/core/profiler/convert/step_events_to_steps_db.h" -#include "tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/html_utils.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h" -#include "tensorflow/core/profiler/utils/tpu_step_details_utils.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -using tsl::profiler::OneDigit; - -// If the percentage of step time that spends on SparseCoreV0 is more than -// kModeratelySparseCoreV0BoundThresholdInPercent, it is considered highly -// SparseCoreV0 bound. -constexpr double kModeratelySparseCoreV0BoundThresholdInPercent = 10; -// If the percentage of step time that spends on all-reduce is more than -// kAllReduceBoundThresholdInPercent, it is considered all-reduce bound. -constexpr double kAllReduceBoundThresholdInPercent = 6; -// If the percentage of step time that is idle due to host overhead (but not -// input-related) is >= kTcIdleThresholdInPercent, it will be highlighted in the -// recommendation section of the Overview Page. -constexpr double kTcIdleThresholdInPercent = 3; -// Public doc on how to run multiple steps in a tf-function. -constexpr absl::string_view kMultipleStepsInTffunctionDoc = - "https://www.tensorflow.org/guide/" - "tpu#improving_performance_by_multiple_steps_within_tffunction"; - -const double kNumPsPerMs = 1000000000.0; - -// If the percentage of step time that is due to infeed is less than -// kModeratelyInfeedBoundThresholdInPercent, it is considered NOT -// input-bound; else if it is less than -// kHighlyInfeedBoundThresholdInPercent, it is considered MODERATELY -// input-bound; else if it is considered HIGHLY input-bound. -constexpr double kModeratelyInfeedBoundThresholdInPercent = 5; -constexpr double kHighlyInfeedBoundThresholdInPercent = 20; - -// If the percentage of step time that is due to outfeed is less than -// kModeratelyOutfeedBoundThresholdInPercent, it is considered NOT -// output-bound; else if it is less than -// kHighlyOutfeedBoundThresholdInPercent, it is considered MODERATELY -// output-bound; else if it is considered HIGHLY output-bound. -constexpr double kModeratelyOutfeedBoundThresholdInPercent = 5; -constexpr double kHighlyOutfeedBoundThresholdInPercent = 20; - -// If the percentage of step time that is due to kernel launch is less than -// kModeratelyKernelLaunchBoundThresholdInPercent, it is considered NOT -// kernel-launch bound; else if it is less than -// kHighlyKernelLaunchBoundThresholdInPercent, it is considered MODERATELY -// kernel-launch bound; else if it is considered HIGHLY kernel-launch bound. -constexpr double kModeratelyKernelLaunchBoundThresholdInPercent = 3; -constexpr double kHighlyKernelLaunchBoundThresholdInPercent = 15; - -// If the percentage of step time that is due to all other time is less than -// kModeratelyAllOtherBoundThresholdInPercent, it is considered NOT -// all-other bound; else if it is less than -// kHighlyAllOtherBoundThresholdInPercent, it is considered MODERATELY -// all-other bound; else if it is considered HIGHLY all-other bound. -constexpr double kModeratelyAllOtherBoundThresholdInPercent = 3; -constexpr double kHighlyAllOtherBoundThresholdInPercent = 15; - -// If the percentage of step time that is due to device collectives is less than -// kModeratelyDeviceCollectivesBoundThresholdInPercent, it is considered NOT -// device-collectives bound; else if it is less than -// kHighlyDeviceCollectivesBoundThresholdInPercent, it is considered MODERATELY -// device-collectives bound; else if it is considered HIGHLY device-collectives -// bound. -constexpr double kModeratelyDeviceCollectivesBoundThresholdInPercent = 3; -constexpr double kHighlyDeviceCollectivesBoundThresholdInPercent = 15; - -// Section number of the host-analysis section in the input-pipeline analysis. -constexpr int kHostAnalysisSectionNumber = 3; -// Python-only explanation for "All Others" time. -const char* kAllOthersPythonExplanation = - " % of the total step time sampled is spent on 'All Others' time. " - "This could be due to Python execution overhead."; -// Explanation for "Kernel Launch" time due to CPU contention with tf.data. -const char* kKernelLaunchTfDataContention = - " It could be due to CPU contention with tf.data. In this case, you may " - "try to set the environment variable TF_GPU_THREAD_MODE=gpu_private."; - -template -double GetTimeInMs(const Collection& type_ps, EventType event_type) { - return tsl::profiler::PicoToMilli( - gtl::FindWithDefault(type_ps, event_type, /*value=*/0)); -} - -GenericStepTimeBreakdown ComputeGenericStepTimeBreakdownInMs( - const InputPipelineAnalysisResult& analysis) { - tsl::Stat unknown_time_ms; - tsl::Stat host_wait_input_ms; - tsl::Stat host_to_device_ms; - tsl::Stat input_ms; - tsl::Stat output_ms; - tsl::Stat device_compute_ms; - tsl::Stat device_to_device_ms; - tsl::Stat device_collectives_ms; - tsl::Stat host_compute_ms; - tsl::Stat host_prepare_ms; - tsl::Stat host_compile_ms; - GenericStepTimeBreakdown result; - - for (const google::protobuf::Any& step_details : analysis.step_details()) { - PerGenericStepDetails details; - bool success = step_details.UnpackTo(&details); - if (!success && !step_details.type_url().empty()) { - LOG(ERROR) << "Unable to unpack step_breakdown. Expected: generic" - << std::endl; - return {}; - } - unknown_time_ms.UpdateStat(details.unknown_time_ms()); - host_wait_input_ms.UpdateStat(details.host_wait_input_ms()); - host_to_device_ms.UpdateStat(details.host_to_device_ms()); - input_ms.UpdateStat(details.host_wait_input_ms() + - details.host_to_device_ms()); - output_ms.UpdateStat(details.output_ms()); - device_compute_ms.UpdateStat(details.device_compute_ms()); - device_to_device_ms.UpdateStat(details.device_to_device_ms()); - device_collectives_ms.UpdateStat(details.device_collectives_ms()); - host_compute_ms.UpdateStat(details.host_compute_ms()); - host_prepare_ms.UpdateStat(details.host_prepare_ms()); - host_compile_ms.UpdateStat(details.host_compile_ms()); - } - *result.mutable_unknown_time_ms_summary() = - GetStepSummaryForSampleStats(unknown_time_ms); - *result.mutable_host_wait_input_ms_summary() = - GetStepSummaryForSampleStats(host_wait_input_ms); - *result.mutable_host_to_device_ms_summary() = - GetStepSummaryForSampleStats(host_to_device_ms); - *result.mutable_input_ms_summary() = GetStepSummaryForSampleStats(input_ms); - *result.mutable_output_ms_summary() = GetStepSummaryForSampleStats(output_ms); - *result.mutable_device_compute_ms_summary() = - GetStepSummaryForSampleStats(device_compute_ms); - *result.mutable_device_to_device_ms_summary() = - GetStepSummaryForSampleStats(device_to_device_ms); - *result.mutable_device_collectives_ms_summary() = - GetStepSummaryForSampleStats(device_collectives_ms); - *result.mutable_host_compute_ms_summary() = - GetStepSummaryForSampleStats(host_compute_ms); - *result.mutable_host_prepare_ms_summary() = - GetStepSummaryForSampleStats(host_prepare_ms); - *result.mutable_host_compile_ms_summary() = - GetStepSummaryForSampleStats(host_compile_ms); - return result; -} - -InputPipelineAnalysisResult ComputeGenericInputPipelineAnalysisResult( - const tsl::protobuf::RepeatedPtrField& grouped_by_step) { - InputPipelineAnalysisResult result; - result.set_tag(false); - - // Computes the summary of step time in ms. - *result.mutable_step_time_summary() = - ComputeStepTimeSummaryInMs(grouped_by_step); - - tsl::Stat input_summary_stats_in_percent; - for (const auto& coreid_stepinfo_map : grouped_by_step) { - // Iterates over each step. - const auto* ptr = gtl::FindOrNull(coreid_stepinfo_map.step_info_per_core(), - kDefaultGpuLocalCoreId); - if (ptr == nullptr) { - // For generic hardware, all step-info is put under core-0. If ptr - // is nullptr, it means there is no step at all. - continue; - } - const StepInfoResult& step_info = *ptr; - // Adds the details for a new step. - PerGenericStepDetails details; - details.set_step_number(step_info.step_num()); - if (step_info.step_name().empty()) { - details.set_step_name(absl::StrCat(step_info.step_num())); - } else { - details.set_step_name(step_info.step_name()); - } - details.set_step_time_ms( - tsl::profiler::PicoToMilli(step_info.duration_ps())); - GenericStepBreakdown generic; - bool success = step_info.step_breakdown().UnpackTo(&generic); - if (!success && !step_info.step_breakdown().type_url().empty()) { - LOG(ERROR) << "Unable to unpack step_breakdown. Expected: generic" - << std::endl; - return {}; - } - const auto& type_ps = generic.type_ps(); - details.set_unknown_time_ms(GetTimeInMs(type_ps, UNKNOWN_TIME)); - details.set_host_wait_input_ms(GetTimeInMs(type_ps, HOST_WAIT_INPUT)); - details.set_host_to_device_ms(GetTimeInMs(type_ps, HOST_TO_DEVICE) + - GetTimeInMs(type_ps, DEVICE_WAIT_HOST)); - details.set_output_ms(GetTimeInMs(type_ps, DEVICE_TO_HOST)); - details.set_device_compute_ms(GetTimeInMs(type_ps, DEVICE_COMPUTE_16) + - GetTimeInMs(type_ps, DEVICE_COMPUTE_32)); - details.set_device_to_device_ms(GetTimeInMs(type_ps, DEVICE_TO_DEVICE) + - GetTimeInMs(type_ps, DEVICE_WAIT_DEVICE)); - details.set_device_collectives_ms(GetTimeInMs(type_ps, DEVICE_COLLECTIVES)); - details.set_host_compute_ms(GetTimeInMs(type_ps, HOST_COMPUTE)); - details.set_host_prepare_ms(GetTimeInMs(type_ps, HOST_PREPARE)); - details.set_host_compile_ms(GetTimeInMs(type_ps, HOST_COMPILE)); - result.add_step_details()->PackFrom(details); - - const double input_percent_of_step_time = - 100.0 * tsl::profiler::SafeDivide( - details.host_wait_input_ms() + details.host_to_device_ms(), - details.step_time_ms()); - input_summary_stats_in_percent.UpdateStat(input_percent_of_step_time); - } - - // Computes the summary of input time as percentage of step time. - *result.mutable_input_percent_summary() = - GetStepSummaryForSampleStats(input_summary_stats_in_percent); - - // Computes the breakdown of step time. - GenericStepTimeBreakdown generic_step_time_breakdown = - ComputeGenericStepTimeBreakdownInMs(result); - result.mutable_step_time_breakdown()->PackFrom(generic_step_time_breakdown); - - return result; -} - -// Classification of input processing on the host. -enum class InputOpCategory { - kEnqueue, // enqueue data to be transferred to device. - kDemandedFileRead, // demanded read from file. - kAdvancedFileRead, // advanced read from file (including cached, - // prefetch, parallel-map, interleave). - kPreprocessing // data preprocessing. -}; - -std::string InputOpCategoryString(InputOpCategory category) { - switch (category) { - case InputOpCategory::kEnqueue: - return "Enqueue"; - case InputOpCategory::kDemandedFileRead: - return "Demanded file read"; - case InputOpCategory::kAdvancedFileRead: - return "Advanced file read"; - case InputOpCategory::kPreprocessing: - return "Preprocessing"; - } -} - -inline bool IsInputOp(absl::string_view category) { - // Do not include "IteratorGetNext*" here, because IteratorGetNext is an Op - // that experiences the install stall, not an Op that causes the input stall. - return tsl::profiler::IsInfeedEnqueueOp(category) || - tsl::profiler::IsDatasetOp(category) || - tsl::profiler::IsMemcpyHToDOp(category); -} - -// TODO(ckluk): -// Confirm with the tf.data team if the classification below is correct. -InputOpCategory CategorizeInputOp(absl::string_view name, - absl::string_view category) { - if (tsl::profiler::IsInfeedEnqueueOp(category) || - tsl::profiler::IsMemcpyHToDOp(category)) { - // Ops for sending input from host to device. - return InputOpCategory::kEnqueue; - } - DCHECK(tsl::profiler::IsDatasetOp(category)); - if (absl::EndsWith(name, "::TFRecord") || - absl::EndsWith(name, "::TextLine") || - absl::EndsWith(name, "::FixedLengthRecord") || - absl::EndsWith(name, "::SSTable") || absl::EndsWith(name, "::RecordIO")) { - // Ops that read files. - if (absl::StrContains(name, "::MemoryReader") || - absl::StrContains(name, "::MemoryWriter") || - absl::StrContains(name, "::Interleave") || - absl::StrContains(name, "::Prefetch") || - absl::StrContains(name, "::ParallelMap")) { - // Ops that read files in advance, including caching, interleaving, and - // prefetching. - return InputOpCategory::kAdvancedFileRead; - } else { - // Ops that read files on demand. - return InputOpCategory::kDemandedFileRead; - } - } else { - // All other ops are classified as preprocessing. - return InputOpCategory::kPreprocessing; - } -} - -struct InputOpMetrics { - std::vector input_op_metrics; - uint64 input_op_time_ps = 0; -}; - -InputOpMetrics SelectInputOpMetrics(const OpMetricsDb& all_op_metrics) { - InputOpMetrics input_op_metrics; - for (const OpMetrics* op_metrics : SortedOpMetricsDb(all_op_metrics)) { - if (IsInputOp(op_metrics->category())) { - input_op_metrics.input_op_metrics.push_back(op_metrics); - input_op_metrics.input_op_time_ps += op_metrics->self_time_ps(); - } - } - return input_op_metrics; -} - -InputOpDetails ConvertOpMetricsToInputOpDetails(const OpMetrics& op_metrics, - uint64 input_op_time_ps, - InputOpCategory category) { - InputOpDetails details; - details.set_op_name(op_metrics.name()); - details.set_count(op_metrics.occurrences()); - details.set_time_in_ms(tsl::profiler::PicoToMilli(op_metrics.time_ps())); - details.set_self_time_in_ms( - tsl::profiler::PicoToMilli(op_metrics.self_time_ps())); - details.set_time_in_percent( - 100.0 * - tsl::profiler::SafeDivide(op_metrics.time_ps(), input_op_time_ps)); - details.set_self_time_in_percent( - 100.0 * - tsl::profiler::SafeDivide(op_metrics.self_time_ps(), input_op_time_ps)); - details.set_category(InputOpCategoryString(category)); - return details; -} - -// Returns the ratio of the host-to-device time in each step to the step-time. -double RatioOfHostToDeviceTimeToStepTime( - const OpMetricsDb& host_tf_metrics_db, - const InputPipelineAnalysisResult& input_pipeline_analysis) { - // For TPU execution that uses infeed. - std::optional host_infeed_enqueue_ratio = - HostInfeedEnqueueRatio(host_tf_metrics_db); - if (host_infeed_enqueue_ratio.has_value()) { - return host_infeed_enqueue_ratio.value(); - } - // For GPU and TPU execution that do not use infeed. - double avg_step_time_ms = - input_pipeline_analysis.step_time_summary().average(); - if (avg_step_time_ms > 0) { - // Uses the on-device step time. - GenericStepTimeBreakdown generic_breakdown; - if (input_pipeline_analysis.step_time_breakdown().UnpackTo( - &generic_breakdown)) { - double avg_host_to_device_time_ms = - generic_breakdown.host_to_device_ms_summary().average(); - return tsl::profiler::SafeDivide(avg_host_to_device_time_ms, - avg_step_time_ms); - } - } - return 0.0; -} - -void DeviceCollectivesAnalysis(double device_collectives_percent, - std::string* device_collectives_classification, - std::string* device_collectives_statement) { - if (device_collectives_percent >= - kHighlyDeviceCollectivesBoundThresholdInPercent) { - *device_collectives_classification = "high"; - *device_collectives_statement = - absl::StrCat(OneDigit(device_collectives_percent), - " % of the total step time sampled is spent on 'Device " - "Collective Communication'."); - } else if (device_collectives_percent >= - kModeratelyDeviceCollectivesBoundThresholdInPercent) { - *device_collectives_classification = "moderate"; - *device_collectives_statement = - absl::StrCat(OneDigit(device_collectives_percent), - " % of the total step time sampled is spent on 'Device " - "Collective Communication'."); - } else { - *device_collectives_classification = "no"; - *device_collectives_statement = ""; - } -} - -void KernelLaunchAnalysis(bool tfdata_used, double kernel_launch_percent, - std::string* kernel_launch_classification, - std::string* kernel_launch_statement) { - if (kernel_launch_percent >= kHighlyKernelLaunchBoundThresholdInPercent) { - *kernel_launch_classification = "high"; - *kernel_launch_statement = absl::StrCat( - OneDigit(kernel_launch_percent), - " % of the total step time sampled is spent on 'Kernel Launch'."); - if (tfdata_used) { - absl::StrAppend(kernel_launch_statement, kKernelLaunchTfDataContention); - } - } else if (kernel_launch_percent >= - kModeratelyKernelLaunchBoundThresholdInPercent) { - *kernel_launch_classification = "moderate"; - *kernel_launch_statement = absl::StrCat( - OneDigit(kernel_launch_percent), - " % of the total step time sampled is spent on 'Kernel Launch'."); - if (tfdata_used) { - absl::StrAppend(kernel_launch_statement, kKernelLaunchTfDataContention); - } - } else { - *kernel_launch_classification = "no"; - *kernel_launch_statement = ""; - } -} - -void AllOtherAnalysis(bool all_other_reported, double all_other_percent, - std::string* all_other_classification, - std::string* all_other_statement) { - if (all_other_reported) { - *all_other_classification = "no"; - *all_other_statement = ""; - return; - } - if (all_other_percent >= kHighlyAllOtherBoundThresholdInPercent) { - *all_other_classification = "high"; - *all_other_statement = - absl::StrCat(OneDigit(all_other_percent), kAllOthersPythonExplanation); - } else if (all_other_percent >= kModeratelyAllOtherBoundThresholdInPercent) { - *all_other_classification = "moderate"; - *all_other_statement = - absl::StrCat(OneDigit(all_other_percent), kAllOthersPythonExplanation); - } else { - *all_other_classification = "no"; - *all_other_statement = ""; - } -} - -// Tests if tf.data API is in use. -bool TfDataInUse(const InputTimeBreakdown& breakdown) { - // Do not include enqueue_us because the "enqueue" Op that Xprof recognizes is - // not part of tf.data. - return breakdown.demanded_file_read_us() > 0 || - breakdown.advanced_file_read_us() > 0 || - breakdown.preprocessing_us() > 0; -} - -// Returns a HTML link with the given text. -std::string MakeDocLink(absl::string_view doc_link, absl::string_view text) { - return absl::StrCat("", text, - ""); -} - -// Returns the HTML link to the introduction to the tf.data API. -std::string DatasetIntroDoc() { - return "https://www.tensorflow.org/guide/data"; -} - -struct WaitForScV0Breakdown { - uint64_t DurationPs() const { - return scv0_infeed_duration_ps + scv0_compute_duration_ps; - } - - uint64_t scv0_infeed_duration_ps = 0; - uint64_t scv0_compute_duration_ps = 0; -}; - -struct TcInfeed { - std::optional core_id; - uint64_t duration_ps = 0; -}; - -void ConvertGenericStepBreakdownToTpuStepBreakdown( - const tensorflow::profiler::GenericStepBreakdown& generic_step_breakdown, - uint64_t step_time_ps, TpuStepBreakdown& tpu_step_breakdown) { - auto& category_ps = generic_step_breakdown.category_ps(); - tensorflow::profiler::ProfileTimeBreakdown time_breakdown; - for (const auto& [category, time_ps] : category_ps) { - // Don't add idle time to time_breakdown as the idle time is inferred. - if (category == "IDLE") continue; - time_breakdown.IncrementCategoryTimePs(category, time_ps); - } - time_breakdown.SetProfileTimePs(step_time_ps); - time_breakdown.BreakdownSparseCoreV0Infeed(); - - tpu_step_breakdown.set_infeed_duration_ps(time_breakdown.InfeedTimePs()); - tpu_step_breakdown.set_host_outfeed_ps(time_breakdown.OutfeedTimePs()); - tpu_step_breakdown.set_wait_for_scv0_duration_ps( - time_breakdown.SparseCoreV0InfeedWaitTimePs()); - tpu_step_breakdown.set_scv0_infeed_transform_ps( - time_breakdown.SparseCoreV0InfeedTransformTimePs()); - tpu_step_breakdown.set_scv0_outfeed_ps( - time_breakdown.SparseCoreV0OutfeedTimePs()); - tpu_step_breakdown.set_crs_duration_ps( - time_breakdown.AllReduceOrAllToAllTimePs()); - tpu_step_breakdown.set_send_duration_ps(time_breakdown.SendTimePs()); - tpu_step_breakdown.set_recv_duration_ps(time_breakdown.RecvTimePs()); - tpu_step_breakdown.set_host_send_duration_ps(time_breakdown.HostSendTimePs()); - tpu_step_breakdown.set_host_recv_duration_ps(time_breakdown.HostRecvTimePs()); - tpu_step_breakdown.set_wait_for_megacore_fusion_peer_duration_ps( - time_breakdown.MegacoreFusionTimePs()); - tpu_step_breakdown.set_high_flops_compute_ps( - time_breakdown.HighFlopsComputeTimePs()); - tpu_step_breakdown.set_tc_idle_ps(time_breakdown.IdleTimePs()); - tpu_step_breakdown.set_tc_busy_ps(time_breakdown.TensorCoreBusyTimePs()); -} - -TpuStepTimeBreakdown ComputeTpuStepTimeBreakdownInMs( - const InputPipelineAnalysisResult& analysis, bool has_sparse_core) { - tsl::Stat tc_compute_ms; - tsl::Stat tc_infeed_ms; - tsl::Stat tc_outfeed_ms; - tsl::Stat tc_idle_ms; - tsl::Stat scv0_compute_ms; - tsl::Stat scv0_infeed_ms; - tsl::Stat host_transfer_ms; - tsl::Stat sc_compute_ms; - tsl::Stat sc_infeed_ms; - tsl::Stat sc_outfeed_ms; - tsl::Stat sc_idle_ms; - tsl::Stat sc_step_time_ms; - TpuStepTimeBreakdown result; - - for (const google::protobuf::Any& step_details : analysis.step_details()) { - PerTpuStepDetails details; - if (!step_details.UnpackTo(&details)) { - LOG(ERROR) << "Unable to unpack step_details. Expected: tpu"; - // TODO(b/302086111): Switch back to DFATAL once absl is updated. - DCHECK(false); - return result; - } - tc_compute_ms.UpdateStat(details.tc_compute_time_ms()); - tc_idle_ms.UpdateStat(details.tc_idle_time_ms()); - tc_infeed_ms.UpdateStat(details.tc_infeed_time_ms()); - tc_outfeed_ms.UpdateStat(details.tc_outfeed_time_ms()); - scv0_compute_ms.UpdateStat(details.scv0_compute_time_ms()); - scv0_infeed_ms.UpdateStat(details.scv0_infeed_time_ms()); - host_transfer_ms.UpdateStat(details.host_transfer_ms()); - sc_compute_ms.UpdateStat(details.sc_compute_time_ms()); - sc_idle_ms.UpdateStat(details.sc_idle_time_ms()); - sc_infeed_ms.UpdateStat(details.sc_infeed_time_ms()); - sc_outfeed_ms.UpdateStat(details.sc_outfeed_time_ms()); - sc_step_time_ms.UpdateStat(details.sc_step_time_ms()); - } - *result.mutable_tc_compute_ms_summary() = - GetStepSummaryForSampleStats(tc_compute_ms); - *result.mutable_scv0_compute_ms_summary() = - GetStepSummaryForSampleStats(scv0_compute_ms); - *result.mutable_tc_infeed_ms_summary() = - GetStepSummaryForSampleStats(tc_infeed_ms); - *result.mutable_tc_outfeed_ms_summary() = - GetStepSummaryForSampleStats(tc_outfeed_ms); - *result.mutable_scv0_infeed_ms_summary() = - GetStepSummaryForSampleStats(scv0_infeed_ms); - *result.mutable_tc_idle_ms_summary() = - GetStepSummaryForSampleStats(tc_idle_ms); - *result.mutable_host_transfer_ms_summary() = - GetStepSummaryForSampleStats(host_transfer_ms); - if (has_sparse_core) { - auto* sparse_core_step_summary = result.mutable_sparse_core_step_summary(); - *sparse_core_step_summary->mutable_sc_compute_ms_summary() = - GetStepSummaryForSampleStats(sc_compute_ms); - *sparse_core_step_summary->mutable_sc_infeed_ms_summary() = - GetStepSummaryForSampleStats(sc_infeed_ms); - *sparse_core_step_summary->mutable_sc_outfeed_ms_summary() = - GetStepSummaryForSampleStats(sc_outfeed_ms); - *sparse_core_step_summary->mutable_sc_idle_ms_summary() = - GetStepSummaryForSampleStats(sc_idle_ms); - *sparse_core_step_summary->mutable_sc_step_time_ms_summary() = - GetStepSummaryForSampleStats(sc_step_time_ms); - } - return result; -} - -// Given the step sequence on each core, computes the result proto of the -// input-pipeline analysis tool (the InputPipelineAnalysisResult defined in -// input_pipeline.proto). -// Note on grouped_by_step: There is one element for each step executed (on -// multiple cores). Each element is a map from the core_id to the information -// of the step that runs on that core. Elements are in the same order that the -// steps are executed over time. -InputPipelineAnalysisResult ComputeTpuInputPipelineAnalysisResult( - const tsl::protobuf::RepeatedPtrField& grouped_by_step, - const tsl::protobuf::Map& - core_details_map) { - InputPipelineAnalysisResult result; - bool has_sparse_core = false; - for (const auto& [core_id, core_details] : core_details_map) { - has_sparse_core |= core_details.is_sparse_core(); - } - - // Computes the summary of step time in ms. - *result.mutable_step_time_summary() = - ComputeStepTimeSummaryInMs(grouped_by_step); - - // Summary of the statistics of infeed time as percentage of the step - // time. - tsl::Stat infeed_summary_stats_in_percent; - for (const auto& coreid_stepinfo_map : grouped_by_step) { - // Compute each TPU step stats. - const PerTpuStepDetails& per_step_data = - ComputeTpuPerStepDataAcrossCores(coreid_stepinfo_map, core_details_map); - result.add_step_details()->PackFrom(per_step_data); - - // The infeed summary is based on the maximum infeed time across cores at - // each step. - infeed_summary_stats_in_percent.UpdateStat( - per_step_data.infeed_percent_maximum()); - } - - // Computes the summary of infeed time as percentage of step time. - *result.mutable_input_percent_summary() = - GetStepSummaryForSampleStats(infeed_summary_stats_in_percent); - - // Computes the breakdown of step time - TpuStepTimeBreakdown tpu_step_time_breakdown = - ComputeTpuStepTimeBreakdownInMs(result, has_sparse_core); - result.mutable_step_time_breakdown()->PackFrom(tpu_step_time_breakdown); - result.set_tag(true); - - return result; -} - -// Returns true if device_op_metrics_db contains an infeed op. -bool HasTpuInfeedOp(const OpMetricsDb& device_op_metrics_db) { - for (const OpMetrics& metrics : device_op_metrics_db.metrics_db()) { - if (tsl::profiler::IsHostOrSparseCoreV0Infeed(metrics.category())) { - return true; - } - } - return false; -} - -// Returns the time spent waiting for input for generic hardware. -uint64_t TotalInputPs(const StepDetails& step_details) { - uint64_t total_input_ps = 0; - for (const auto& event : step_details.Events()) { - if (event.type == HOST_WAIT_INPUT || event.type == HOST_TO_DEVICE) { - // Includes both the time where the host was waiting input and the time - // where the host was sending data to the device. - total_input_ps += event.span.duration_ps(); - } - } - return total_input_ps; -} - -void TensorCoreIdleAnalysis(bool all_cores_profiled, double tc_idle_percent, - std::string* input_classification, - std::string* input_statement, - std::string* tc_idle_classification, - std::string* tc_idle_statement) { - // In MayFixTpuStepAnalysis(), we have already separated the idle time from - // the input time. So, we don't need to substract the input time from the - // idle time here. - if (tc_idle_percent < kTcIdleThresholdInPercent) { - *tc_idle_classification = "no"; - *tc_idle_statement = ""; - return; - } - std::string idle_percent_str = absl::StrFormat("%.1lf", tc_idle_percent); - if (all_cores_profiled) { - // Significant idle time with all cores profiled. - *tc_idle_classification = "yes"; - *tc_idle_statement = - absl::StrCat(idle_percent_str, - " % of the total step time sampled is due to host " - "overhead that is not input-related. For TF 2.x, you may " - "want to use a ", - AnchorElement(kMultipleStepsInTffunctionDoc, - "host-training loop (i.e. running multiple " - "steps within a tf.function).")); - return; - } - - // Significant idle time without all cores profiled. - if (*input_classification == "host") { - // We've already identified that it is input bound. So, no need to issue - // more warnings. - *tc_idle_classification = "no"; - *tc_idle_statement = ""; - return; - } - - *input_classification = "host"; // focuses on "host" first. - *input_statement = absl::StrCat( - "Your program COULD be input-bound because ", idle_percent_str, - "% of the total step time is idle. This may be a manifestation of an " - "input issue on a worker " - "machine that was not profiled. To be certain, please profile ALL " - "worker machines in your job by following ", - AnchorElement(kProfileAllHostsDoc, "this instruction.")); - *tc_idle_classification = "no"; - *tc_idle_statement = ""; -} - -void AllReduceAnalysis(bool all_cores_profiled, - double all_reduce_compute_percent, - double all_reduce_sync_percent, double input_percent, - std::string* input_classification, - std::string* input_statement, - std::string* all_reduce_classification, - std::string* all_reduce_statement) { - double all_reduce_percent = - all_reduce_compute_percent + all_reduce_sync_percent; - // Since all-reduce time is overlapped with the input time, we consider the - // all-reduce time that is not input related. - double all_reduce_not_input_related_percent = - all_reduce_percent - input_percent; - - if (all_reduce_not_input_related_percent < - kAllReduceBoundThresholdInPercent) { - // Insignificant time spent on all-reduce. - *all_reduce_classification = "no"; - *all_reduce_statement = ""; - return; - } - - if (all_cores_profiled) { - // Significant time spent on all-reduce with all cores profiled. - std::string all_reduce_compute_percent_str = - absl::StrFormat("%.1lf", all_reduce_compute_percent); - std::string all_reduce_sync_percent_str = - absl::StrFormat("%.1lf", all_reduce_sync_percent); - *all_reduce_classification = "yes"; - *all_reduce_statement = absl::StrCat( - "Also, ", all_reduce_sync_percent_str, - " % of the total step time sampled is spent on synchronization with " - "other TPU cores, and ", - all_reduce_compute_percent_str, - " % of the total step time sampled is spent on actual AllReduce."); - return; - } - - // Significant time spent on all-reduce and not all cores were profiled. - std::string all_reduce_percent_str = - absl::StrFormat("%.1lf", all_reduce_percent); - - if (*input_classification != "device") { - // InputAnalysis() already indicates some potential input issue. So, we - // can focus on all-reduce performance. - *all_reduce_classification = "yes"; - *all_reduce_statement = absl::StrCat( - "Also, ", all_reduce_percent_str, - " % of the total step time sampled is spent on synchronization " - "with " - "other TPU cores and AllReduce. Not all worker machines are " - "profiled, " - "therefore " - "we " - "cannot disambiguate the actual time for AllReduce from the " - "synchronization. To be certain, please profile ALL " - "worker machines in your job by following ", - AnchorElement(kProfileAllHostsDoc, "this instruction.")); - return; - } - - // InputAnalysis() indicates that it is NOT input-bound. However, it may - // be because the input delay is manifested as all-reduce time. So, - // attribute it to a possible input issue. - *input_classification = "host"; // focuses on "host" first. - *input_statement = absl::StrCat( - "Your program COULD be input-bound because ", all_reduce_percent_str, - "% of the total step time is spent on synchronization with other " - "TPU cores. This may be a manifestation of an input issue on a " - "worker " - "machine that was not profiled. To be certain, please profile ALL " - "worker machines in your job by following ", - AnchorElement(kProfileAllHostsDoc, "this instruction.")); - *all_reduce_classification = "no"; - *all_reduce_statement = ""; -} - -void ScV0Analysis(double scv0_percent, std::string* scv0_classification, - std::string* scv0_statement) { - if (scv0_percent == 0) { - *scv0_classification = "no"; - *scv0_statement = ""; - return; - } - std::string scv0_percent_str = absl::StrFormat("%.1lf", scv0_percent); - if (scv0_percent < kModeratelySparseCoreV0BoundThresholdInPercent) { - *scv0_classification = "moderate"; - *scv0_statement = absl::StrCat( - "Also, ", scv0_percent_str, - " % of the total step time sampled is spent on the ", kSparseCoreV0Name, - " compute. You may also want to reduce the ", kSparseCoreV0Name, - " compute time."); - return; - } - *scv0_classification = "high"; - *scv0_statement = absl::StrCat( - "Also, ", scv0_percent_str, - " % of the total step time sampled is spent on the ", kSparseCoreV0Name, - " compute. You should focus on reducing the ", kSparseCoreV0Name, - " compute time as well."); -} - -// A map keeps track of the minimum value associated with an id. -class MinMap { - public: - void Observe(uint64_t id, uint64_t value) { - auto [iter, inserted] = min_map_.try_emplace(id, value); - if (!inserted && iter->second > value) { - iter->second = value; - } - } - - uint64_t Min(uint64_t id) const { - auto iter = min_map_.find(id); - return (iter != min_map_.end()) ? iter->second : 0; - } - - private: - absl::flat_hash_map min_map_; -}; - -} // namespace - -PerTpuStepDetails ComputeTpuPerStepDataAcrossCores( - const PerCoreStepInfo& coreid_stepinfo_map, - const tsl::protobuf::Map& - core_details_map) { - PerTpuStepDetails per_step_data; - - PerCoreAllReduceBreakdown all_reduce_breakdown = - ComputePerStepAllReduceBreakdownAcrossCores(coreid_stepinfo_map); - - tsl::Stat infeed_percent_stats; - tsl::Stat step_stats_in_ps; - tsl::Stat optimal_step_time_ps; - // Take the average TC outfeed time in result. - tsl::Stat tc_outfeed_time_in_ps; - tsl::Stat sc_compute_time_ps; - tsl::Stat sc_step_stats_in_ps; - tsl::Stat sc_outfeed_time_in_ps; - tsl::Stat sc_infeed_time_in_ps; - tsl::Stat sc_idle_time_in_ps; - - tsl::Stat host_send_recv_time_ps; - - // For the core with the max wait-for-scv0 duration, breakdown to compute and - // infeed time. - WaitForScV0Breakdown max_wait_for_scv0; - - TcInfeed max_infeed; - - // For the core with the max all reduce duration, breakdown to compute and - // synchronization time. - AllReduceBreakdown max_all_reduce; - - per_step_data.set_step_number(-1); - auto process_step_for_sc = - [&](const tensorflow::profiler::StepInfoResult& step_info, - const SparseCoreStepBreakdown& sc_step) { - if (per_step_data.step_number() < 0) { - per_step_data.set_step_number(step_info.step_num()); - } else { - if (per_step_data.step_number() != step_info.step_num()) { - VLOG(1) << "Inconsistent step numbers across cores (" - << per_step_data.step_number() << " vs. " - << step_info.step_num() << ")."; - } - } - sc_step_stats_in_ps.UpdateStat(step_info.duration_ps()); - sc_outfeed_time_in_ps.UpdateStat(sc_step.sc_outfeed_ps()); - sc_infeed_time_in_ps.UpdateStat(sc_step.sc_infeed_ps()); - sc_compute_time_ps.UpdateStat(step_info.duration_ps() - - sc_step.sc_infeed_ps() - - sc_step.sc_outfeed_ps()); - sc_idle_time_in_ps.UpdateStat(sc_step.sc_idle_ps()); - }; - for (const auto& [core_id, step_info] : - coreid_stepinfo_map.step_info_per_core()) { - // iterates over each core. - TpuStepBreakdown tpu; - if (!step_info.step_breakdown().UnpackTo(&tpu)) { - VLOG(1) << "Unable to unpack step_breakdown from tpu, try unpacking from " - "generic"; - tensorflow::profiler::GenericStepBreakdown generic_step_breakdown; - if (!step_info.step_breakdown().UnpackTo(&generic_step_breakdown)) { - SparseCoreStepBreakdown sc_step; - if (step_info.step_breakdown().UnpackTo(&sc_step)) { - process_step_for_sc(step_info, sc_step); - continue; - } else { - LOG(ERROR) << "Unable to unpack step_breakdown from " - "GenericStepBreakdown or SparseCoreStepBreakdown"; - // TODO(b/302086111): Switch back to DFATAL once absl is updated. - DCHECK(false); - return per_step_data; - } - } - if (core_id >= kSparseCoreIndexStart) { - // Sparse core step breakdown from xspace. - uint64_t idle_time_ps = 0; - uint64_t busy_time_ps = 0; - for (const auto& [category, time_ps] : - generic_step_breakdown.category_ps()) { - if (category == kIdle) { - idle_time_ps = time_ps; - } else if (category == "sparse_core_busy_ops") { - busy_time_ps = time_ps; - } - } - sc_step_stats_in_ps.UpdateStat(step_info.duration_ps()); - sc_compute_time_ps.UpdateStat(busy_time_ps); - sc_idle_time_in_ps.UpdateStat(idle_time_ps); - continue; - } else { - // Tensor core step breakdown from xspace. - ConvertGenericStepBreakdownToTpuStepBreakdown( - generic_step_breakdown, step_info.duration_ps(), tpu); - } - } - step_stats_in_ps.UpdateStat(step_info.duration_ps()); - if (tpu.wait_for_scv0_duration_ps() > max_wait_for_scv0.DurationPs()) { - max_wait_for_scv0.scv0_infeed_duration_ps = ScV0InfeedDurationPs(tpu); - max_wait_for_scv0.scv0_compute_duration_ps = ScV0ComputeDurationPs(tpu); - } - - tc_outfeed_time_in_ps.UpdateStat(tpu.host_outfeed_ps()); - - const AllReduceBreakdown& breakdown = all_reduce_breakdown[core_id]; - if (breakdown.DurationPs() > max_all_reduce.DurationPs()) { - max_all_reduce = breakdown; - } - - infeed_percent_stats.UpdateStat(100.0 * TcPlusScV0InfeedDurationPs(tpu) / - step_info.duration_ps()); - // The optimal step time is the actual step time minus the time tensor - // core spends waiting for host or sparsecorev0 (but not other tensor - // cores). - optimal_step_time_ps.UpdateStat(step_info.duration_ps() - - WaitForHostOrScV0DurationPs(tpu)); - host_send_recv_time_ps.UpdateStat(HostSendRecvDurationPs(tpu)); - - if (per_step_data.step_number() < 0) { - // Sets the step number of the current step from the first core. - per_step_data.set_step_number(step_info.step_num()); - } else { - // The step number of the current step is already set. Checks if it is - // the same across cores. In case of multi-host tracing, we may have - // some inconsistent steps as tracing is not exactly guaranteed to be - // synchronized across all hosts. - if (per_step_data.step_number() != step_info.step_num()) { - VLOG(1) << "Inconsistent step numbers across cores (" - << per_step_data.step_number() << " vs. " - << step_info.step_num() << ")."; - } - } - if (tpu.infeed_duration_ps() > max_infeed.duration_ps) { - max_infeed.core_id = core_id; - max_infeed.duration_ps = tpu.infeed_duration_ps(); - } - } - - per_step_data.set_tc_outfeed_time_ms( - tsl::profiler::PicoToMilli(tc_outfeed_time_in_ps.avg())); - // The TC compute time is the minimum of the optimal step time across cores. - per_step_data.set_tc_compute_time_ms( - tsl::profiler::PicoToMilli(optimal_step_time_ps.min())); - per_step_data.set_host_transfer_ms( - tsl::profiler::PicoToMilli(host_send_recv_time_ps.max())); - // TODO(b/153730997): Use the maximum step time. - // The infeed time is the step time across cores minus all other times. - // Previously, we used the maximum step time but changed to use the minimum - // step time to work around b/153730997. - // Uses the max TC infeed duration across cores as the step's TC infeed - // duration. - per_step_data.set_tc_infeed_time_ms( - tsl::profiler::PicoToMilli(max_infeed.duration_ps)); - if (max_infeed.core_id.has_value()) { - per_step_data.set_coreid_max_infeed_time(max_infeed.core_id.value()); - if (core_details_map.contains(max_infeed.core_id.value())) { - const CoreDetails& core_details = - core_details_map.at(max_infeed.core_id.value()); - per_step_data.set_max_infeed_time_core_name(absl::StrCat( - core_details.hostname(), ":", core_details.device_ordinal())); - } - } - - per_step_data.set_scv0_compute_time_ms( - tsl::profiler::PicoToMilli(max_wait_for_scv0.scv0_compute_duration_ps)); - per_step_data.set_scv0_infeed_time_ms( - tsl::profiler::PicoToMilli(max_wait_for_scv0.scv0_infeed_duration_ps)); - - // The TC idle time is the time TC spends waiting for the host but not - // waiting for input. - per_step_data.set_tc_idle_time_ms( - tsl::profiler::PicoToMilli(step_stats_in_ps.min()) - - NonIdleTimeMs(per_step_data)); - if (per_step_data.tc_idle_time_ms() < 0) { - per_step_data.set_tc_idle_time_ms(0); - } - - per_step_data.set_all_reduce_compute_time_ms( - tsl::profiler::PicoToMilli(max_all_reduce.compute_duration_ps)); - per_step_data.set_all_reduce_sync_time_ms( - tsl::profiler::PicoToMilli(max_all_reduce.sync_duration_ps)); - - per_step_data.set_infeed_percent_average(infeed_percent_stats.avg()); - per_step_data.set_infeed_percent_minimum(infeed_percent_stats.min()); - per_step_data.set_infeed_percent_maximum(infeed_percent_stats.max()); - - per_step_data.set_sc_infeed_time_ms( - tsl::profiler::PicoToMilli(sc_infeed_time_in_ps.avg())); - per_step_data.set_sc_outfeed_time_ms( - tsl::profiler::PicoToMilli(sc_outfeed_time_in_ps.avg())); - per_step_data.set_sc_compute_time_ms( - tsl::profiler::PicoToMilli(sc_compute_time_ps.min())); - per_step_data.set_sc_idle_time_ms( - tsl::profiler::PicoToMilli(sc_idle_time_in_ps.avg())); - per_step_data.set_sc_step_time_ms( - tsl::profiler::PicoToMilli(sc_step_stats_in_ps.avg())); - if (per_step_data.sc_idle_time_ms() < 0) { - per_step_data.set_sc_idle_time_ms(0); - } - return per_step_data; -} - -StepSummary GetStepSummaryForSampleStats( - const tsl::Stat& sample_stats) { - StepSummary step_time_summary; - double avg, sdv, min, max; - if (sample_stats.empty()) { - // If sample_stats is empty, sample_stats.avg() will return NaN. However, we - // prefer to show an 0 instead. - avg = sdv = min = max = 0.0; - } else { - avg = sample_stats.avg(); - sdv = sqrt(sample_stats.sample_variance()); - min = sample_stats.min(); - max = sample_stats.max(); - } - step_time_summary.set_average(avg); - step_time_summary.set_standard_deviation(sdv); - step_time_summary.set_minimum(min); - step_time_summary.set_maximum(max); - return step_time_summary; -} - -PerCoreAllReduceBreakdown ComputePerStepAllReduceBreakdownAcrossCores( - const PerCoreStepInfo& coreid_stepinfo_map) { - PerCoreAllReduceBreakdown result; - MinMap min_duration_map; - for (const auto& [core_id, all_reduce_db] : - coreid_stepinfo_map.all_reduce_db_per_core()) { - for (const auto& all_reduce : all_reduce_db.all_reduce_info()) { - uint64_t duration_ps = - all_reduce.end_time_ps() - all_reduce.start_time_ps(); - min_duration_map.Observe(all_reduce.id(), duration_ps); - } - } - for (const auto& [core_id, all_reduce_db] : - coreid_stepinfo_map.all_reduce_db_per_core()) { - AllReduceBreakdown& breakdown = result[core_id]; - for (const auto& all_reduce : all_reduce_db.all_reduce_info()) { - uint64_t duration_ps = - all_reduce.end_time_ps() - all_reduce.start_time_ps(); - uint64_t min_duration_ps = min_duration_map.Min(all_reduce.id()); - breakdown.compute_duration_ps += min_duration_ps; - breakdown.sync_duration_ps += duration_ps - min_duration_ps; - } - } - return result; -} - -void MayFixTpuStepAnalysis( - const StepEvents& host_step_events, const OpMetricsDb& device_op_metrics_db, - StepDatabaseResult& step_db, - const tsl::protobuf::Map& - core_details_map) { - // This code is only applicable when input is received by the tensor core - // from the host without the use of infeed. If the tensor core receives - // input via host infeed or via sparsecorev0 infeed, there's nothing to do. - if (HasTpuInfeedOp(device_op_metrics_db)) return; - - for (PerCoreStepInfo& per_core_step_info : - *(step_db.mutable_step_sequence())) { - uint32_t step_num = per_core_step_info.step_num(); - // TODO(ckluk): step_num is obtained from tf_op_stats, which is based on the - // step-tracking mechanism with the on-device training loop. However, this - // step_num is different from the group_id. So, what we are doing here is - // only an approximation, assuming that all steps exhibit similar - // breakdown. Once grouping works on TPU device, we need to replace step_num - // by the group_id from TPU device. - const StepDetails* step_details = - gtl::FindOrNull(host_step_events, step_num); - if (step_details == nullptr) { - continue; // step_num not in host_step_events, we don't know how to fix. - } - uint64_t total_input_ps = TotalInputPs(*step_details); - if (total_input_ps == 0) { - continue; // no host input events. - } - PerTpuStepDetails tpu_step_data = - ComputeTpuPerStepDataAcrossCores(per_core_step_info, core_details_map); - double tc_idle_ms = tpu_step_data.tc_idle_time_ms(); - double adjusted_input_ratio = - std::min(tsl::profiler::SafeDivide( - tsl::profiler::PicoToMilli(total_input_ps), tc_idle_ms), - 1.0); - for (auto& [core_id, step_info] : - *per_core_step_info.mutable_step_info_per_core()) { - // skip sparse cores for this. - if (core_id >= kSparseCoreIndexStart) continue; - TpuStepBreakdown tpu; - if (TpuStepBreakdown tpu; step_info.step_breakdown().UnpackTo(&tpu)) { - DCHECK_EQ(tpu.infeed_duration_ps(), 0); - if (tpu.tc_idle_ps() > 0) { - // Extract the infeed fraction of idle time. - tpu.set_infeed_duration_ps(tpu.tc_idle_ps() * adjusted_input_ratio); - tpu.set_tc_idle_ps(tpu.tc_idle_ps() - tpu.infeed_duration_ps()); - step_info.mutable_step_breakdown()->PackFrom(tpu); - } - } else if (tensorflow::profiler::GenericStepBreakdown generic; - step_info.step_breakdown().UnpackTo(&generic)) { - uint64_t& infeed_time_ps = - (*generic.mutable_category_ps())[xla::HloOpcodeString( - xla::HloOpcode::kInfeed)]; - uint64_t& idle_time_ps = - (*generic.mutable_category_ps())[tensorflow::profiler::kIdle]; - DCHECK_EQ(infeed_time_ps, 0); - if (idle_time_ps > 0) { - infeed_time_ps = idle_time_ps * adjusted_input_ratio; - idle_time_ps -= infeed_time_ps; - step_info.mutable_step_breakdown()->PackFrom(generic); - } - } else { - // Likely encountered an ScStepBreakdown instance which can be skipped - // as we only care about attributing TC idle time to host. - LOG(INFO) << "Unable to unpack step_breakdown."; - } - } - } -} - -TpuBottleneckAnalysis ComputeTpuBottleneckAnalysis( - bool all_cores_profiled, const InputPipelineAnalysisResult& result) { - double total_step_time_ms = 0; - double total_infeed_time_ms = 0; - double total_tc_outfeed_time_ms = 0; - double total_scv0_compute_time_ms = 0; - double total_all_reduce_compute_time_ms = 0; - double total_all_reduce_sync_time_ms = 0; - double total_tc_idle_time_ms = 0; - - TpuBottleneckAnalysis analysis; - for (const google::protobuf::Any& step_details : result.step_details()) { - PerTpuStepDetails details; - if (!step_details.UnpackTo(&details)) { - LOG(ERROR) << "Unable to unpack step_details. Expected: tpu"; - // TODO(b/302086111): Switch back to DFATAL once absl is updated. - DCHECK(false); - return analysis; - } - total_step_time_ms += StepTimeMs(details); - total_infeed_time_ms += InfeedTimeMs(details); - total_tc_outfeed_time_ms += details.tc_outfeed_time_ms(); - total_scv0_compute_time_ms += details.scv0_compute_time_ms(); - total_all_reduce_compute_time_ms += details.all_reduce_compute_time_ms(); - total_all_reduce_sync_time_ms += details.all_reduce_sync_time_ms(); - total_tc_idle_time_ms += details.tc_idle_time_ms(); - } - if (total_step_time_ms == 0) { - analysis.set_input_classification("unknown"); - analysis.set_input_statement( - "No step time measured. Therefore we cannot tell where the performance " - "bottleneck is."); - analysis.set_tc_idle_classification("no"), - analysis.set_tc_idle_statement(""); - analysis.set_scv0_classification("no"); - analysis.set_scv0_statement(""); - analysis.set_all_reduce_classification("no"); - analysis.set_all_reduce_statement(""); - return analysis; - } - - double infeed_percent = 100.0 * total_infeed_time_ms / total_step_time_ms; - std::string input_classification; - std::string input_statement; - InputAnalysis(infeed_percent, /*all_other_percent=*/0, &input_classification, - &input_statement); - - double tc_outfeed_percent = - 100.0 * total_tc_outfeed_time_ms / total_step_time_ms; - std::string output_classification; - std::string output_statement; - OutputAnalysis(tc_outfeed_percent, &output_classification, &output_statement); - - double tc_idle_percent = 100.0 * total_tc_idle_time_ms / total_step_time_ms; - std::string tc_idle_classification; - std::string tc_idle_statement; - TensorCoreIdleAnalysis(all_cores_profiled, tc_idle_percent, - &input_classification, &input_statement, - &tc_idle_classification, &tc_idle_statement); - - double all_reduce_compute_percent = - 100.0 * total_all_reduce_compute_time_ms / total_step_time_ms; - double all_reduce_sync_percent = - 100.0 * total_all_reduce_sync_time_ms / total_step_time_ms; - std::string all_reduce_classification; - std::string all_reduce_statement; - AllReduceAnalysis(all_cores_profiled, all_reduce_compute_percent, - all_reduce_sync_percent, infeed_percent, - &input_classification, &input_statement, - &all_reduce_classification, &all_reduce_statement); - - double scv0_percent = 100.0 * total_scv0_compute_time_ms / total_step_time_ms; - std::string scv0_classification; - std::string scv0_statement; - ScV0Analysis(scv0_percent, &scv0_classification, &scv0_statement); - - // compute_percent includes both TC and ScV0 compute. - double compute_percent = std::max( - 0.0, 100.0 - infeed_percent - tc_outfeed_percent - tc_idle_percent); - - analysis.set_compute_percent(compute_percent); - analysis.set_input_percent(infeed_percent); - analysis.set_output_percent(tc_outfeed_percent); - analysis.set_tc_idle_percent(tc_idle_percent); - analysis.set_input_classification(input_classification); - analysis.set_input_statement(input_statement); - analysis.set_output_statement(output_statement); - analysis.set_tc_idle_classification(tc_idle_classification), - analysis.set_tc_idle_statement(tc_idle_statement); - analysis.set_scv0_classification(scv0_classification); - analysis.set_scv0_statement(scv0_statement); - analysis.set_all_reduce_classification(all_reduce_classification); - analysis.set_all_reduce_statement(all_reduce_statement); - return analysis; -} - -void GenerateHostResult(const OpMetricsDb& host_tf_metrics_db, - InputPipelineAnalysisResult* result) { - InputOpMetrics input_op_metrics = SelectInputOpMetrics(host_tf_metrics_db); - // Returns if the program is not using an input pipeline with - // instrumentation and hence no input ops are found. - if (input_op_metrics.input_op_metrics.empty()) return; - - absl::flat_hash_map aggregated_input_op_times_us; - for (const OpMetrics* op_metrics : input_op_metrics.input_op_metrics) { - InputOpCategory category = - CategorizeInputOp(op_metrics->name(), op_metrics->category()); - *result->add_input_op_details() = ConvertOpMetricsToInputOpDetails( - *op_metrics, input_op_metrics.input_op_time_ps, category); - aggregated_input_op_times_us[category] += - tsl::profiler::PicoToMicro(op_metrics->self_time_ps()); - } - - double enqueue_time_us = - aggregated_input_op_times_us[InputOpCategory::kEnqueue]; - double total_input_op_time_us = - aggregated_input_op_times_us[InputOpCategory::kDemandedFileRead] + - aggregated_input_op_times_us[InputOpCategory::kAdvancedFileRead] + - aggregated_input_op_times_us[InputOpCategory::kPreprocessing]; - - double ratio = std::min( - 1.0, RatioOfHostToDeviceTimeToStepTime(host_tf_metrics_db, *result)); - DCHECK_GE(ratio, 0.0); - double non_enqueue_time_us = (ratio != 0.0) - ? (enqueue_time_us * (1.0 - ratio) / ratio) - : total_input_op_time_us; - - // Scales the various input-time components wrt to non_enqueue_time_us. - double scaled_demanded_fileread_time_us = tsl::profiler::SafeDivide( - non_enqueue_time_us * - aggregated_input_op_times_us[InputOpCategory::kDemandedFileRead], - total_input_op_time_us); - double scaled_advanced_fileread_time_us = tsl::profiler::SafeDivide( - non_enqueue_time_us * - aggregated_input_op_times_us[InputOpCategory::kAdvancedFileRead], - total_input_op_time_us); - double scaled_preprocessing_time_us = tsl::profiler::SafeDivide( - non_enqueue_time_us * - aggregated_input_op_times_us[InputOpCategory::kPreprocessing], - total_input_op_time_us); - double unclassified_non_enqueue_time_us = std::max( - 0.0, non_enqueue_time_us - scaled_demanded_fileread_time_us - - scaled_advanced_fileread_time_us - scaled_preprocessing_time_us); - - InputTimeBreakdown* input_time_breakdown = - result->mutable_input_time_breakdown(); - input_time_breakdown->set_enqueue_us(enqueue_time_us); - input_time_breakdown->set_demanded_file_read_us( - scaled_demanded_fileread_time_us); - input_time_breakdown->set_advanced_file_read_us( - scaled_advanced_fileread_time_us); - input_time_breakdown->set_preprocessing_us(scaled_preprocessing_time_us); - input_time_breakdown->set_unclassified_non_enqueue_us( - unclassified_non_enqueue_time_us); -} - -InputPipelineAnalysisRecommendation GenerateRecommendation() { - const absl::string_view kDatasetIntro = - "https://www.tensorflow.org/programmers_guide/datasets"; - - const absl::string_view kDatasetTopic = - "https://www.tensorflow.org/api_docs/python/tf/data/Dataset#"; - - const absl::string_view kTfRecordDataset = - "https://www.tensorflow.org/api_docs/python/tf/data/" - "TFRecordDataset#class_tfrecorddataset"; - - InputPipelineAnalysisRecommendation recommendation; - *recommendation.add_details() = - "Enqueuing data: you may want to combine small input data chunks " - "into fewer " - "but larger chunks."; - *recommendation.add_details() = absl::StrCat( - "Data preprocessing: you may increase num_parallel_calls in ", - AnchorElement(absl::StrCat(kDatasetTopic, "map"), "Dataset map()"), - " or preprocess the data OFFLINE."); - *recommendation.add_details() = absl::StrCat( - "Reading data from files in advance: you may tune parameters in the " - "following tf.data API (", - AnchorElement(absl::StrCat(kDatasetTopic, "prefetch"), "prefetch size"), - ", ", - AnchorElement(absl::StrCat(kDatasetTopic, "interleave"), - "interleave cycle_length"), - ", ", AnchorElement(kTfRecordDataset, "reader buffer_size"), ")"); - *recommendation.add_details() = absl::StrCat( - "Reading data from files on demand: you should read data IN ADVANCE " - "using the following tf.data API (", - AnchorElement(absl::StrCat(kDatasetTopic, "prefetch"), "prefetch"), ", ", - AnchorElement(absl::StrCat(kDatasetTopic, "interleave"), "interleave"), - ", ", AnchorElement(kTfRecordDataset, "reader buffer"), ")"); - *recommendation.add_details() = absl::StrCat( - "Other data reading or processing: you may consider using the ", - AnchorElement(kDatasetIntro, "tf.data API"), - " (if you are not using it now)"); - - return recommendation; -} - -StepSummary ComputeStepTimeSummaryInMs( - const tsl::protobuf::RepeatedPtrField& grouped_by_step) { - tsl::Stat total_step_stats_in_ms; - // iterates over each step. - for (const auto& coreid_stepinfo_map : grouped_by_step) { - double max_per_step_stats_in_ms = 0.0; - // iterates over each core. - for (const auto& coreid_and_stepinfo : - coreid_stepinfo_map.step_info_per_core()) { - if (coreid_and_stepinfo.first >= kSparseCoreIndexStart) continue; - const auto& step_info = coreid_and_stepinfo.second; - max_per_step_stats_in_ms = std::max(step_info.duration_ps() / kNumPsPerMs, - max_per_step_stats_in_ms); - } - // Step time of each step is determined by the slowest core. - total_step_stats_in_ms.UpdateStat(max_per_step_stats_in_ms); - } - - return GetStepSummaryForSampleStats(total_step_stats_in_ms); -} - -InputPipelineAnalysisResult ConvertOpStatsToInputPipelineAnalysis( - const OpStats& op_stats) { - const HardwareType hardware_type = op_stats.run_environment().hardware_type(); - - InputPipelineAnalysisResult result; - if (hardware_type == tensorflow::profiler::TPU) { - result = ComputeTpuInputPipelineAnalysisResult( - op_stats.step_db().step_sequence(), op_stats.core_id_to_details()); - } else { - result = ComputeGenericInputPipelineAnalysisResult( - op_stats.step_db().step_sequence()); - } - result.set_hardware_type(HardwareType_Name(hardware_type)); - - PopulateStepDiagnostics(op_stats, result.mutable_diagnostics()); - GenerateHostResult(op_stats.host_op_metrics_db(), &result); - - InputPipelineAnalysisRecommendation recommendation = GenerateRecommendation(); - if (hardware_type == tensorflow::profiler::TPU) { - TpuBottleneckAnalysis bottleneck_analysis = ComputeTpuBottleneckAnalysis( - /*all_cores_profiled=*/true, result); - result.set_input_percent(bottleneck_analysis.input_percent()); - result.set_output_percent(bottleneck_analysis.output_percent()); - result.set_idle_percent(bottleneck_analysis.tc_idle_percent()); - result.set_compute_percent(bottleneck_analysis.compute_percent()); - - recommendation.mutable_bottleneck_analysis()->PackFrom(bottleneck_analysis); - *recommendation.mutable_summary_next_step() = - GetSummaryNextStep(bottleneck_analysis.input_classification(), - result.input_time_breakdown()); - } else { - BottleneckAnalysis bottleneck_analysis = ComputeBottleneckAnalysis( - result.input_time_breakdown(), result.step_details()); - result.set_input_percent(bottleneck_analysis.input_percent()); - result.set_output_percent(bottleneck_analysis.output_percent()); - result.set_idle_percent(bottleneck_analysis.idle_percent()); - result.set_compute_percent(bottleneck_analysis.compute_percent()); - recommendation.mutable_bottleneck_analysis()->PackFrom(bottleneck_analysis); - *recommendation.mutable_summary_next_step() = - GetSummaryNextStep(bottleneck_analysis.input_classification(), - result.input_time_breakdown()); - } - - *result.mutable_recommendation() = recommendation; - return result; -} - -bool InputAnalysis(double input_percent, double all_other_percent, - std::string* input_classification, - std::string* input_statement) { - absl::string_view non_input_time = "other time"; - if (input_percent >= kHighlyInfeedBoundThresholdInPercent) { - *input_classification = "host"; - *input_statement = absl::StrCat( - "Your program is HIGHLY input-bound because ", OneDigit(input_percent), - "% of the total step time sampled is waiting for input. Therefore, you " - "should first focus on reducing the input time."); - return false; - } else if (input_percent >= kModeratelyInfeedBoundThresholdInPercent) { - *input_classification = "both"; - *input_statement = absl::StrCat( - "Your program is MODERATELY input-bound because ", - OneDigit(input_percent), - "% of the total step time sampled is waiting for input. Therefore, " - "you would need to reduce both the input time and ", - non_input_time, "."); - return false; - } else if (all_other_percent >= kModeratelyAllOtherBoundThresholdInPercent) { - // Input analysis says it is not input-bound, but "All-Other" time - // is significant. It could still be input-bound (or Python overhead). - *input_classification = "both"; - *input_statement = absl::StrCat( - "Your program is POTENTIALLY input-bound because ", - OneDigit(all_other_percent), - "% of the total step time sampled is spent on 'All Others' time (which " - "could be due to I/O or Python execution or both)."); - return true; - } else { - // Definitely not input-bound. - *input_classification = "device"; - *input_statement = - absl::StrCat("Your program is NOT input-bound because only ", - OneDigit(input_percent), - "% of the total step time sampled is waiting for " - "input. Therefore, you should focus on " - "reducing ", - non_input_time, "."); - return false; - } -} - -void OutputAnalysis(double output_percent, std::string* output_classification, - std::string* output_statement) { - if (output_percent >= kHighlyOutfeedBoundThresholdInPercent) { - *output_classification = "host"; - *output_statement = absl::StrCat( - "Your program is HIGHLY output-bound because ", - OneDigit(output_percent), - "% of the total step time sampled is spent on output. Therefore, you " - "should first focus on reducing the output time."); - } else if (output_percent >= kModeratelyOutfeedBoundThresholdInPercent) { - *output_classification = "both"; - *output_statement = absl::StrCat( - "Your program is MODERATELY output-bound because ", - OneDigit(output_percent), - "% of the total step time sampled is spent on output. Therefore, " - "you would need to reduce both the output time and other time."); - } else { - *output_classification = "device"; - *output_statement = ""; - } -} - -BottleneckAnalysis ComputeBottleneckAnalysis( - const InputTimeBreakdown& input_time_breakdown, - const tsl::protobuf::RepeatedPtrField<::google::protobuf::Any>& - any_step_details) { - double total_step_time_ms = 0; - double total_input_ms = 0; - double total_output_ms = 0; - double total_host_compute_ms = 0; - double total_host_prepare_ms = 0; - double total_host_compile_ms = 0; - double total_device_compute_ms = 0; - double total_device_to_device_ms = 0; - double total_device_collectives_ms = 0; - double total_unknown_ms = 0; - - for (const google::protobuf::Any& step_details : any_step_details) { - PerGenericStepDetails details; - bool success = step_details.UnpackTo(&details); - if (!success && !step_details.type_url().empty()) { - LOG(ERROR) << "Unable to unpack step_breakdown. Expected: generic" - << std::endl; - return {}; - } - total_step_time_ms += details.step_time_ms(); - total_input_ms += - details.host_wait_input_ms() + details.host_to_device_ms(); - total_output_ms += details.output_ms(); - total_host_prepare_ms += details.host_prepare_ms(); - total_device_compute_ms += details.device_compute_ms(); - total_device_to_device_ms += details.device_to_device_ms(); - total_device_collectives_ms += details.device_collectives_ms(); - total_host_compute_ms += details.host_compute_ms(); - total_host_compile_ms += details.host_compile_ms(); - total_unknown_ms += details.unknown_time_ms(); - } - - if (total_step_time_ms == 0) { - BottleneckAnalysis analysis; - analysis.set_input_classification("unknown"); - analysis.set_input_statement( - "No step time measured. Therefore we cannot tell where the " - "performance bottleneck is."); - analysis.set_kernel_launch_classification("no"); - analysis.set_kernel_launch_statement(""); - analysis.set_all_other_classification("no"); - analysis.set_all_other_statement(""); - analysis.set_device_collectives_classification("no"); - analysis.set_device_collectives_statement(""); - return analysis; - } - double input_percent = 100.0 * total_input_ms / total_step_time_ms; - double output_percent = 100.0 * total_output_ms / total_step_time_ms; - double compute_percent = 100.0 * total_device_compute_ms / total_step_time_ms; - double device_collectives_percent = - 100.0 * total_device_collectives_ms / total_step_time_ms; - - // idle_percent includes host_prepare (i.e. kernel launch, device-to-device, - // host compute, host compile, and unknown. - double idle_percent = - std::max(0.0, 100.0 - input_percent - output_percent - compute_percent - - device_collectives_percent); - double kernel_launch_percent = - 100.0 * total_host_prepare_ms / total_step_time_ms; - double all_other_percent = 100.0 * total_unknown_ms / total_step_time_ms; - - std::string input_classification; - std::string input_statement; - bool all_other_reported = - InputAnalysis(input_percent, all_other_percent, &input_classification, - &input_statement); - - std::string device_collectives_classification; - std::string device_collectives_statement; - DeviceCollectivesAnalysis(device_collectives_percent, - &device_collectives_classification, - &device_collectives_statement); - - std::string kernel_launch_classification; - std::string kernel_launch_statement; - KernelLaunchAnalysis(TfDataInUse(input_time_breakdown), kernel_launch_percent, - &kernel_launch_classification, &kernel_launch_statement); - - std::string all_other_classification; - std::string all_other_statement; - AllOtherAnalysis(all_other_reported, all_other_percent, - &all_other_classification, &all_other_statement); - - BottleneckAnalysis analysis; - analysis.set_input_percent(input_percent); - analysis.set_output_percent(output_percent); - analysis.set_idle_percent(idle_percent); - analysis.set_compute_percent(compute_percent); - - analysis.set_input_classification(input_classification); - analysis.set_input_statement(input_statement); - analysis.set_kernel_launch_classification(kernel_launch_classification); - analysis.set_kernel_launch_statement(kernel_launch_statement); - analysis.set_all_other_classification(all_other_classification); - analysis.set_all_other_statement(all_other_statement); - analysis.set_device_collectives_classification( - device_collectives_classification); - analysis.set_device_collectives_statement(device_collectives_statement); - - return analysis; -} - -std::string GetSummaryNextStep(absl::string_view input_classification, - const InputTimeBreakdown& breakdown) { - std::string summary_next_step; - if (input_classification == "host" || input_classification == "both") { - if (!TfDataInUse(breakdown)) { - summary_next_step = absl::StrCat( - "Consider using ", MakeDocLink(DatasetIntroDoc(), "the tf.data API"), - " to enable profiler's host-side analysis for input pipeline. " - "Profiler currently does not support custom input pipeline (please " - "ignore " - "Section ", - kHostAnalysisSectionNumber, " below)."); - } else { - summary_next_step = - absl::StrCat("Look at Section ", kHostAnalysisSectionNumber, - " for the breakdown of input time on the host."); - } - } else { - summary_next_step = "You may skip the rest of this page."; - } - - return summary_next_step; -} - -double HostToDeviceTransferAsPercentOfInputTime( - const InputTimeBreakdown& breakdown) { - // Thanks to the scaling trick we did in GenerateHostResult(), we can - // estimate the percentage of input-time spent on host-to-device transfer in - // the following way. - double total_input_time_us = - breakdown.demanded_file_read_us() + breakdown.advanced_file_read_us() + - breakdown.preprocessing_us() + breakdown.enqueue_us() + - breakdown.unclassified_non_enqueue_us(); - return 100.0 * - tsl::profiler::SafeDivide(breakdown.enqueue_us(), total_input_time_us); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h deleted file mode 100644 index 53ebe189eaa324..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_ - -#include -#include - -#include "google/protobuf/any.pb.h" -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/util/stats_calculator.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/protobuf/tpu_input_pipeline.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { - -struct AllReduceBreakdown { - uint64_t compute_duration_ps = 0; - uint64_t sync_duration_ps = 0; - - uint64_t DurationPs() const { return compute_duration_ps + sync_duration_ps; } -}; - -// Used to store AllReduceBreakdown per core id. Just an alias for user -// convenience. -using PerCoreAllReduceBreakdown = - absl::flat_hash_map; - -// Breakdown AllReduce time into synchronization time and actual compute time -// for each core and step. -PerCoreAllReduceBreakdown ComputePerStepAllReduceBreakdownAcrossCores( - const PerCoreStepInfo& coreid_stepinfo_map); - -// Computes the fields in PerStepData by considering the different StepInfos -// of the same step across cores. -PerTpuStepDetails ComputeTpuPerStepDataAcrossCores( - const PerCoreStepInfo& coreid_stepinfo_map, - const tsl::protobuf::Map& - core_details_map); - -StepSummary GetStepSummaryForSampleStats(const tsl::Stat& sample_stats); - -// If the percent of input-time spent on host-to-device transfer is greater than -// kHostToDeviceTimePercentAsSignificant, we should advise the -// user to optimize this transfer. -constexpr double kHostToDeviceTimePercentAsSignificant = 10.0; - -// If the percent of input-time spent on host-to-device transfer is greater than -// kHostToDeviceTimePercentAsDominant, we should ONLY advise the -// user to optimize this transfer; we won't bother to suggest optimization for -// tf.data. -constexpr double kHostToDeviceTimePercentAsDominant = 90.0; - -// Computes the summary of step time in milliseconds. -StepSummary ComputeStepTimeSummaryInMs( - const tsl::protobuf::RepeatedPtrField& grouped_by_step); - -void GenerateHostResult(const OpMetricsDb& host_tf_metrics_db, - InputPipelineAnalysisResult* result); - -InputPipelineAnalysisRecommendation GenerateRecommendation(); - -// For TPU, we may have mis-regarded some host overhead as idle time. -// This function checks if this is the case using host_step_events. If this is, -// it will do the correction in op_stats. -void MayFixTpuStepAnalysis( - const StepEvents& host_step_events, const OpMetricsDb& device_op_metrics_db, - StepDatabaseResult& step_db, - const tsl::protobuf::Map& core_details_map); - -// Returns a struct that describes the performance bottleneck of the -// program executed on TPU. -TpuBottleneckAnalysis ComputeTpuBottleneckAnalysis( - bool all_cores_profiled, const InputPipelineAnalysisResult& result); - -// Returns the performance bottleneck of the program executed. -BottleneckAnalysis ComputeBottleneckAnalysis( - const InputTimeBreakdown& input_time_breakdown, - const tsl::protobuf::RepeatedPtrField<::google::protobuf::Any>& - any_step_details); - -InputPipelineAnalysisResult ConvertOpStatsToInputPipelineAnalysis( - const OpStats& op_stats); - -// Returns true if explanation for "All Others" time is also included in -// input_statement. -bool InputAnalysis(double input_percent, double all_other_percent, - std::string* input_classification, - std::string* input_statement); - -void OutputAnalysis(double output_percent, std::string* output_classification, - std::string* output_statement); - -string GetSummaryNextStep(absl::string_view input_classification, - const InputTimeBreakdown& breakdown); - -// Returns the percentage of the input time that is spent on transferring the -// data from host to device. -double HostToDeviceTransferAsPercentOfInputTime( - const InputTimeBreakdown& breakdown); - -void AddErrorMessages(const OpStats& op_stats, - InputPipelineAnalysisResult* result); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis_test.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis_test.cc deleted file mode 100644 index 3b9cff76410794..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis_test.cc +++ /dev/null @@ -1,204 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" - -#include -#include - -#include "google/protobuf/any.pb.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tensorflow::profiler::CoreDetails; -using ::tensorflow::profiler::OpMetricsDb; -using ::tensorflow::profiler::StepDatabaseResult; -using ::tensorflow::profiler::StepEvents; - -TEST(TfOpStatsToInputPipelineAnalysisTest, - AttributeHostInputTimeToTCWhenInfeedMissing) { - uint64_t step_num = 1; - tensorflow::profiler::StepDetails step_details; - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_WAIT_INPUT, - tsl::profiler::Timespan::FromEndPoints(50, 100))); - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_TO_DEVICE, - tsl::profiler::Timespan::FromEndPoints(110, 200))); - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_TO_DEVICE, - tsl::profiler::Timespan::FromEndPoints(430, 500))); - StepEvents host_step_events = {{step_num, step_details}}; - StepDatabaseResult step_db; - tensorflow::profiler::PerCoreStepInfo* pcsi = step_db.add_step_sequence(); - pcsi->set_step_num(step_num); - auto& sipc_map = *pcsi->mutable_step_info_per_core(); - tensorflow::profiler::StepInfoResult& sir = sipc_map[/* core_id= */ 2]; - sir.set_step_num(step_num); - sir.set_begin_ps(40); - sir.set_duration_ps(1000); - tensorflow::profiler::GenericStepBreakdown step_breakdown; - tsl::protobuf::Map& category_ps = - *step_breakdown.mutable_category_ps(); - category_ps[tensorflow::profiler::kIdle] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kMultiply)] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAllGather)] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAsyncStart)] = 50; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAsyncDone)] = 50; - sir.mutable_step_breakdown()->PackFrom(step_breakdown); - tsl::protobuf::Map core_details_map; - MayFixTpuStepAnalysis(host_step_events, OpMetricsDb(), step_db, - core_details_map); - tensorflow::profiler::GenericStepBreakdown updated_step_breakdown; - sir.step_breakdown().UnpackTo(&updated_step_breakdown); - const tsl::protobuf::Map& updated_category_ps = - updated_step_breakdown.category_ps(); - EXPECT_EQ(updated_category_ps.at(tensorflow::profiler::kIdle), 90); - ASSERT_TRUE(updated_category_ps.contains( - xla::HloOpcodeString(xla::HloOpcode::kInfeed))); - EXPECT_EQ( - updated_category_ps.at(xla::HloOpcodeString(xla::HloOpcode::kInfeed)), - 210); -} - -TEST(TfOpStatsToInputPipelineAnalysisTest, - AttributeHostInputTimeToTCWhenInfeedMissingMultiCore) { - uint64_t step_num = 1; - tensorflow::profiler::StepDetails step_details; - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_WAIT_INPUT, - tsl::profiler::Timespan::FromEndPoints(50, 100))); - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_TO_DEVICE, - tsl::profiler::Timespan::FromEndPoints(110, 200))); - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_TO_DEVICE, - tsl::profiler::Timespan::FromEndPoints(430, 500))); - StepEvents host_step_events = {{step_num, step_details}}; - StepDatabaseResult step_db; - tensorflow::profiler::PerCoreStepInfo* pcsi = step_db.add_step_sequence(); - pcsi->set_step_num(step_num); - tsl::protobuf::Map& sipc_map = - *pcsi->mutable_step_info_per_core(); - tensorflow::profiler::StepInfoResult& sir = sipc_map[/* core_id= */ 2]; - sir.set_step_num(step_num); - sir.set_begin_ps(40); - sir.set_duration_ps(1000); - tensorflow::profiler::GenericStepBreakdown step_breakdown; - tsl::protobuf::Map& category_ps = - *step_breakdown.mutable_category_ps(); - category_ps[tensorflow::profiler::kIdle] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kMultiply)] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAllGather)] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAsyncStart)] = 50; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAsyncDone)] = 50; - sir.mutable_step_breakdown()->PackFrom(step_breakdown); - tensorflow::profiler::StepInfoResult& sir2 = sipc_map[/* core_id= */ 1]; - sir2.set_step_num(step_num); - sir2.set_begin_ps(45); - sir2.set_duration_ps(900); - tensorflow::profiler::GenericStepBreakdown step_breakdown2; - tsl::protobuf::Map& category_ps2 = - *step_breakdown2.mutable_category_ps(); - category_ps2[tensorflow::profiler::kIdle] = 250; - category_ps2[xla::HloOpcodeString(xla::HloOpcode::kMultiply)] = 300; - category_ps2[xla::HloOpcodeString(xla::HloOpcode::kAllGather)] = 250; - category_ps2[xla::HloOpcodeString(xla::HloOpcode::kAsyncStart)] = 50; - category_ps2[xla::HloOpcodeString(xla::HloOpcode::kAsyncDone)] = 50; - sir2.mutable_step_breakdown()->PackFrom(step_breakdown2); - tsl::protobuf::Map core_details_map; - OpMetricsDb device_op_metrics_db; - MayFixTpuStepAnalysis(host_step_events, device_op_metrics_db, step_db, - core_details_map); - tensorflow::profiler::GenericStepBreakdown updated_step_breakdown; - sir.step_breakdown().UnpackTo(&updated_step_breakdown); - const tsl::protobuf::Map& updated_category_ps = - updated_step_breakdown.category_ps(); - EXPECT_EQ(updated_category_ps.at(tensorflow::profiler::kIdle), 48); - ASSERT_TRUE(updated_category_ps.contains( - xla::HloOpcodeString(xla::HloOpcode::kInfeed))); - EXPECT_EQ( - updated_category_ps.at(xla::HloOpcodeString(xla::HloOpcode::kInfeed)), - 252); - tensorflow::profiler::GenericStepBreakdown updated_step_breakdown2; - sir2.step_breakdown().UnpackTo(&updated_step_breakdown2); - const tsl::protobuf::Map& updated_category_ps2 = - updated_step_breakdown2.category_ps(); - EXPECT_EQ(updated_category_ps2.at(tensorflow::profiler::kIdle), 40); - ASSERT_TRUE(updated_category_ps2.contains( - xla::HloOpcodeString(xla::HloOpcode::kInfeed))); - EXPECT_EQ( - updated_category_ps2.at(xla::HloOpcodeString(xla::HloOpcode::kInfeed)), - 210); -} - -TEST(TfOpStatsToInputPipelineAnalysisTest, - SkipMayFixTpuStepAnalysisWhenInfeedExists) { - uint64_t step_num = 1; - tensorflow::profiler::StepDetails step_details; - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_WAIT_INPUT, - tsl::profiler::Timespan::FromEndPoints(50, 100))); - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_TO_DEVICE, - tsl::profiler::Timespan::FromEndPoints(110, 200))); - step_details.AddEvent(tensorflow::profiler::EventTypeSpan( - tensorflow::profiler::EventType::HOST_TO_DEVICE, - tsl::profiler::Timespan::FromEndPoints(430, 500))); - StepEvents host_step_events = {{step_num, step_details}}; - StepDatabaseResult step_db; - tensorflow::profiler::PerCoreStepInfo* pcsi = step_db.add_step_sequence(); - pcsi->set_step_num(step_num); - tsl::protobuf::Map& sipc_map = - *pcsi->mutable_step_info_per_core(); - tensorflow::profiler::StepInfoResult& sir = sipc_map[/* core_id= */ 2]; - sir.set_step_num(step_num); - sir.set_begin_ps(40); - sir.set_duration_ps(1000); - tensorflow::profiler::GenericStepBreakdown step_breakdown; - tsl::protobuf::Map& category_ps = - *step_breakdown.mutable_category_ps(); - category_ps[tensorflow::profiler::kIdle] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kMultiply)] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAllGather)] = 300; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kAsyncStart)] = 50; - category_ps[xla::HloOpcodeString(xla::HloOpcode::kInfeed)] = 50; - sir.mutable_step_breakdown()->PackFrom(step_breakdown); - tsl::protobuf::Map core_details_map; - OpMetricsDb device_op_metrics_db; - device_op_metrics_db.add_metrics_db()->set_category( - std::string(xla::HloOpcodeString(xla::HloOpcode::kInfeed))); - MayFixTpuStepAnalysis(host_step_events, device_op_metrics_db, step_db, - core_details_map); - tensorflow::profiler::GenericStepBreakdown updated_step_breakdown; - sir.step_breakdown().UnpackTo(&updated_step_breakdown); - const tsl::protobuf::Map& updated_category_ps = - updated_step_breakdown.category_ps(); - EXPECT_EQ(updated_category_ps.at(tensorflow::profiler::kIdle), 300); - EXPECT_EQ( - updated_category_ps.at(xla::HloOpcodeString(xla::HloOpcode::kInfeed)), - 50); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc b/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc deleted file mode 100644 index 59ac8ca086bd4a..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_op_profile.h" - -#include -#include - -#include "absl/log/check.h" -#include "absl/strings/match.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/convert/op_profile_builder.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_profile.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tensorflow::profiler::IsIdleOp; -using ::tensorflow::profiler::OpMetrics; -using ::tensorflow::profiler::OpProfileBuilder; -using ::tensorflow::profiler::OpProfileOptions; -using ::tensorflow::profiler::OpStats; -using ::tensorflow::profiler::TotalTimePs; -using ::tensorflow::profiler::op_profile::Node; - -void BuildOpProfileNodeTree(const OpStats& op_stats, bool group_by_program, - bool exclude_idle_ops, int op_profile_limit, - Node* root) { - const auto& metrics_db = op_stats.device_op_metrics_db(); - if (metrics_db.metrics_db().empty()) return; - - OpProfileOptions options = {group_by_program, - /*group_by_deduplicated_name=*/true, - /*children_per_node=*/op_profile_limit}; - OpProfileBuilder builder(options, root, &op_stats.program_id_to_name_map()); - - for (const OpMetrics& op_metrics : metrics_db.metrics_db()) { - DCHECK(!op_metrics.name().empty()); - // Don't add ops that cannot be symbolized. - if (absl::StartsWith(op_metrics.name(), "region")) continue; - if (exclude_idle_ops && IsIdleOp(op_metrics)) continue; - builder.AddOp(op_metrics); - } - - const auto& perf_env = op_stats.perf_env(); - double max_gigaflops_per_second_per_core = - tsl::profiler::TeraToGiga(perf_env.peak_tera_flops_per_second()); - std::vector peak_bws; - for (auto bw : perf_env.peak_bws_giga_bytes_per_second()) { - peak_bws.push_back(tsl::profiler::GigaToGibi(bw)); - } - builder.Finalize(max_gigaflops_per_second_per_core, peak_bws, - TotalTimePs(metrics_db, exclude_idle_ops)); -} - -} // namespace - -void ConvertOpStatsToOpProfile( - const OpStats& op_stats, tensorflow::profiler::HardwareType hardware_type, - tensorflow::profiler::op_profile::Profile& profile, int op_profile_limit) { - profile.set_device_type(HardwareType_Name(hardware_type)); - BuildOpProfileNodeTree(op_stats, - /*group_by_program=*/false, - /*exclude_idle_ops=*/false, op_profile_limit, - profile.mutable_by_category()); - - BuildOpProfileNodeTree(op_stats, - /*group_by_program=*/false, - /*exclude_idle_ops=*/true, op_profile_limit, - profile.mutable_by_category_exclude_idle()); - - BuildOpProfileNodeTree(op_stats, - /*group_by_program=*/true, - /*exclude_idle_ops=*/false, op_profile_limit, - profile.mutable_by_program()); - - BuildOpProfileNodeTree(op_stats, - /*group_by_program=*/true, - /*exclude_idle_ops=*/true, op_profile_limit, - profile.mutable_by_program_exclude_idle()); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_op_profile.h b/tensorflow/core/profiler/convert/op_stats_to_op_profile.h deleted file mode 100644 index 1fcfefb510d454..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_op_profile.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OP_PROFILE_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OP_PROFILE_H_ - -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_profile.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -// Assembles a hierarchical performance profile based on HLOs in the op metrics -// db. -// The node hierarchy is as following: -// by_category -// - combined_root -// - category 1 -// - category 2 -// - ... -// - idle -// by_program -// - program_1_root -// - category 1 -// - category 2 -// - ... -// - program_2_root -// - category 1 -// - ... -// - idle -// The nodes in the profile are sorted by time in decreasing order and pruned -// to reduce the profile size. Only 100 nodes are kept for level >= 3. -// See op_profile.proto for the detailed semantics of the returned profile. -void ConvertOpStatsToOpProfile( - const tensorflow::profiler::OpStats& op_stats, - tensorflow::profiler::HardwareType hardware_type, - tensorflow::profiler::op_profile::Profile& profile, - int op_profile_limit = 100); - -} // namespace profiler -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OP_PROFILE_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc deleted file mode 100644 index 73af4c71436627..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc +++ /dev/null @@ -1,405 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h" - -#include -#include -#include -#include - -#include "google/protobuf/any.pb.h" -#include "absl/algorithm/container.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/format_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" -#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/overview_page.pb.h" -#include "tensorflow/core/profiler/protobuf/power_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_function.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" -#include "tensorflow/core/profiler/utils/html_utils.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -using tsl::profiler::OneDigit; - -// If the use of low-precision ops is less than this percentage threshold, a -// statement of suggestion will be made. -constexpr double kLowPrecisionPercentThreshold = 10; - -struct TfFunctionInfo { - absl::string_view function_name; - double expensive_call_percent; -}; - -OverviewPageTip MakeOverviewPageTip(std::string text) { - OverviewPageTip tip; - tip.set_link(std::move(text)); - return tip; -} - -// Makes a recommendation for looking up a document. -// doc_url is expected to be already be escaped suitably for use in an HTML -// attribute. -OverviewPageTip MakeOverviewPageTipDocLink(absl::string_view doc_url, - absl::string_view text) { - return MakeOverviewPageTip(AnchorElement(doc_url, text)); -} - -void ComputeHostTips(OverviewPageRecommendation* re) { - *re->add_host_tips() = MakeOverviewPageTip( - "input_pipeline_analyzer (especially Section 3 for the breakdown of " - "input operations on the Host)"); - *re->add_host_tips() = MakeOverviewPageTip( - "trace_viewer (look at the activities on the timeline of each Host " - "Thread near the bottom of the trace view)"); -} - -void ComputeDeviceTips(HardwareType hardware_type, - OverviewPageRecommendation* re) { - absl::string_view device_name = HardwareType_Name(hardware_type); - absl::string_view timeline_name = device_name; - absl::string_view op_stats_toolname = "framework_op_stats"; - if (hardware_type == tensorflow::profiler::TPU) { - timeline_name = "TPU core"; - op_stats_toolname = "op_profile"; - } - *re->add_device_tips() = MakeOverviewPageTip( - absl::StrCat(op_stats_toolname, - " (identify the time-consuming operations " - "executed on the ", - device_name, ")")); - *re->add_device_tips() = MakeOverviewPageTip(absl::StrCat( - "trace_viewer (look at the activities on the timeline of each ", - timeline_name, " in the trace view)")); -} - -void ComputeFaqTips(OverviewPageRecommendation* re) { - *re->add_faq_tips() = MakeOverviewPageTip("Refer to the TF2 Profiler FAQ"); -} - -void ComputeDocumentationTips(OverviewPageRecommendation* re) { - *re->add_documentation_tips() = MakeOverviewPageTipDocLink( - "https://www.tensorflow.org/guide/data_performance_analysis", - "Analyze tf.data performance with the TF Profiler"); - *re->add_documentation_tips() = MakeOverviewPageTipDocLink( - "https://www.tensorflow.org/guide/" - "data_performance", - "Better performance with the tf.data API"); -} - -std::string GeneratePrecisionStatement(const PrecisionStats& precision_stats) { - uint64 total_compute_ps = - precision_stats.compute_16bit_ps() + precision_stats.compute_32bit_ps(); - if (total_compute_ps > 0) { - double percent_16bit = - (100.0 * precision_stats.compute_16bit_ps()) / total_compute_ps; - if (percent_16bit < kLowPrecisionPercentThreshold) { - return absl::StrCat( - "Only ", OneDigit(percent_16bit), - "% of device computation is 16 bit. So you might want to replace " - "more 32-bit Ops by 16-bit Ops to improve performance (if the " - "reduced accuracy is acceptable)."); - } - } - return ""; -} - -} // namespace - -void SetCommonRecommendation( - absl::string_view input_classification, absl::string_view input_statement, - absl::string_view output_statement, HardwareType hardware_type, - absl::string_view tf_function_statement_html, - absl::string_view eager_statement_html, - absl::string_view outside_compilation_statement_html, - OverviewPageRecommendation* re) { - re->set_bottleneck(std::string(input_classification)); - re->set_statement(std::string(input_statement)); - re->set_output_statement(std::string(output_statement)); - re->set_tf_function_statement_html(std::string(tf_function_statement_html)); - re->set_eager_statement_html(std::string(eager_statement_html)); - re->set_outside_compilation_statement_html( - std::string(outside_compilation_statement_html)); - ComputeHostTips(re); - ComputeDeviceTips(hardware_type, re); - ComputeDocumentationTips(re); - ComputeFaqTips(re); -} - -OverviewPageRecommendation ComputeGenericRecommendation( - const BottleneckAnalysis& bottleneck, - const PrecisionStats& precision_stats) { - OverviewPageRecommendation re; - GenericRecommendation generic; - generic.set_device_collectives_bottleneck( - bottleneck.device_collectives_classification()); - generic.set_device_collectives_statement( - bottleneck.device_collectives_statement()); - generic.set_kernel_launch_bottleneck( - bottleneck.kernel_launch_classification()); - generic.set_kernel_launch_statement(bottleneck.kernel_launch_statement()); - generic.set_all_other_bottleneck(bottleneck.all_other_classification()); - generic.set_all_other_statement(bottleneck.all_other_statement()); - generic.set_precision_statement(GeneratePrecisionStatement(precision_stats)); - re.mutable_recommendation()->PackFrom(generic); - return re; -} - -OverviewPageAnalysis ComputeAnalysisResult(const OpStats& op_stats) { - OverviewPageAnalysis analysis; - OpMetricsDb device_tf_op_metrics_db = CreateTfMetricsDbFromDeviceOpMetricsDb( - op_stats.device_op_metrics_db(), /*with_idle=*/false); - KernelStatsByOpName kernel_stats_by_op_name = - GroupKernelReportsByOpName(op_stats.kernel_stats_db()); - uint64 total_device_time_ps = device_tf_op_metrics_db.total_time_ps(); - constexpr int kNumTopOpsShown = 10; - double device_cumulative_fraction = 0.0; - for (const OpMetrics* metrics : - SortedOpMetricsDb(device_tf_op_metrics_db, kNumTopOpsShown)) { - OverviewTfOp* op = analysis.add_top_device_ops(); - op->set_name(metrics->name()); - op->set_category(metrics->category()); - op->set_self_time_fraction(tsl::profiler::SafeDivide( - metrics->self_time_ps(), total_device_time_ps)); - device_cumulative_fraction += op->self_time_fraction(); - op->set_cumulative_time_fraction(device_cumulative_fraction); - op->set_flop_rate(tsl::profiler::SafeDivide( - metrics->flops(), tsl::profiler::PicoToNano(metrics->time_ps()))); - auto iter = kernel_stats_by_op_name.find(op->name()); - if (iter != kernel_stats_by_op_name.end()) { - op->set_is_op_tensorcore_eligible( - iter->second.is_op_tensor_core_eligible); - op->set_is_op_using_tensorcore(iter->second.tensor_core_duration_ns != 0); - } - } - uint64 total_device_compute_ps = - op_stats.device_op_metrics_db().precision_stats().compute_16bit_ps() + - op_stats.device_op_metrics_db().precision_stats().compute_32bit_ps(); - analysis.set_device_compute_16bit_percent( - 100.0 * - tsl::profiler::SafeDivide( - op_stats.device_op_metrics_db().precision_stats().compute_16bit_ps(), - total_device_compute_ps)); - analysis.set_device_compute_32bit_percent( - 100.0 * - tsl::profiler::SafeDivide( - op_stats.device_op_metrics_db().precision_stats().compute_32bit_ps(), - total_device_compute_ps)); - - uint64 num_host_tf_ops = 0; - uint64 total_host_op_time_ps_exclude_idle = 0; - uint64 eager_host_op_time_ps = 0; - for (const OpMetrics& metrics : op_stats.host_op_metrics_db().metrics_db()) { - num_host_tf_ops += metrics.occurrences(); - if (!IsIdleOp(metrics)) { - total_host_op_time_ps_exclude_idle += metrics.self_time_ps(); - if (metrics.is_eager()) eager_host_op_time_ps += metrics.self_time_ps(); - } - } - uint64 num_device_tf_ops = 0; - uint64 total_device_op_time_ps_exclude_idle = 0; - uint64 eager_device_op_time_ps = 0; - for (const OpMetrics& metrics : device_tf_op_metrics_db.metrics_db()) { - num_device_tf_ops += metrics.occurrences(); - if (!IsIdleOp(metrics)) { - total_device_op_time_ps_exclude_idle += metrics.self_time_ps(); - if (metrics.is_eager()) eager_device_op_time_ps += metrics.self_time_ps(); - } - } - // Figures out outside_compilation time from - // op_stats.device_op_metrics_db().metrics_db(). We don't use the - // {metrics.provenance(), metrics.name()} from - // device_tf_op_metrics_db.metrics_db(), because metrics.provenance() there is - // not set and metrics.name() can be either HLO-Op name or TF-Op name, which - // will confuse tsl::profiler::IsOutsideCompilationOp(). - uint64 outside_compilation_device_op_time_ps = 0; - for (const OpMetrics& metrics : - op_stats.device_op_metrics_db().metrics_db()) { - if (!tsl::profiler::IsOutsideCompilationOp(metrics.provenance(), - metrics.long_name())) - continue; - outside_compilation_device_op_time_ps += metrics.self_time_ps(); - } - uint64 num_total_tf_ops = num_host_tf_ops + num_device_tf_ops; - analysis.set_host_tf_op_percent( - 100.0 * tsl::profiler::SafeDivide(num_host_tf_ops, num_total_tf_ops)); - analysis.set_device_tf_op_percent( - 100.0 * tsl::profiler::SafeDivide(num_device_tf_ops, num_total_tf_ops)); - analysis.set_host_trace_level(op_stats.run_environment().host_trace_level()); - analysis.set_host_op_time_eager_percent( - 100.0 * tsl::profiler::SafeDivide(eager_host_op_time_ps, - total_host_op_time_ps_exclude_idle)); - analysis.set_device_op_time_eager_percent( - 100.0 * tsl::profiler::SafeDivide(eager_device_op_time_ps, - total_device_op_time_ps_exclude_idle)); - analysis.set_device_op_time_outside_compilation_percent( - 100.0 * tsl::profiler::SafeDivide(outside_compilation_device_op_time_ps, - total_device_op_time_ps_exclude_idle)); - return analysis; -} - -// Converts from HostIndependentJobInfo to OverviewPageHostIndependentJobInfo. -OverviewPageHostIndependentJobInfo ToOverviewPageHostIndependentJobInfo( - const HostIndependentJobInfoResult& host_independent_job_info) { - OverviewPageHostIndependentJobInfo result; - result.set_change_list(host_independent_job_info.change_list()); - result.set_build_time(host_independent_job_info.build_time()); - result.set_build_target(host_independent_job_info.build_target()); - result.set_profile_duration_ms( - host_independent_job_info.profile_duration_ms()); - return result; -} - -// Converts from HostDependentJobInfo to OverviewPageHostDependentJobInfo. -OverviewPageHostDependentJobInfo ToOverviewPageHostDependentJobInfo( - const HostDependentJobInfoResult& host_dependent_job_info) { - OverviewPageHostDependentJobInfo result; - result.set_host_id(host_dependent_job_info.host_id()); - result.set_command_line(host_dependent_job_info.command_line()); - result.set_start_time(host_dependent_job_info.start_time()); - result.set_bns_address(host_dependent_job_info.bns_address()); - result.set_profile_time_ns(host_dependent_job_info.profile_time_ns()); - return result; -} - -OverviewPageRunEnvironment ComputeRunEnvironment( - const RunEnvironment& run_environment) { - OverviewPageRunEnvironment re; - re.set_host_count(run_environment.host_count()); - re.set_task_count(run_environment.task_count()); - re.set_device_type(run_environment.device_type()); - re.set_device_core_count(run_environment.device_core_count()); - re.set_replica_count(run_environment.replica_count()); - re.set_num_cores_per_replica(run_environment.num_cores_per_replica()); - re.set_is_training(run_environment.is_training()); - if (run_environment.has_power_metrics()) { - *re.mutable_power_metrics() = run_environment.power_metrics(); - } - *re.mutable_host_independent_job_info() = - ToOverviewPageHostIndependentJobInfo( - run_environment.host_independent_job_info()); - for (const auto& host_dependent_job_info : - run_environment.host_dependent_job_info()) { - *re.add_host_dependent_job_info() = - ToOverviewPageHostDependentJobInfo(host_dependent_job_info); - } - return re; -} - -std::string TfFunctionRecommendationHtml(const TfFunctionDb& tf_function_db) { - std::vector candidates; - for (const auto& name_fun : tf_function_db.tf_functions()) { - const auto& fun = name_fun.second; - if (fun.expensive_call_percent() >= kTfFunctionReportThresholdInPercent) { - candidates.push_back({name_fun.first, fun.expensive_call_percent()}); - } - } - if (candidates.empty()) return ""; - auto cmp = [](const TfFunctionInfo& a, const TfFunctionInfo& b) { - return a.expensive_call_percent > b.expensive_call_percent; - }; - // Sorts candidates in descending order of expensive_call_percent. - absl::c_sort(candidates, cmp); - std::string expensive_functions = ""; - auto num_functions_shown = std::min( - static_cast(3), candidates.size()); - - for (decltype(candidates)::size_type i = 0; i < num_functions_shown; i++) { - if (i > 0) absl::StrAppend(&expensive_functions, ", "); - absl::StrAppend(&expensive_functions, "\"", candidates[i].function_name, - "\""); - } - if (candidates.size() > num_functions_shown) - absl::StrAppend(&expensive_functions, " and more"); - return absl::StrCat("Expensive tf-functions detected (", expensive_functions, - ") due to either retracing or eager execution."); -} - -std::string EagerRecommendationHtml(double host_op_time_eager_percent, - double device_op_time_eager_percent) { - std::string recommendation = ""; - if (host_op_time_eager_percent > kEagerReportThresholdInPercent) - absl::StrAppend(&recommendation, OneDigit(host_op_time_eager_percent), - "% of Op time on the host used eager execution. "); - if (device_op_time_eager_percent > kEagerReportThresholdInPercent) - absl::StrAppend(&recommendation, OneDigit(device_op_time_eager_percent), - "% of Op time on the device used eager execution. "); - if (!recommendation.empty()) - absl::StrAppend(&recommendation, "Performance could be improved with ", - AnchorElement("https://www.tensorflow.org/guide/function", - "tf.function.")); - return recommendation; -} - -std::string OutsideCompilationRecommendationHtml( - double device_op_time_outside_compilation_percent) { - if (device_op_time_outside_compilation_percent <= - kOutsideCompilationThresholdInPercent) - return ""; - return absl::StrCat( - OneDigit(device_op_time_outside_compilation_percent), - " % of Op time on the device are for outside compilation. Performance " - "could be improved by avoiding outside compilation."); -} - -OverviewPage ConvertOpStatsToOverviewPage(const OpStats& op_stats) { - OverviewPage overview_page; - *overview_page.mutable_run_environment() = - ComputeRunEnvironment(op_stats.run_environment()); - *overview_page.mutable_analysis() = ComputeAnalysisResult(op_stats); - *overview_page.mutable_input_analysis() = - ConvertOpStatsToInputPipelineAnalysis(op_stats); - BottleneckAnalysis bottleneck = ComputeBottleneckAnalysis( - overview_page.input_analysis().input_time_breakdown(), - overview_page.input_analysis().step_details()); - *overview_page.mutable_recommendation() = ComputeGenericRecommendation( - bottleneck, op_stats.device_op_metrics_db().precision_stats()); - SetCommonRecommendation( - bottleneck.input_classification(), bottleneck.input_statement(), "", - ParseHardwareType(op_stats.run_environment().device_type()), - TfFunctionRecommendationHtml(op_stats.tf_function_db()), - EagerRecommendationHtml( - overview_page.analysis().host_op_time_eager_percent(), - overview_page.analysis().device_op_time_eager_percent()), - OutsideCompilationRecommendationHtml( - overview_page.analysis() - .device_op_time_outside_compilation_percent()), - overview_page.mutable_recommendation()); - PopulateOverviewDiagnostics(op_stats, overview_page.mutable_diagnostics()); - overview_page.mutable_analysis()->set_mxu_utilization_percent( - op_stats.performance_counter_result().matrix_unit_utilization_percent()); - return overview_page; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.h b/tensorflow/core/profiler/convert/op_stats_to_overview_page.h deleted file mode 100644 index ba6d906e325d96..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.h +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OVERVIEW_PAGE_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OVERVIEW_PAGE_H_ - -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/overview_page.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Reports tf-function optimization opportunity in the Overview Page if the -// expensive-call-time percentage is over this threshold for at least one of -// the tf-functions profiled. -const double kTfFunctionReportThresholdInPercent = 20; - -// Reports eager-mode optimization opportunity in the Overview Page if the -// percent of Op time on host (or device) that is spent on eager mode is over -// this threshold. -const double kEagerReportThresholdInPercent = 10; - -// Reports outside-compilation opportunity in the Overview Page if the -// percent of Op time on device that is for outside compilation is over -// this threshold. -const double kOutsideCompilationThresholdInPercent = 5; - -void SetCommonRecommendation( - absl::string_view input_classification, absl::string_view input_statement, - absl::string_view output_statement, HardwareType hardware_type, - absl::string_view tf_function_statement_html, - absl::string_view eager_statement_html, - absl::string_view outside_compilation_statement_html, - OverviewPageRecommendation* re); - -OverviewPageRecommendation ComputeGenericRecommendation( - const BottleneckAnalysis& bottleneck, - const PrecisionStats& precision_stats); - -OverviewPageAnalysis ComputeAnalysisResult(const OpStats& op_stats); - -OverviewPageRunEnvironment ComputeRunEnvironment( - const RunEnvironment& run_environment); - -OverviewPage ConvertOpStatsToOverviewPage(const OpStats& op_stats); - -// Returns a html which provides tf-function related recommendation. -std::string TfFunctionRecommendationHtml(const TfFunctionDb& tf_function_db); - -// Returns a html which provides eager-mode related recommendation. -std::string EagerRecommendationHtml(double host_op_time_eager_percent, - double device_op_time_eager_percent); - -// Returns a html which provides outside-compilation related recommendation. -std::string OutsideCompilationRecommendationHtml( - double device_op_time_outside_compilation_percent); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OVERVIEW_PAGE_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc deleted file mode 100644 index 3735c2a188bc19..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_pod_stats.h" - -#include -#include -#include -#include - -#include "google/protobuf/any.pb.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" -#include "tensorflow/core/profiler/utils/event_span.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -PodStatsRecord CreatePodStatsRecord(absl::string_view host_name, - const StepInfoResult& step_info) { - PodStatsRecord record; - GenericStepBreakdown generic; - bool success = step_info.step_breakdown().UnpackTo(&generic); - DCHECK(success); - record.set_host_name(string(host_name)); - record.set_step_num(step_info.step_num()); - record.set_total_duration_us( - tsl::profiler::PicoToMicro(step_info.duration_ps())); - auto& step_breakdown_map = *record.mutable_step_breakdown_us(); - std::vector> metrics; - - auto add_event = [&](GenericEventType type, - std::initializer_list event_list) { - uint64 ps = 0; - for (const auto& event_type : event_list) { - ps += gtl::FindWithDefault(generic.type_ps(), event_type, /*value=*/0); - } - step_breakdown_map[type] = tsl::profiler::PicoToMicro(ps); - metrics.emplace_back(ps, GetGenericEventTypeStr(type)); - }; - - add_event(kDeviceCompute, {DEVICE_COMPUTE_32, DEVICE_COMPUTE_16}); - add_event(kDeviceToDevice, {DEVICE_TO_DEVICE, DEVICE_WAIT_DEVICE}); - add_event(kDeviceCollectives, {DEVICE_COLLECTIVES}); - add_event(kHostCompute, {HOST_COMPUTE}); - add_event(kHostPrepare, {HOST_PREPARE}); - add_event(kInput, {HOST_WAIT_INPUT, HOST_TO_DEVICE, DEVICE_WAIT_HOST}); - add_event(kOutput, {DEVICE_TO_HOST}); - add_event(kCompile, {HOST_COMPILE}); - add_event(kAllOthers, {UNKNOWN_TIME}); - - std::sort(metrics.begin(), metrics.end()); - record.set_bottleneck(metrics.back().second.data(), - metrics.back().second.size()); - return record; -} - -} // namespace - -PodStatsDatabase ConvertOpStatsToPodStats(const OpStats& op_stats) { - PodStatsDatabase pod_stats_db; - const auto& core_id_map = op_stats.core_id_to_details(); - for (int i = GenericEventType::kFirstGenericEventType; - i <= GenericEventType::kLastGenericEventType; i++) { - auto& event = *pod_stats_db.add_step_breakdown_events(); - event.set_id(i); - absl::string_view type_str = - GetGenericEventTypeStr(static_cast(i)); - event.set_name(type_str.data(), type_str.size()); - } - - for (const auto& step_sequence : op_stats.step_db().step_sequence()) { - for (const auto& entry : step_sequence.step_info_per_core()) { - if (!core_id_map.contains(entry.first)) { - LOG(WARNING) << "core_id_map does not contain " << entry.first; - continue; - } - const CoreDetails& details = core_id_map.at(entry.first); - *pod_stats_db.add_pod_stats_record() = - CreatePodStatsRecord(details.hostname(), entry.second); - } - } - PopulateStepDiagnostics(op_stats, pod_stats_db.mutable_diagnostics()); - return pod_stats_db; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_stats.h b/tensorflow/core/profiler/convert/op_stats_to_pod_stats.h deleted file mode 100644 index bd3d74068d8a6c..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_stats.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_STATS_H_ - -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -PodStatsDatabase ConvertOpStatsToPodStats(const OpStats& op_stats); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_STATS_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc deleted file mode 100644 index 899b8ade54ca9e..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_pod_stats.h" - -#include "google/protobuf/any.pb.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" -#include "tensorflow/core/profiler/utils/event_span.h" - -namespace tensorflow { -namespace profiler { -namespace { - -const double kMaxError = 1e-6; -constexpr int kStepNum = 2; -constexpr int kCoreId = 1001; -constexpr int kStepTimePs = 1000; -constexpr int kHostComputePs = 50; -constexpr int kHostCompilePs = 50; -constexpr int kHostToHostPs = 50; -constexpr int kHostToDevicePs = 50; -constexpr int kHostPreparePs = 50; -constexpr int kDeviceCollectivePs = 350; -constexpr int kHostWaitInputPs = 50; -constexpr int kDeviceToDevicePs = 50; -constexpr int kDeviceToHostPs = 50; -constexpr int kDeviceCompute32Ps = 50; -constexpr int kDeviceCompute16Ps = 50; -constexpr int kDeviceWaitDevicePs = 50; -constexpr int kDeviceWaitHostPs = 50; -constexpr int kUnknownTimePs = 50; -static constexpr char kHostname[] = "host:123"; - -void CreateOpStats(OpStats* op_stats) { - PerCoreStepInfo* info = op_stats->mutable_step_db()->add_step_sequence(); - info->set_step_num(kStepNum); - StepInfoResult& step_info = (*info->mutable_step_info_per_core())[kCoreId]; - step_info.set_step_num(kStepNum); - step_info.set_duration_ps(kStepTimePs); - GenericStepBreakdown breakdown; - auto& type_ps = *breakdown.mutable_type_ps(); - type_ps[HOST_COMPUTE] = kHostComputePs; - type_ps[HOST_COMPILE] = kHostCompilePs; - type_ps[HOST_TO_HOST] = kHostToHostPs; - type_ps[HOST_TO_DEVICE] = kHostToDevicePs; - type_ps[HOST_PREPARE] = kHostPreparePs; - type_ps[DEVICE_COLLECTIVES] = kDeviceCollectivePs; - type_ps[HOST_WAIT_INPUT] = kHostWaitInputPs; - type_ps[DEVICE_TO_DEVICE] = kDeviceToDevicePs; - type_ps[DEVICE_TO_HOST] = kDeviceToHostPs; - type_ps[DEVICE_COMPUTE_32] = kDeviceCompute32Ps; - type_ps[DEVICE_COMPUTE_16] = kDeviceCompute16Ps; - type_ps[DEVICE_WAIT_DEVICE] = kDeviceWaitDevicePs; - type_ps[DEVICE_WAIT_HOST] = kDeviceWaitHostPs; - type_ps[UNKNOWN_TIME] = kUnknownTimePs; - step_info.mutable_step_breakdown()->PackFrom(breakdown); - CoreDetails& details = (*op_stats->mutable_core_id_to_details())[kCoreId]; - details.set_hostname(kHostname); -} - -TEST(OpStatsToPodStats, GpuPodStats) { - OpStats op_stats; - CreateOpStats(&op_stats); - PodStatsDatabase pod_stats_db = ConvertOpStatsToPodStats(op_stats); - EXPECT_EQ(1, pod_stats_db.pod_stats_record_size()); - const PodStatsRecord& record = pod_stats_db.pod_stats_record(0); - EXPECT_EQ(kStepNum, record.step_num()); - EXPECT_EQ(kHostname, record.host_name()); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kStepTimePs), - record.total_duration_us(), kMaxError); - const auto& breakdown = record.step_breakdown_us(); - EXPECT_NEAR( - tsl::profiler::PicoToMicro(kDeviceCompute32Ps + kDeviceCompute16Ps), - breakdown.at(kDeviceCompute), kMaxError); - EXPECT_NEAR( - tsl::profiler::PicoToMicro(kDeviceToDevicePs + kDeviceWaitDevicePs), - breakdown.at(kDeviceToDevice), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kDeviceCollectivePs), - breakdown.at(kDeviceCollectives), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostComputePs), - breakdown.at(kHostCompute), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostPreparePs), - breakdown.at(kHostPrepare), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostWaitInputPs + kHostToDevicePs + - kDeviceWaitHostPs), - breakdown.at(kInput), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kDeviceToHostPs), - breakdown.at(kOutput), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostCompilePs), - breakdown.at(kCompile), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kUnknownTimePs), - breakdown.at(kAllOthers), kMaxError); - - EXPECT_EQ(GetGenericEventTypeStr(kDeviceCollectives), record.bottleneck()); -} - -TEST(OpStatsToPodStats, Diagnostics) { - OpStats op_stats; - op_stats.mutable_step_db()->set_use_incomplete_step(true); - PodStatsDatabase pod_stats_db = ConvertOpStatsToPodStats(op_stats); - EXPECT_EQ(1, pod_stats_db.diagnostics().warnings_size()); - EXPECT_EQ(kErrorIncompleteStep, pod_stats_db.diagnostics().warnings(0)); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc deleted file mode 100644 index aad1e1ca79fd95..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h" - -#include - -#include "absl/log/check.h" -#include "tensorflow/core/profiler/convert/op_stats_to_pod_stats.h" -#include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" - -namespace tensorflow { -namespace profiler { -namespace { - -PodStatsSequence ConvertOpStatsToPodStatsSequence(const OpStats& op_stats, - PodStatsDatabase pod_stats) { - PodStatsSequence result_db; - // PodStatsDatabase is created using the same iteration order below. - // Thus, we just need to move one record at a time. - int i = 0; - for (const auto& step_sequence : op_stats.step_db().step_sequence()) { - PodStatsMap* pod_stats_map = result_db.add_pod_stats_map(); - pod_stats_map->set_step_num(step_sequence.step_num()); - for (const auto& entry : step_sequence.step_info_per_core()) { - PodStatsRecord& record = - (*pod_stats_map->mutable_pod_stats_per_core())[entry.first]; - DCHECK_LE(i, pod_stats.pod_stats_record_size()); - record = std::move(*pod_stats.mutable_pod_stats_record(i++)); - } - } - return result_db; -} - -} // namespace - -PodViewerDatabase ConvertOpStatsToPodViewer(const OpStats& op_stats) { - PodViewerDatabase database; - database.set_device_type(op_stats.run_environment().device_type()); - PodStatsDatabase pod_stats = ConvertOpStatsToPodStats(op_stats); - database.mutable_step_breakdown_events()->Swap( - pod_stats.mutable_step_breakdown_events()); - *database.mutable_pod_stats_sequence() = - ConvertOpStatsToPodStatsSequence(op_stats, std::move(pod_stats)); - PopulateStepDiagnostics(op_stats, database.mutable_diagnostics()); - return database; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h deleted file mode 100644 index c45c99393758b0..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_VIEWER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_VIEWER_H_ - -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/pod_viewer.pb.h" - -namespace tensorflow { -namespace profiler { - -PodViewerDatabase ConvertOpStatsToPodViewer(const OpStats& op_stats); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_VIEWER_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc deleted file mode 100644 index 2273bce70fb228..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h" - -#include "google/protobuf/any.pb.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" -#include "tensorflow/core/profiler/utils/event_span.h" - -namespace tensorflow { -namespace profiler { -namespace { - -const double kMaxError = 1e-6; -constexpr int kStepNum = 2; -constexpr int kCoreId = 1001; -constexpr int kStepTimePs = 1000; -constexpr int kHostComputePs = 50; -constexpr int kHostCompilePs = 50; -constexpr int kHostToHostPs = 50; -constexpr int kHostToDevicePs = 50; -constexpr int kHostPreparePs = 50; -constexpr int kDeviceCollectivePs = 350; -constexpr int kHostWaitInputPs = 50; -constexpr int kDeviceToDevicePs = 50; -constexpr int kDeviceToHostPs = 50; -constexpr int kDeviceCompute32Ps = 50; -constexpr int kDeviceCompute16Ps = 50; -constexpr int kDeviceWaitDevicePs = 50; -constexpr int kDeviceWaitHostPs = 50; -constexpr int kUnknownTimePs = 50; -static constexpr char kHostname[] = "host:123"; - -void CreateOpStats(OpStats* op_stats) { - PerCoreStepInfo* info = op_stats->mutable_step_db()->add_step_sequence(); - info->set_step_num(kStepNum); - StepInfoResult& step_info = (*info->mutable_step_info_per_core())[kCoreId]; - step_info.set_step_num(kStepNum); - step_info.set_duration_ps(kStepTimePs); - GenericStepBreakdown breakdown; - auto& type_ps = *breakdown.mutable_type_ps(); - type_ps[HOST_COMPUTE] = kHostComputePs; - type_ps[HOST_COMPILE] = kHostCompilePs; - type_ps[HOST_TO_HOST] = kHostToHostPs; - type_ps[HOST_TO_DEVICE] = kHostToDevicePs; - type_ps[HOST_PREPARE] = kHostPreparePs; - type_ps[DEVICE_COLLECTIVES] = kDeviceCollectivePs; - type_ps[HOST_WAIT_INPUT] = kHostWaitInputPs; - type_ps[DEVICE_TO_DEVICE] = kDeviceToDevicePs; - type_ps[DEVICE_TO_HOST] = kDeviceToHostPs; - type_ps[DEVICE_COMPUTE_32] = kDeviceCompute32Ps; - type_ps[DEVICE_COMPUTE_16] = kDeviceCompute16Ps; - type_ps[DEVICE_WAIT_DEVICE] = kDeviceWaitDevicePs; - type_ps[DEVICE_WAIT_HOST] = kDeviceWaitHostPs; - type_ps[UNKNOWN_TIME] = kUnknownTimePs; - step_info.mutable_step_breakdown()->PackFrom(breakdown); - CoreDetails& details = (*op_stats->mutable_core_id_to_details())[kCoreId]; - details.set_hostname(kHostname); -} - -TEST(OpStatsToPodViewer, GpuPodViewer) { - OpStats op_stats; - CreateOpStats(&op_stats); - PodViewerDatabase pod_viewer_db = ConvertOpStatsToPodViewer(op_stats); - EXPECT_EQ(1, pod_viewer_db.pod_stats_sequence().pod_stats_map_size()); - const PodStatsMap& pod_stats_map = - pod_viewer_db.pod_stats_sequence().pod_stats_map(0); - EXPECT_EQ(kStepNum, pod_stats_map.step_num()); - const PodStatsRecord& record = pod_stats_map.pod_stats_per_core().at(kCoreId); - EXPECT_EQ(kStepNum, record.step_num()); - EXPECT_EQ(kHostname, record.host_name()); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kStepTimePs), - record.total_duration_us(), kMaxError); - const auto& breakdown = record.step_breakdown_us(); - EXPECT_NEAR( - tsl::profiler::PicoToMicro(kDeviceCompute32Ps + kDeviceCompute16Ps), - breakdown.at(kDeviceCompute), kMaxError); - EXPECT_NEAR( - tsl::profiler::PicoToMicro(kDeviceToDevicePs + kDeviceWaitDevicePs), - breakdown.at(kDeviceToDevice), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kDeviceCollectivePs), - breakdown.at(kDeviceCollectives), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostComputePs), - breakdown.at(kHostCompute), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostPreparePs), - breakdown.at(kHostPrepare), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostWaitInputPs + kHostToDevicePs + - kDeviceWaitHostPs), - breakdown.at(kInput), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kDeviceToHostPs), - breakdown.at(kOutput), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kHostCompilePs), - breakdown.at(kCompile), kMaxError); - EXPECT_NEAR(tsl::profiler::PicoToMicro(kUnknownTimePs), - breakdown.at(kAllOthers), kMaxError); - - EXPECT_EQ(GetGenericEventTypeStr(kDeviceCollectives), record.bottleneck()); -} - -TEST(OpStatsToPodViewer, Diagnostics) { - OpStats op_stats; - op_stats.mutable_step_db()->set_use_incomplete_step(true); - PodViewerDatabase pod_viewer_db = ConvertOpStatsToPodViewer(op_stats); - EXPECT_EQ(1, pod_viewer_db.diagnostics().warnings_size()); - EXPECT_EQ(kErrorIncompleteStep, pod_viewer_db.diagnostics().warnings(0)); -} - -TEST(OpStatsToPodViewer, DeviceType) { - OpStats op_stats; - op_stats.mutable_run_environment()->set_device_type("GPU"); - PodViewerDatabase pod_viewer_db = ConvertOpStatsToPodViewer(op_stats); - EXPECT_EQ("GPU", pod_viewer_db.device_type()); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc deleted file mode 100644 index 58ebbc10ec9571..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc +++ /dev/null @@ -1,266 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_roofline_model.h" - -#include -#include - -#include "absl/log/check.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/roofline_model.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/diagnostics.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using tensorflow::profiler::OpMetrics; -using tensorflow::profiler::OpMetricsDb; -using tensorflow::profiler::PerfEnv; -using tensorflow::profiler::roofline_model::RecordType; -using tensorflow::profiler::roofline_model::RooflineModelDatabase; -using tensorflow::profiler::roofline_model::RooflineModelRecord; - -// The maximum number of records to generate. -const uint32_t kMaxNumRecords = 1000; -} // namespace - -RooflineModelRecord ConvertOpMetricsToRooflineModelRecord( - const OpStats& op_stats, const OpMetrics& metrics, RecordType record_type, - uint32_t step_num, uint64_t total_time_ps, - const RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed) { - RooflineModelRecord record; - record.set_hlo_name(metrics.name()); - record.set_hlo_category(metrics.category()); - record.set_hlo_module_id(metrics.hlo_module_id()); - record.set_record_type(record_type); - record.set_step_num(step_num); - SetExecutionTimes(metrics, &record); - if (record_type == RecordType::AVERAGE_STEP) { - // For RecordType::AVERAGE_STEP, divide by num_steps to show per-step - // numbers when appropriate. - int num_steps = op_stats.step_db().step_sequence_size(); - record.set_total_time_in_us(record.total_time_in_us() / num_steps); - record.set_total_self_time_in_us(record.total_self_time_in_us() / - num_steps); - } - record.set_total_time_per_core_in_us(tsl::profiler::SafeDivide( - record.total_time_in_us(), - op_stats.run_environment().device_core_count())); - record.set_total_time_in_percentage( - tsl::profiler::SafeDivide(metrics.time_ps(), total_time_ps)); - - tensorflow::profiler::SetTpuUnitFractions(metrics, &record); - - // Set the roofline-specific fields. - SetRooflineMetrics(metrics, op_stats.perf_env(), op_stats.run_environment(), - &record); - const double cmem_wr_utilization = - roofline_model_db.has_cmem() - ? tsl::profiler::SafeDivide(record.cmem_write_bw(), - roofline_model_db.peak_cmem_write_bw()) - : 0; - const double cmem_rd_utilization = - roofline_model_db.has_cmem() - ? tsl::profiler::SafeDivide(record.cmem_read_bw(), - roofline_model_db.peak_cmem_read_bw()) - : 0; - const double vmem_rd_utilization = - roofline_model_db.has_merged_vmem() - ? tsl::profiler::SafeDivide(record.vmem_read_bw(), - roofline_model_db.peak_vmem_read_bw()) - : 0; - const double vmem_wr_utilization = - roofline_model_db.has_merged_vmem() - ? tsl::profiler::SafeDivide(record.vmem_write_bw(), - roofline_model_db.peak_vmem_write_bw()) - : 0; - const double flops_utilization = tsl::profiler::SafeDivide( - record.measured_flop_rate(), roofline_model_db.peak_flop_rate()); - const double hbm_utilization = tsl::profiler::SafeDivide( - record.hbm_bw(), roofline_model_db.peak_hbm_bw()); - - const double max_mem_utilization = - std::max({cmem_wr_utilization, cmem_rd_utilization, hbm_utilization, - vmem_wr_utilization, vmem_rd_utilization}); - const double roofline_efficiency = - std::max({max_mem_utilization, flops_utilization}); - // Note, copy-start/done can have utilizations above 1.0 since their - // bytes/time are not accurate as they are asynchronous. - record.set_optimal_flop_rate(tsl::profiler::SafeDivide( - record.measured_flop_rate(), roofline_efficiency)); - record.set_roofline_efficiency(roofline_efficiency); - record.set_flop_rate_relative_to_hw_limit(flops_utilization); - record.set_memory_bw_relative_to_hw_limit(max_mem_utilization); - - record.set_include_infeed_outfeed(include_infeed_outfeed); - - return record; -} - -RooflineModelRecord GenerateRooflineModelProgramRecord( - const OpStats& op_stats, const OpMetricsDb& db, RecordType record_type, - uint32_t step_num, const RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed) { - OpMetrics program_metrics; - program_metrics.set_name("Program"); - program_metrics.set_category("Program"); - program_metrics.set_occurrences(1); - uint64_t infeed_outfeed_time = 0; - for (const OpMetrics& metrics : db.metrics_db()) { - // Aggregate innermost ops only to avoid redundant counting. - if (tsl::profiler::MayHaveInnerOps(metrics.category())) continue; - if (!include_infeed_outfeed && - tsl::profiler::IsInfeedOrOutfeed(metrics.category())) { - infeed_outfeed_time += metrics.time_ps(); - continue; - } - program_metrics.set_flops(program_metrics.flops() + metrics.flops()); - program_metrics.set_model_flops(program_metrics.model_flops() + - metrics.model_flops()); - program_metrics.set_bytes_accessed(program_metrics.bytes_accessed() + - metrics.bytes_accessed()); - CombineMemoryAccessedBreakdown( - metrics.memory_accessed_breakdown(), - program_metrics.mutable_memory_accessed_breakdown()); - } - uint64_t total_time_ps = db.total_time_ps(); - if (!include_infeed_outfeed) total_time_ps -= infeed_outfeed_time; - program_metrics.set_time_ps(total_time_ps); - RooflineModelRecord program_record = ConvertOpMetricsToRooflineModelRecord( - op_stats, program_metrics, record_type, step_num, total_time_ps, - roofline_model_db, include_infeed_outfeed); - program_record.set_rank(0); - program_record.set_total_self_time_as_fraction(0.0); - program_record.set_cumulative_total_self_time_as_fraction(0.0); - return program_record; -} - -tsl::protobuf::RepeatedPtrField -ConvertOpMetricsDbToRooflineModelRecords( - const OpStats& op_stats, const OpMetricsDb& db, RecordType record_type, - uint32_t step_num, const RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed) { - tsl::protobuf::RepeatedPtrField roofline_model_records; - RooflineModelRecord* program_record = roofline_model_records.Add(); - *program_record = GenerateRooflineModelProgramRecord( - op_stats, db, record_type, step_num, roofline_model_db, - include_infeed_outfeed); - const RooflineModelRecord* prev_record = program_record; - uint64_t infeed_outfeed_time = 0; - if (!include_infeed_outfeed) { - // Calculate the total time spent on infeed and outfeed ops. - for (const OpMetrics& metrics : db.metrics_db()) { - if (tsl::profiler::IsInfeedOrOutfeed(metrics.category())) { - infeed_outfeed_time += metrics.time_ps(); - } - } - } - uint64_t total_time_ps = db.total_time_ps() - infeed_outfeed_time; - double total_time_us = tsl::profiler::PicoToMicro(total_time_ps); - for (const auto* metrics : SortedOpMetricsDb(db, kMaxNumRecords)) { - if (metrics->occurrences() == 0) continue; - if (!include_infeed_outfeed && - tsl::profiler::IsInfeedOrOutfeed(metrics->category())) { - continue; - } - RooflineModelRecord* record = roofline_model_records.Add(); - *record = ConvertOpMetricsToRooflineModelRecord( - op_stats, *metrics, record_type, step_num, total_time_ps, - roofline_model_db, include_infeed_outfeed); - SetRankAndTimeFractions(total_time_us, *prev_record, record); - prev_record = record; - } - return roofline_model_records; -} - -RooflineModelDatabase InitializeRooflineModelDatabaseFromOpStats( - const OpStats& op_stats, bool include_infeed_outfeed) { - tensorflow::profiler::HardwareType hardware_type = - op_stats.run_environment().hardware_type(); - DCHECK(hardware_type == GPU || hardware_type == TPU); - - RooflineModelDatabase roofline_model_db; - const PerfEnv& perf_env = op_stats.perf_env(); - roofline_model_db.set_device_type(op_stats.run_environment().device_type()); - - // Set peak flop rate in GFLOPs/s. - roofline_model_db.set_peak_flop_rate( - tsl::profiler::TeraToGiga((perf_env.peak_tera_flops_per_second()))); - roofline_model_db.set_peak_hbm_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 0))); - - if (hardware_type == HardwareType::TPU) { - roofline_model_db.set_megacore(perf_env.has_megacore()); - - roofline_model_db.set_has_cmem(perf_env.has_cmem()); - roofline_model_db.set_has_merged_vmem(perf_env.has_merged_vmem()); - if (roofline_model_db.has_cmem()) { - roofline_model_db.set_peak_cmem_read_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 3))); - roofline_model_db.set_peak_cmem_write_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 4))); - } else if (roofline_model_db.has_merged_vmem()) { - roofline_model_db.set_peak_vmem_read_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 5))); - roofline_model_db.set_peak_vmem_write_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 6))); - } - } else if (hardware_type == HardwareType::GPU) { - roofline_model_db.set_megacore(false); - roofline_model_db.set_has_cmem(false); - roofline_model_db.set_has_merged_vmem(true); - roofline_model_db.set_peak_vmem_read_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 1))); - roofline_model_db.set_peak_vmem_write_bw( - tsl::profiler::GigaToGibi(GetMemoryPeakBandwidth(perf_env, 2))); - } - - return roofline_model_db; -} - -RooflineModelDatabase ConvertOpStatsToRooflineModel( - const OpStats& op_stats, bool include_infeed_outfeed) { - HardwareType hardware_type = op_stats.run_environment().hardware_type(); - if (hardware_type != GPU && hardware_type != TPU) { - return RooflineModelDatabase(); - } - - RooflineModelDatabase roofline_model_db = - InitializeRooflineModelDatabaseFromOpStats(op_stats, - include_infeed_outfeed); - - AddRooflineModelRecordForProfileDuration(op_stats, roofline_model_db, - include_infeed_outfeed); - AddRooflineModelRecordsForCompleteSteps(op_stats, roofline_model_db, - include_infeed_outfeed); - AddRooflineModelRecordsPerStep(op_stats, roofline_model_db, - include_infeed_outfeed); - PopulateStepDiagnostics(op_stats, roofline_model_db.mutable_diagnostics()); - return roofline_model_db; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.h b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.h deleted file mode 100644 index f2ed42f783d86e..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.h +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_ROOFLINE_MODEL_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_ROOFLINE_MODEL_H_ - -#include - -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/roofline_model.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { - -using tensorflow::profiler::OpMetrics; -using tensorflow::profiler::roofline_model::RecordType; -using tensorflow::profiler::roofline_model::RooflineModelDatabase; -using tensorflow::profiler::roofline_model::RooflineModelRecord; - -RooflineModelRecord ConvertOpMetricsToRooflineModelRecord( - const OpStats& op_stats, const OpMetrics& metrics, RecordType record_type, - uint32_t step_num, uint64_t total_time_ps, - const RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed); - -RooflineModelRecord GenerateRooflineModelProgramRecord( - const OpStats& op_stats, const OpMetricsDb& db, RecordType record_type, - uint32_t step_num, const RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed); - -tsl::protobuf::RepeatedPtrField -ConvertOpMetricsDbToRooflineModelRecords( - const OpStats& op_stats, const OpMetricsDb& db, RecordType record_type, - uint32_t step_num, const RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed); - -tensorflow::profiler::roofline_model::RooflineModelDatabase -ConvertOpStatsToRooflineModel(const tensorflow::profiler::OpStats& tf_op_stats, - bool include_infeed_outfeed); - -tensorflow::profiler::roofline_model::RooflineModelDatabase -InitializeRooflineModelDatabaseFromOpStats(const OpStats& op_stats, - bool include_infeed_outfeed); -// Generate RooflineModelRecord for the HLO DB over the entire profiling -// duration including incomplete steps. -inline void AddRooflineModelRecordForProfileDuration( - const OpStats& op_stats, RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed) { - *roofline_model_db.mutable_roofline_model_record() = - ConvertOpMetricsDbToRooflineModelRecords( - op_stats, op_stats.device_op_metrics_db(), RecordType::ALL, - /*step_num=*/0, roofline_model_db, include_infeed_outfeed); -} - -// Generate RooflineModelRecord for the HLO DB over complete steps only. -inline void AddRooflineModelRecordsForCompleteSteps( - const OpStats& op_stats, RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed) { - if (op_stats.has_hlo_metrics_db_complete_steps_only()) { - *roofline_model_db.add_roofline_model_record() = - GenerateRooflineModelProgramRecord( - op_stats, op_stats.hlo_metrics_db_complete_steps_only(), - RecordType::AVERAGE_STEP, /*step_num=*/0, roofline_model_db, - include_infeed_outfeed); - } -} - -// Generate RooflineModelRecords for the per-step DBs. -inline void AddRooflineModelRecordsPerStep( - const OpStats& op_stats, RooflineModelDatabase& roofline_model_db, - bool include_infeed_outfeed) { - for (const auto& step_info : op_stats.step_db().step_sequence()) { - *roofline_model_db.add_roofline_model_record() = - GenerateRooflineModelProgramRecord( - op_stats, step_info.hlo_metrics_db(), RecordType::PER_STEP, - step_info.step_num(), roofline_model_db, include_infeed_outfeed); - } -} - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_ROOFLINE_MODEL_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc deleted file mode 100644 index 841a7b58be9d4c..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h" - -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_to_record.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" - -namespace tensorflow { -namespace profiler { -namespace { - -// The maximum number of Tensorflow Ops displayed on Tensorflow Stats page. -// 500 device side ops and 500 host side ops. -const int kMaxNumOfOps = 500; - -TfStatsRecord ConvertOpMetricsToTfStatsRecord(bool on_device, - const OpMetrics& metrics, - const PerfEnv& perf_env, - const RunEnvironment& run_env) { - TfStatsRecord record; - record.set_host_or_device(on_device ? "Device" : "Host"); - record.set_is_eager(metrics.is_eager()); - record.set_op_type(metrics.category()); - record.set_op_name(metrics.name()); - SetExecutionTimes(metrics, &record); - SetRooflineMetrics(metrics, perf_env, run_env, &record); - return record; -} - -TfStatsTable GenerateTfStatsTable( - const OpMetricsDb& host_tf_metrics_db, - const OpMetricsDb& device_tf_metrics_db, - const KernelStatsByOpName& kernel_stats_by_op_name, const PerfEnv& perf_env, - const RunEnvironment& run_env, bool exclude_idle) { - TfStatsTable tf_stats_table; - TfStatsRecord sentinel; - sentinel.set_rank(0); - sentinel.set_device_cumulative_total_self_time_as_fraction(0.0); - sentinel.set_host_cumulative_total_self_time_as_fraction(0.0); - const TfStatsRecord* prev_record = &sentinel; - - // Sets device-side TF stats. - uint64 total_device_time_ps = TotalTimePs(device_tf_metrics_db, exclude_idle); - double total_device_time_us = - tsl::profiler::PicoToMicro(total_device_time_ps); - for (const OpMetrics* metrics : - SortedOpMetricsDb(device_tf_metrics_db, kMaxNumOfOps)) { - if (exclude_idle && IsIdleOp(*metrics)) continue; - TfStatsRecord* record = tf_stats_table.add_tf_stats_record(); - *record = ConvertOpMetricsToTfStatsRecord( - /*on_device=*/true, *metrics, perf_env, run_env); - // Compute TensorCore utilization only on device side. - auto iter = kernel_stats_by_op_name.find(record->op_name()); - if (iter != kernel_stats_by_op_name.end()) { - record->set_gpu_tensorcore_utilization( - tsl::profiler::SafeDivide(iter->second.tensor_core_duration_ns, - iter->second.total_duration_ns)); - } else { - record->set_gpu_tensorcore_utilization(0.0); - } - SetRankAndDeviceTimeFractions(total_device_time_us, *prev_record, record); - prev_record = record; - } - - // Sets host-side TF stats. - uint64 total_host_time_ps = TotalTimePs(host_tf_metrics_db, exclude_idle); - double total_host_time_us = tsl::profiler::PicoToMicro(total_host_time_ps); - for (const OpMetrics* metrics : tensorflow::profiler::SortedOpMetricsDb( - host_tf_metrics_db, kMaxNumOfOps)) { - if (exclude_idle && IsIdleOp(*metrics)) continue; - TfStatsRecord* record = tf_stats_table.add_tf_stats_record(); - *record = ConvertOpMetricsToTfStatsRecord( - /*on_device=*/false, *metrics, perf_env, run_env); - // Host side TensorCore utilization is always 0.0 - record->set_gpu_tensorcore_utilization(0.0); - SetRankAndHostTimeFractions(total_host_time_us, *prev_record, record); - prev_record = record; - } - return tf_stats_table; -} - -} // namespace - -TfStatsDatabase ConvertOpStatsToTfStats(const OpStats& op_stats) { - const OpMetricsDb& host_tf_metrics_db = op_stats.host_op_metrics_db(); - OpMetricsDb device_tf_metrics_db = - CreateTfMetricsDbFromDeviceOpMetricsDb(op_stats.device_op_metrics_db()); - const PerfEnv perf_env = op_stats.perf_env(); - const RunEnvironment run_env = op_stats.run_environment(); - KernelStatsByOpName kernel_stats_by_op_name = - GroupKernelReportsByOpName(op_stats.kernel_stats_db()); - TfStatsDatabase tf_stats_db; - *tf_stats_db.mutable_with_idle() = GenerateTfStatsTable( - host_tf_metrics_db, device_tf_metrics_db, kernel_stats_by_op_name, - perf_env, run_env, /*exclude_idle=*/false); - *tf_stats_db.mutable_without_idle() = GenerateTfStatsTable( - host_tf_metrics_db, device_tf_metrics_db, kernel_stats_by_op_name, - perf_env, run_env, /*exclude_idle=*/true); - tf_stats_db.set_device_type(op_stats.run_environment().device_type()); - return tf_stats_db; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.h b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.h deleted file mode 100644 index 3b8a06ef1c6619..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_TF_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_TF_STATS_H_ - -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -TfStatsDatabase ConvertOpStatsToTfStats(const OpStats& op_stats); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_TF_STATS_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc deleted file mode 100644 index abe9d599d971a9..00000000000000 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h" - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -XEventBuilder AddTensorFlowOpEvent(std::string&& tf_op_fullname, - int64_t start_timestamp_ns, - int64_t duration_ns, bool on_device, - absl::string_view kernel_name, - XPlaneBuilder* plane, XLineBuilder* line) { - absl::string_view name = on_device ? kernel_name : tf_op_fullname; - XEventBuilder event = line->AddEvent(*plane->GetOrCreateEventMetadata(name)); - event.SetTimestampNs(start_timestamp_ns); - event.SetDurationNs(duration_ns); - if (!on_device) return event; - event.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), - *plane->GetOrCreateStatMetadata(std::move(tf_op_fullname))); - return event; -} - -void AddTensorFlowOpEventWithKernelDetails(std::string&& tf_op_fullname, - int64_t start_timestamp_ns, - int64_t duration_ns, bool on_device, - absl::string_view kernel_name, - absl::string_view kernel_details, - XPlaneBuilder* plane, - XLineBuilder* line) { - XEventBuilder event = - AddTensorFlowOpEvent(std::move(tf_op_fullname), start_timestamp_ns, - duration_ns, on_device, kernel_name, plane, line); - if (!on_device) return; - event.ParseAndAddStatValue(*plane->GetOrCreateStatMetadata("kernel_details"), - kernel_details); -} - -TEST(OpStatsToTfStats, GpuTfStats) { - // TfOp1 has kernel1 and kernel2; TfOp2 has kernel3; - // TfOp3 has kernel4 and kernel5 and is TensorCore eligible. - static constexpr char kTfOp1[] = "TfOp1"; - static constexpr char kTfOp2[] = "TfOp2"; - static constexpr char kTfOp3[] = "Conv2D"; - static constexpr char kKernel1[] = "kernel1"; - static constexpr char kKernel2[] = "kernel2"; - static constexpr char kKernel3[] = "kernel3"; - // Kernel4 is a kernel using TensorCore - static constexpr char kKernel4[] = "volta_fp16_s884gemm"; - static constexpr char kKernel5[] = "kernel5"; - constexpr int64_t kKernel1StartNs = 100000; - constexpr int64_t kKernel1DurationNs = 8000; - constexpr int64_t kKernel2StartNs = 110000; - constexpr int64_t kKernel2DurationNs = 10000; - constexpr int64_t kKernel3StartNs = 120000; - constexpr int64_t kKernel3DurationNs = 10000; - constexpr int64_t kKernel4StartNs = 130000; - constexpr int64_t kKernel4DurationNs = 10000; - constexpr int64_t kKernel5StartNs = 150000; - constexpr int64_t kKernel5DurationNs = 10000; - - // Mock kernel details for both kernel4 and kernel5. - const std::string kKernelDetails = R"MULTI(regs:32 -static_shared:0 -dynamic_shared:16384 -grid:2,1,1 -block:32,1,1 -occ_pct:100)MULTI"; - - XSpace space; - XPlaneBuilder device_plane( - GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0)); - XLineBuilder stream1 = device_plane.GetOrCreateLine(/*line_id=*/10); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel1StartNs, - kKernel1DurationNs, /*on_device=*/true, kKernel1, - &device_plane, &stream1); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel2StartNs, - kKernel2DurationNs, /*on_device=*/true, kKernel2, - &device_plane, &stream1); - XLineBuilder stream2 = device_plane.GetOrCreateLine(/*line_id=*/20); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel1StartNs, - kKernel1DurationNs, /*on_device=*/true, kKernel1, - &device_plane, &stream2); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel2StartNs, - kKernel2DurationNs, /*on_device=*/true, kKernel2, - &device_plane, &stream2); - AddTensorFlowOpEvent(absl::StrCat(kTfOp2, ":", kTfOp2), kKernel3StartNs, - kKernel3DurationNs, /*on_device=*/true, kKernel3, - &device_plane, &stream2); - AddTensorFlowOpEventWithKernelDetails( - absl::StrCat(kTfOp3, ":", kTfOp3), kKernel4StartNs, kKernel4DurationNs, - /*on_device=*/true, kKernel4, kKernelDetails, &device_plane, &stream2); - AddTensorFlowOpEventWithKernelDetails( - absl::StrCat(kTfOp3, ":", kTfOp3), kKernel5StartNs, kKernel5DurationNs, - /*on_device=*/true, kKernel5, kKernelDetails, &device_plane, &stream2); - - OpStatsOptions options; - options.generate_kernel_stats_db = true; - options.generate_op_metrics_db = true; - const OpStats op_stats = ConvertXSpaceToOpStats(space, options); - const TfStatsDatabase tf_stats = ConvertOpStatsToTfStats(op_stats); - - EXPECT_EQ(tf_stats.device_type(), op_stats.run_environment().device_type()); - - // TfOp1, TfOp3, TfOp2, Idle - EXPECT_EQ(4, tf_stats.with_idle().tf_stats_record_size()); - - const TfStatsRecord& record_0 = tf_stats.with_idle().tf_stats_record(0); - EXPECT_EQ(kTfOp1, record_0.op_name()); - EXPECT_EQ(kTfOp1, record_0.op_type()); - EXPECT_EQ(2, record_0.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToMicro(kKernel1DurationNs) * 2 + - tsl::profiler::NanoToMicro(kKernel2DurationNs) * 2, - record_0.total_self_time_in_us()); - - const TfStatsRecord& record_1 = tf_stats.with_idle().tf_stats_record(1); - EXPECT_EQ(kTfOp3, record_1.op_name()); - EXPECT_EQ(kTfOp3, record_1.op_type()); - EXPECT_EQ(1, record_1.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToMicro(kKernel4DurationNs) + - tsl::profiler::NanoToMicro(kKernel5DurationNs), - record_1.total_self_time_in_us()); - // GPU TensorCore utilization is 0.5 because kernel4 is using TensorCore and - // kernel5 is not using TensorCore, and they have the same duration. - EXPECT_DOUBLE_EQ(0.5, record_1.gpu_tensorcore_utilization()); - - const TfStatsRecord& record_2 = tf_stats.with_idle().tf_stats_record(2); - EXPECT_EQ(kTfOp2, record_2.op_name()); - EXPECT_EQ(kTfOp2, record_2.op_type()); - EXPECT_EQ(1, record_2.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToMicro(kKernel3DurationNs), - record_2.total_self_time_in_us()); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc b/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc deleted file mode 100644 index 7d8f22914421a7..00000000000000 --- a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/preprocess_single_host_xplane.h" - -#include - -#include "absl/strings/match.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/preprocess_xplane.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/derived_timeline.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" - -namespace tensorflow { -namespace profiler { - -void PreprocessSingleHostXSpace( - XSpace* space, bool step_grouping, bool derived_timeline, - tsl::profiler::GroupMetadataMap* group_metadata_map) { - if (step_grouping && !tsl::profiler::IsXSpaceGrouped(*space)) { - // Grouping (i.e. marking step number) events in the XSpace. - std::vector device_traces; - bool isTpu = false; - for (XPlane& plane : *space->mutable_planes()) { - if (tsl::profiler::IsDevicePlane(plane)) { - device_traces.push_back(&plane); - } - // Preprocess XPlane to convert stats to Traceme2 semantics - tsl::profiler::PreprocessXPlane(&plane); - - if (!isTpu && absl::StartsWith(plane.name(), kTpuPlanePrefix)) { - isTpu = true; - } - } - - tsl::profiler::EventForest event_forest; - if (isTpu) { - // group TPU events - GroupTpuEventsOSS(space, device_traces, &event_forest); - } else { - // group GPU events - tsl::profiler::GroupTfEvents(space, &event_forest); - } - - if (derived_timeline) { - // Generated miscellaneous derived time lines for device planes. - GenerateDerivedTimeLines(event_forest.GetGroupMetadataMap(), space); - } - - if (group_metadata_map != nullptr) { - *group_metadata_map = event_forest.GetGroupMetadataMap(); - } - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.h b/tensorflow/core/profiler/convert/preprocess_single_host_xplane.h deleted file mode 100644 index 4c86ed8758bc4a..00000000000000 --- a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_PREPROCESS_SINGLE_HOST_XPLANE_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_PREPROCESS_SINGLE_HOST_XPLANE_H_ - -#include "xla/tsl/profiler/utils/group_events.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Preprocess XSpaces before tools conversion. -// If step_grouping = true, perform events grouping for step tracking. -// If derived_timeline, generate derived timeline (XLines). -// If group_metadata_map is not nullptr, populate the group metadata map. -void PreprocessSingleHostXSpace( - XSpace* space, bool step_grouping, bool derived_timeline, - tsl::profiler::GroupMetadataMap* group_metadata_map = nullptr); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_PREPROCESS_SINGLE_HOST_XPLANE_H_ diff --git a/tensorflow/core/profiler/convert/process_megascale_dcn.cc b/tensorflow/core/profiler/convert/process_megascale_dcn.cc deleted file mode 100644 index ab2b9fbefe60bd..00000000000000 --- a/tensorflow/core/profiler/convert/process_megascale_dcn.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/process_megascale_dcn.h" - -#include - -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/convert/dcn_analysis.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -using tsl::profiler::CreateTfXPlaneVisitor; -using tsl::profiler::FindMutableTensorCorePlanes; - -void ProcessMegascaleDcn(XSpace* space) { - std::vector device_xplanes = FindMutableTensorCorePlanes(space); - int num_tpu_cores = device_xplanes.size(); - // DCN TraceMe's are in the Host XPlane - XPlane* host_plane = - FindMutablePlaneWithName(space, tsl::profiler::kHostThreadsPlaneName); - const XPlaneVisitor plane_visitor = CreateTfXPlaneVisitor(host_plane); - // TODO(yashjs): Update parameter value for `is_megacore`. - DcnEventsProcessor dcn_events_processor(num_tpu_cores, false); - dcn_events_processor.SetupMessageInfo(plane_visitor); - if (dcn_events_processor.HasDcnMessages( - tsl::profiler::kMegaScaleDcnReceive)) { - dcn_events_processor.ProcessReceiveMessages(plane_visitor); - } - // Update host XPlane with DCN traffic - dcn_events_processor.AddHostDcnTrafficToXPlane(host_plane); - // Update device XPlanes with per collective TPU traffic. - for (XPlane* device_xplane : device_xplanes) { - dcn_events_processor.AddTpuCollectiveDcnTrafficToXPlane(device_xplane); - } - - SortXSpace(space); -} -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/process_megascale_dcn.h b/tensorflow/core/profiler/convert/process_megascale_dcn.h deleted file mode 100644 index 794c2bea66462a..00000000000000 --- a/tensorflow/core/profiler/convert/process_megascale_dcn.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_PROCESS_MEGASCALE_DCN_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_PROCESS_MEGASCALE_DCN_H_ - -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Process Dcn Megascale TraceMe info. -void ProcessMegascaleDcn(XSpace* space); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_PROCESS_MEGASCALE_DCN_H_ diff --git a/tensorflow/core/profiler/convert/profile_time_breakdown.cc b/tensorflow/core/profiler/convert/profile_time_breakdown.cc deleted file mode 100644 index e1826a7119f9a2..00000000000000 --- a/tensorflow/core/profiler/convert/profile_time_breakdown.cc +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/profile_time_breakdown.h" - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" - -namespace tensorflow { -namespace profiler { - -void ProfileTimeBreakdown::SetCategoryTimePs(absl::string_view category, - uint64_t time_ps) { - time_ps_by_category_.insert_or_assign(category, time_ps); -} - -uint64_t ProfileTimeBreakdown::PopCategoryTimePs(absl::string_view category) { - uint64_t time_ps = 0; - auto iter = time_ps_by_category_.find(category); - if (iter != time_ps_by_category_.end()) { - time_ps = iter->second; - time_ps_by_category_.erase(iter); - } - return time_ps; -} - -void ProfileTimeBreakdown::BreakdownSparseCoreV0Infeed() { - // Infeed from SparseCoreV0 and outfeed to SparseCoreV0 are mostly identical - // in compute since they do the same transformation. We can subtract out the - // outfeed time from the infeed time to know how much time the TensorCore - // actually spent waiting on SparseCoreV0. - uint64_t bc_infeed_ps = - PopCategoryTimePs(tsl::profiler::kHloSparseCoreV0Infeed); - if (bc_infeed_ps == 0) return; - uint64_t bc_outfeed_ps = - CategoryTimePs(tsl::profiler::kHloSparseCoreV0Outfeed); - - uint64_t bc_infeed_transform_ps = std::min(bc_infeed_ps, bc_outfeed_ps); - uint64_t bc_infeed_wait_ps = bc_infeed_ps - bc_infeed_transform_ps; - - SetCategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedWait, - bc_infeed_wait_ps); - SetCategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedTransform, - bc_infeed_transform_ps); -} - -std::string ProfileTimeBreakdown::DebugString() const { - std::string str; - for (const auto& [category, time_ps] : time_ps_by_category_) { - absl::StrAppend(&str, category, ": ", tsl::profiler::PicoToUni(time_ps), - "\n"); - } - absl::StrAppend( - &str, "total_time: ", tsl::profiler::PicoToUni(total_time_ps_), "\n"); - absl::StrAppend( - &str, "profile_time: ", tsl::profiler::PicoToUni(profile_time_ps_), "\n"); - return str; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/profile_time_breakdown.h b/tensorflow/core/profiler/convert/profile_time_breakdown.h index 1e3379beb4c457..9b68baad5ecf79 100644 --- a/tensorflow/core/profiler/convert/profile_time_breakdown.h +++ b/tensorflow/core/profiler/convert/profile_time_breakdown.h @@ -15,230 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_PROFILE_TIME_BREAKDOWN_H_ #define TENSORFLOW_CORE_PROFILER_CONVERT_PROFILE_TIME_BREAKDOWN_H_ -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" - -namespace tensorflow { -namespace profiler { - -// Allows accumulating time spent in different HLO instruction categories to -// breakdown the total profile time and compute metrics of interest. -class ProfileTimeBreakdown { - public: - // Category should be the operator category disambiguated by xprof instead of - // the original category from XLA. - // For a correct time breakdown, we need to use the self time of operators, - // instead of total time to avoid double counting. Note that for leaf ops, - // self time and total time are the same. - void IncrementCategoryTimePs(absl::string_view category, - uint64_t self_time_ps) { - time_ps_by_category_[category] += self_time_ps; - total_time_ps_ += self_time_ps; - } - - // Profile time cannot be smaller than the total time in all categories. - // If combining profiles across multiple cores, profile time should be the - // profiling duration multiplied by the number of cores that were profiled. - // go/autograppler_profile_time - void SetProfileTimePs(uint64_t profile_time_ps) { - DCHECK_LE(total_time_ps_, profile_time_ps); - profile_time_ps_ = profile_time_ps; - } - - // Breaks down "sparsecorev0 infeed" into two components: - // 1) "sparsecorev0 infeed wait": Time spent waiting on the SparseCoreV0. - // 2) "sparsecorev0 infeed transform": Time spent transforming activations in - // SparseCoreV0 layout into XLA layout. - // Even though 2) is part of the overall embedding computation, it is time - // spent doing work on the TensorCore. - void BreakdownSparseCoreV0Infeed(); - - // Duty cycle is the fraction of time an accelerator is being actively used. - // go/accelerator-metrics-definitions#common-accelerator-metrics - // go/ag-tpu-duty-cycle - double DutyCycle() const { return TimeFraction(OnDutyTimePs()); } - - double IdleFraction() const { return TimeFraction(IdleTimePs()); } - - double InfeedFraction() const { - return CategoryFraction(tsl::profiler::kHloInfeed); - } - - double OutfeedFraction() const { - return CategoryFraction(tsl::profiler::kHloOutfeed); - } - - double SparseCoreV0InfeedFraction() const { - return CategoriesFraction({tsl::profiler::kHloSparseCoreV0Infeed, - tsl::profiler::kHloSparseCoreV0InfeedWait, - tsl::profiler::kHloSparseCoreV0InfeedTransform}); - } - - double SparseCoreV0OutfeedFraction() const { - return CategoryFraction(tsl::profiler::kHloSparseCoreV0Outfeed); - } - - double AllReduceFraction() const { - return CategoryFraction(tsl::profiler::kHloAllReduce); - } - - double AllReduceFusionFraction() const { - return CategoryFraction(tsl::profiler::kHloAllReduceFusion); - } - - double SendRecvFraction() const { - return CategoriesFraction( - {tsl::profiler::kHloSend, tsl::profiler::kHloSendDone, - tsl::profiler::kHloRecv, tsl::profiler::kHloRecvDone}); - } - - double HostSendRecvFraction() const { - return CategoriesFraction( - {tsl::profiler::kHloHostSend, tsl::profiler::kHloHostSendDone, - tsl::profiler::kHloHostRecv, tsl::profiler::kHloHostRecvDone}); - } - - double CategoriesFraction( - const std::initializer_list& categories) const { - return TimeFraction(CategoriesTimePs(categories)); - } - - double CategoryFraction(absl::string_view category) const { - return TimeFraction(CategoryTimePs(category)); - } - - uint64_t ProfileTimePs() const { return profile_time_ps_; } - - uint64_t TotalTimePs() const { return total_time_ps_; } - - uint64_t IdleTimePs() const { return profile_time_ps_ - total_time_ps_; } - - uint64_t OnDutyTimePs() const { return profile_time_ps_ - OffDutyTimePs(); } - - uint64_t OffDutyTimePs() const { - return IdleTimePs() + - CategoriesTimePs( - {tsl::profiler::kHloInfeed, tsl::profiler::kHloOutfeed, - tsl::profiler::kHloHostSend, tsl::profiler::kHloHostSendDone, - tsl::profiler::kHloHostRecv, tsl::profiler::kHloHostRecvDone, - tsl::profiler::kHloMegacoreFusion}); - } - - uint64_t InfeedTimePs() const { - return CategoryTimePs(tsl::profiler::kHloInfeed); - } - - uint64_t OutfeedTimePs() const { - return CategoryTimePs(tsl::profiler::kHloOutfeed); - } - - uint64_t SparseCoreV0InfeedWaitTimePs() const { - return CategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedWait); - } - - uint64_t SparseCoreV0InfeedTransformTimePs() const { - return CategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedTransform); - } - - uint64_t SparseCoreV0OutfeedTimePs() const { - return CategoryTimePs(tsl::profiler::kHloSparseCoreV0Outfeed); - } - - uint64_t AllReduceOrAllToAllTimePs() const { - return CategoriesTimePs({tsl::profiler::kHloAllReduce, - tsl::profiler::kHloAllReduceFusion, - tsl::profiler::kHloAllToAll}); - } - - uint64_t SendTimePs() const { - return CategoriesTimePs( - {tsl::profiler::kHloSend, tsl::profiler::kHloSendDone}); - } - - uint64_t RecvTimePs() const { - return CategoriesTimePs( - {tsl::profiler::kHloRecv, tsl::profiler::kHloRecvDone}); - } - - uint64_t HostSendTimePs() const { - return CategoriesTimePs( - {tsl::profiler::kHloHostSend, tsl::profiler::kHloHostSendDone}); - } - - uint64_t HostRecvTimePs() const { - return CategoriesTimePs( - {tsl::profiler::kHloHostRecv, tsl::profiler::kHloHostRecvDone}); - } - - // Megacore fusion runs different operations on each core, e.g., a convolution - // on one core and an all-reduce on the other core. In a trace, megacore - // fusion is the parent operation, and its self time is the time that the core - // executing the faster operation waits for the core executing the slower - // operation to reach the synchronization point. - uint64_t MegacoreFusionTimePs() const { - return CategoryTimePs(tsl::profiler::kHloMegacoreFusion); - } - - uint64_t HighFlopsComputeTimePs() const { - return CategoriesTimePs({tsl::profiler::kHloConvolution, - tsl::profiler::kHloConvolutionBaseDilated, - tsl::profiler::kHloConvolutionWindowDilated, - tsl::profiler::kHloConvolutionFusion, - tsl::profiler::kHloOutputFusion}); - } - - // Calculated according to the "TC busy time" defined in go/tpu_kpis - uint64_t TensorCoreBusyTimePs() const { - return profile_time_ps_ - OffDutyTimePs() - SparseCoreV0InfeedWaitTimePs(); - } - - uint64_t CategoriesTimePs( - const std::initializer_list& categories) const { - uint64_t time_ps = 0; - for (auto category : categories) { - time_ps += CategoryTimePs(category); - } - return time_ps; - } - - uint64_t CategoryTimePs(absl::string_view category) const { - auto iter = time_ps_by_category_.find(category); - return (iter == time_ps_by_category_.end()) ? 0 : iter->second; - } - - template - void ComputeCategoryFractions(Map& category_fractions) { - for (const auto& [category, time_ps] : time_ps_by_category_) { - category_fractions[category] = TimeFraction(time_ps); - } - } - - std::string DebugString() const; - - private: - // Overwrites the time attributed to the given category. - void SetCategoryTimePs(absl::string_view category, uint64_t time_ps); - - // Removes and returns the time attributed to the given category. - uint64_t PopCategoryTimePs(absl::string_view category); - - double TimeFraction(uint64_t time_ps) const { - return tsl::profiler::SafeDivide(time_ps, profile_time_ps_); - } - - absl::flat_hash_map time_ps_by_category_; - uint64_t total_time_ps_ = 0; // Sum of values in time_ps_by_category_. - uint64_t profile_time_ps_ = 0; -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/convert/profile_time_breakdown.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_CONVERT_PROFILE_TIME_BREAKDOWN_H_ diff --git a/tensorflow/core/profiler/convert/repository.cc b/tensorflow/core/profiler/convert/repository.cc deleted file mode 100644 index 6fcadd8caf65c0..00000000000000 --- a/tensorflow/core/profiler/convert/repository.cc +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/repository.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/strings/strip.h" -#include "xla/tsl/platform/env.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/utils/file_system_utils.h" -#include "tensorflow/core/platform/errors.h" -#include "tsl/platform/path.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { -std::string GetHostnameByPath(absl::string_view xspace_path) { - std::string_view file_name = tsl::io::Basename(xspace_path); - // Remove suffix from file_name, preserving entire prefix. - absl::ConsumeSuffix(&file_name, ".xplane.pb"); - return std::string(file_name); -} -} // namespace - -absl::StatusOr SessionSnapshot::Create( - std::vector xspace_paths, - std::optional>> xspaces) { - if (xspace_paths.empty()) { - return errors::InvalidArgument("Can not find XSpace path."); - } - - if (xspaces.has_value()) { - if (xspaces->size() != xspace_paths.size()) { - return errors::InvalidArgument( - "The size of the XSpace paths: ", xspace_paths.size(), - " is not equal ", - "to the size of the XSpace proto: ", xspaces->size()); - } - for (size_t i = 0; i < xspace_paths.size(); ++i) { - auto host_name = GetHostnameByPath(xspace_paths.at(i)); - if (xspaces->at(i)->hostnames_size() > 0 && !host_name.empty()) { - if (!absl::StrContains(host_name, xspaces->at(i)->hostnames(0))) { - return errors::InvalidArgument( - "The hostname of xspace path and preloaded xpace don't match at " - "index: ", - i, ". \nThe host name of xpace path is ", host_name, - " but the host name of preloaded xpace is ", - xspaces->at(i)->hostnames(0), "."); - } - } - } - } - - return SessionSnapshot(std::move(xspace_paths), std::move(xspaces)); -} - -absl::StatusOr> SessionSnapshot::GetXSpace( - size_t index) const { - if (index >= xspace_paths_.size()) { - return errors::InvalidArgument("Can not get the ", index, - "th XSpace. The total number of XSpace is ", - xspace_paths_.size()); - } - - // Return the pre-loaded XSpace proto. - if (xspaces_.has_value()) { - if (xspaces_->at(index) == nullptr) { - return errors::Internal(""); - } - return std::move(xspaces_->at(index)); - } - - // Return the XSpace proto from file. - auto xspace_from_file = std::make_unique(); - TF_RETURN_IF_ERROR(tsl::ReadBinaryProto( - tsl::Env::Default(), xspace_paths_.at(index), xspace_from_file.get())); - return xspace_from_file; -} - -absl::StatusOr> SessionSnapshot::GetXSpaceByName( - absl::string_view name) const { - if (auto it = hostname_map_.find(name); it != hostname_map_.end()) { - return GetXSpace(it->second); - } - - return errors::InvalidArgument("Can not find the XSpace by name: ", name, - ". The total number of XSpace is ", - xspace_paths_.size()); -} - -std::string SessionSnapshot::GetHostname(size_t index) const { - return GetHostnameByPath(xspace_paths_.at(index)); -} - -std::optional SessionSnapshot::GetFilePath( - absl::string_view toolname, absl::string_view hostname) const { - if (!has_accessible_run_dir_) return std::nullopt; - std::string file_name = ""; - if (toolname == "trace_viewer@") - file_name = absl::StrCat(hostname, ".", "SSTABLE"); - if (!file_name.empty()) return tsl::io::JoinPath(session_run_dir_, file_name); - return std::nullopt; -} - -absl::StatusOr SessionSnapshot::GetHostDataFileName( - const StoredDataType data_type, const std::string host) const { - for (const auto& format : *kHostDataSuffixes) { - if (data_type == format.first) return absl::StrCat(host, format.second); - } - return absl::InternalError(&"Unknown StoredDataType: "[data_type]); -} - -absl::StatusOr> SessionSnapshot::GetHostDataFilePath( - const StoredDataType data_type, const std::string host) const { - // Gets all the files in session run directory. - std::vector results; - TF_RETURN_IF_ERROR(::tsl::Env::Default()->GetChildren( - std::string(GetSessionRunDir()), &results)); - - TF_ASSIGN_OR_RETURN(std::string filename, - GetHostDataFileName(data_type, host)); - - for (const std::string& path : results) { - if (absl::EndsWith(path, filename)) { - return ::tsl::profiler::ProfilerJoinPath(GetSessionRunDir(), filename); - } - } - - return std::nullopt; -} - -absl::StatusOr> SessionSnapshot::HasCacheFile( - const StoredDataType data_type) const { - std::optional filepath; - TF_ASSIGN_OR_RETURN(filepath, - GetHostDataFilePath(data_type, kNoHostIdentifier)); - if (filepath) { - // cache file is present but file contains no data_type events - return std::pair(true, std::string()); - } - - TF_ASSIGN_OR_RETURN(filepath, - GetHostDataFilePath(data_type, kAllHostsIdentifier)); - if (filepath) { - // cache file is present and file contains data_type events - return std::pair(true, filepath.value()); - } - - // no cache file present - return std::pair(false, std::string()); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/repository.h b/tensorflow/core/profiler/convert/repository.h deleted file mode 100644 index f6d4f78277d592..00000000000000 --- a/tensorflow/core/profiler/convert/repository.h +++ /dev/null @@ -1,199 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_REPOSITORY_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_REPOSITORY_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/env.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/utils/file_system_utils.h" -#include "tensorflow/core/profiler/utils/hlo_module_map.h" -#include "tsl/platform/path.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -constexpr char kAllHostsIdentifier[] = "ALL_HOSTS"; -constexpr char kNoHostIdentifier[] = "NO_HOST"; - -enum StoredDataType { - DCN_COLLECTIVE_STATS, -}; - -static auto* kHostDataSuffixes = - new std::vector>( - {{StoredDataType::DCN_COLLECTIVE_STATS, ".dcn_collective_stats.pb"}}); - -// File system directory snapshot of a profile session. -class SessionSnapshot { - public: - // Performs validation and creates SessionSnapshot. - // are the file paths to XSpace protos. - // Optionally, can contain the XSpace protos pre-loaded by the - // profiler plugin. - static absl::StatusOr Create( - std::vector xspace_paths, - std::optional>> xspaces); - - // Returns the number of XSpaces in the profile session. - size_t XSpaceSize() const { return xspace_paths_.size(); } - - // Gets XSpace proto. - // The caller of this function will take ownership of the XSpace. - absl::StatusOr> GetXSpace(size_t index) const; - - // Gets XSpace proto. - // The caller of this function will take ownership of the XSpace. - absl::StatusOr> GetXSpaceByName( - absl::string_view name) const; - - // Gets host name. - std::string GetHostname(size_t index) const; - - // Gets the run directory of the profile session. - absl::string_view GetSessionRunDir() const { return session_run_dir_; } - - // Gets whether the session has an accessible run dir. If false, any - // path-based file read will be disabled in this mode. - bool HasAccessibleRunDir() const { return has_accessible_run_dir_; } - - // Gets the path of the fast file for a given tool. - std::optional GetFilePath(absl::string_view toolname, - absl::string_view host) const; - - // Gets the name of the host data file. - absl::StatusOr GetHostDataFileName(StoredDataType data_type, - std::string host) const; - - // Gets the path of the host data file. - absl::StatusOr> GetHostDataFilePath( - StoredDataType data_type, std::string host) const; - - /* Gets whether the cache file is present in run dir. First value indicates - whether cache file is present or not. Second value indicates the path of cache - file. Possible cases are: - 1. : If no cache file is present - 2. : If cache file is present but file contains no data_type - events - 3. : If cache file is present and file contains data_type - events - */ - absl::StatusOr> HasCacheFile( - StoredDataType data_type) const; - - template - absl::Status WriteBinaryProto(const StoredDataType data_type, - const std::string host, T& proto) const { - // Gets name for host data file. - TF_ASSIGN_OR_RETURN(std::string filename, - GetHostDataFileName(data_type, host)); - - std::string filepath = - tsl::profiler::ProfilerJoinPath(GetSessionRunDir(), filename); - - return tsl::WriteBinaryProto(tsl::Env::Default(), filepath, proto); - } - - template - absl::Status ReadBinaryProto(const StoredDataType data_type, - const std::string host, T* proto) const { - // Gets file path for host data. - TF_ASSIGN_OR_RETURN(std::optional filepath, - GetHostDataFilePath(data_type, host)); - if (filepath) { - return tsl::ReadBinaryProto(tsl::Env::Default(), filepath.value(), proto); - } - - return absl::NotFoundError( - absl::StrCat("No binary proto found for ", host, " and ", data_type)); - } - - private: - SessionSnapshot(std::vector xspace_paths, - std::optional>> xspaces) - : xspace_paths_(std::move(xspace_paths)), - // If the snapshot was initialized by xspaces, the file path and run dir - // is a path tensorflow can't read from or write to so any file IO - // encapsulated in this class will be disabled in this mode. - has_accessible_run_dir_(!xspaces.has_value()), - xspaces_(std::move(xspaces)) { - session_run_dir_ = tsl::io::Dirname(xspace_paths_.at(0)); - for (size_t i = 0; i < xspace_paths_.size(); ++i) { - std::string host_name = GetHostname(i); - hostname_map_[host_name] = i; - } - } - - // File paths to XSpace protos. - std::vector xspace_paths_; - // The run directory of the profile session. - absl::string_view session_run_dir_; - - absl::flat_hash_map - hostname_map_; - - const bool has_accessible_run_dir_; - - // XSpace protos pre-loaded by the profiler plugin. - // TODO(profiler): Use blobstore paths to initialize SessionSnapshot instead - // of using pre-loaded XSpaces. - mutable std::optional>> xspaces_; -}; - -// Writes binary proto format T for a host and data_type to a session. -template -absl::Status WriteBinaryProto(const SessionSnapshot& session_snapshot, - const StoredDataType data_type, - const std::string& host, T& proto) { - return session_snapshot.WriteBinaryProto(data_type, host, proto); -} - -// Reads binary proto format T for a host and data_type to a session. -template -absl::Status ReadBinaryProto(const SessionSnapshot& session_snapshot, - const StoredDataType data_type, - const std::string& host, T* proto) { - return session_snapshot.ReadBinaryProto(data_type, host, proto); -} - -// Process HloModuleMap from all XSpaces in a session. -inline absl::StatusOr ProcessHloModuleMap( - const SessionSnapshot& session_snapshot) { - HloModuleMap hlo_module_map; - for (int i = 0; i < session_snapshot.XSpaceSize(); i++) { - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(i)); - ProcessHloModuleMapFromXSpace(hlo_module_map, xspace.get()); - } - return hlo_module_map; -} - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_REPOSITORY_H_ diff --git a/tensorflow/core/profiler/convert/repository_test.cc b/tensorflow/core/profiler/convert/repository_test.cc deleted file mode 100644 index 3f3872bd13fd8b..00000000000000 --- a/tensorflow/core/profiler/convert/repository_test.cc +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/repository.h" - -#include -#include -#include -#include - -#include -#include -#include "xla/tsl/platform/status.h" -#include "tensorflow/core/platform/errors.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::testing::Eq; - -TEST(Repository, GetHostName) { - auto session_snapshot_or = - SessionSnapshot::Create({"log/plugins/profile/hostname0.xplane.pb", - "log/plugins/profile/hostname1.xplane.pb"}, - /*xspaces=*/std::nullopt); - TF_CHECK_OK(session_snapshot_or.status()); - EXPECT_THAT(session_snapshot_or.value().GetHostname(0), Eq("hostname0")); - EXPECT_THAT(session_snapshot_or.value().GetHostname(1), Eq("hostname1")); - EXPECT_TRUE(session_snapshot_or.value().HasAccessibleRunDir()); -} - -TEST(Repository, GetHostNameWithPeriods) { - auto session_snapshot_or = - SessionSnapshot::Create({"log/plugins/profile/127.0.0.1_6009.xplane.pb"}, - /*xspaces=*/std::nullopt); - TF_CHECK_OK(session_snapshot_or.status()); - EXPECT_THAT(session_snapshot_or.value().GetHostname(0), Eq("127.0.0.1_6009")); - EXPECT_TRUE(session_snapshot_or.value().HasAccessibleRunDir()); -} - -TEST(Repository, GetSpaceByHostName) { - std::vector> xspaces; - // prepare host 1. - auto space1 = std::make_unique(); - *(space1->add_hostnames()) = "hostname1"; - // with index 0 which shouldn't impact the space finding by name. - xspaces.push_back(std::move(space1)); - - // prepare host 0. - auto space0 = std::make_unique(); - *(space0->add_hostnames()) = "hostname0"; - // with index 1 which shouldn't impact the space finding by name. - xspaces.push_back(std::move(space0)); - - auto session_snapshot_or = - SessionSnapshot::Create({"log/plugins/profile/hostname1.xplane.pb", - "log/plugins/profile/hostname0.xplane.pb"}, - std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - auto xspace0_or = session_snapshot_or.value().GetXSpaceByName("hostname0"); - TF_CHECK_OK(xspace0_or.status()); - auto xspace1_or = session_snapshot_or.value().GetXSpaceByName("hostname1"); - EXPECT_FALSE(session_snapshot_or.value().HasAccessibleRunDir()); - TF_CHECK_OK(xspace1_or.status()); - EXPECT_THAT(xspace0_or.value()->hostnames(0), Eq("hostname0")); - EXPECT_THAT(xspace1_or.value()->hostnames(0), Eq("hostname1")); -} - -TEST(Repository, GetSSTableFile) { - auto session_snapshot_or = - SessionSnapshot::Create({"log/plugins/profile/hostname0.xplane.pb"}, - /*xspaces=*/std::nullopt); - TF_CHECK_OK(session_snapshot_or.status()); - auto sstable_path = - session_snapshot_or.value().GetFilePath("trace_viewer@", "hostname0"); - auto not_found_path = - session_snapshot_or.value().GetFilePath("memory_viewer", "hostname0"); - EXPECT_THAT(sstable_path, Eq("log/plugins/profile/hostname0.SSTABLE")); - EXPECT_THAT(not_found_path, Eq(std::nullopt)); -} - -TEST(Repository, GetSSTableFileWithXSpace) { - std::vector> xspaces; - // prepare host 0. - auto space0 = std::make_unique(); - *(space0->add_hostnames()) = "hostname0"; - // with index 1 which shouldn't impact the space finding by name. - xspaces.push_back(std::move(space0)); - auto session_snapshot_or = SessionSnapshot::Create( - {"log/plugins/profile/hostname0.xplane.pb"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - auto file_path_init_by_xspace = - session_snapshot_or.value().GetFilePath("trace_viewer@", "hostname0"); - // The file path should be disabled in this mode. - EXPECT_THAT(file_path_init_by_xspace, Eq(std::nullopt)); -} - -TEST(Repository, MismatchedXSpaceAndPath) { - std::vector> xspaces; - // prepare host 1. - auto space1 = std::make_unique(); - *(space1->add_hostnames()) = "hostname1"; - // with index 0 which shouldn't impact the space finding by name. - xspaces.push_back(std::move(space1)); - - // prepare host 0. - auto space0 = std::make_unique(); - *(space0->add_hostnames()) = "hostname0"; - // with index 1 which shouldn't impact the space finding by name. - xspaces.push_back(std::move(space0)); - - auto session_snapshot_or = - SessionSnapshot::Create({"log/plugins/profile/hostname0.xplane.pb", - "log/plugins/profile/hostname1.xplane.pb"}, - std::move(xspaces)); - auto error = - R"(The hostname of xspace path and preloaded xpace don't match at index: 0. -The host name of xpace path is hostname0 but the host name of preloaded xpace is hostname1.)"; - EXPECT_THAT(session_snapshot_or.status(), Eq(errors::InvalidArgument(error))); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc deleted file mode 100644 index 0d8f90bcbbbb76..00000000000000 --- a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc +++ /dev/null @@ -1,224 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/step_events_to_steps_db.h" - -#include -#include -#include -#include -#include - -#include "google/protobuf/any.pb.h" -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" - -namespace tensorflow { -namespace profiler { - -// Local core id should start from 1. -const uint32 kDefaultGpuLocalCoreId = 1; - -namespace { - -void StepEventsToPerCoreStepInfo(uint32_t step_num, StepDetails& step_details, - PerCoreStepInfo& per_core_step_info) { - per_core_step_info.set_step_num(step_num); - OpMetricsDbCombiner combiner(per_core_step_info.mutable_hlo_metrics_db()); - auto step_time = step_details.StepTime(); - if (step_time.duration_ps() == 0) { - // In case no step markers are observed for the particular step, Skip the - // step. - VLOG(1) << "Skipping step " << step_details.StepName() - << "with no step markers"; - return; - } - for (auto& [core_id, metrics_db] : step_details.PerCoreOpMetricsDb()) { - SetTotalTimePs(metrics_db, step_time.duration_ps()); - AddIdleOp(metrics_db); - // TODO(b/397774568): Remove this once the SparseCore OpMetricsDb is - // implemented. - if (core_id < kSparseCoreIndexStart) combiner.Combine(metrics_db); - - GenericStepBreakdown step_breakdown; - auto& category_ps = *(step_breakdown.mutable_category_ps()); - for (auto& metric : metrics_db.metrics_db()) { - category_ps[metric.category()] += metric.self_time_ps(); - } - - StepInfoResult step_info; - step_info.set_step_num(step_num); - step_info.set_step_name(step_details.StepName()); - step_info.set_begin_ps(step_time.begin_ps()); - step_info.set_duration_ps(step_time.duration_ps()); - step_info.mutable_step_breakdown()->PackFrom(step_breakdown); - (*per_core_step_info.mutable_step_info_per_core())[core_id] = - std::move(step_info); - } - auto& all_reduce_db_per_core_map = - *per_core_step_info.mutable_all_reduce_db_per_core(); - for (const auto& [core_id, all_reduce_db] : step_details.Collectives()) { - all_reduce_db_per_core_map[core_id].CopyFrom(all_reduce_db); - } -} - -// Converts from StepDetails to StepInfoResult. -StepInfoResult ConvertStepDetailsToStepInfo(bool has_device, int64_t step_num, - StepDetails& step_details) { - GenericStepBreakdown generic; - tsl::profiler::Timespan step_time = step_details.StepTime(); - auto& type_ps = *(generic.mutable_type_ps()); - uint64 total_event_duration = 0; - for (const auto& event : step_details.Events()) { - // Ignore event duration outside the step marker. - uint64 event_duration = step_time.OverlappedDurationPs(event.span); - type_ps[event.type] += event_duration; - total_event_duration += event_duration; - } - if (total_event_duration < step_time.duration_ps()) { - // Some time in the step is not associated with any event. Classify them as - // "unknown time". - type_ps[UNKNOWN_TIME] += step_time.duration_ps() - total_event_duration; - } - // Determines if this particular step is a well-formed one. - bool well_formed_step = has_device ? type_ps.contains(DEVICE_COMPUTE_16) || - type_ps.contains(DEVICE_COMPUTE_32) - : type_ps.contains(HOST_COMPUTE); - StepInfoResult step_info; - step_info.mutable_step_breakdown()->PackFrom(generic); - if (well_formed_step) { - step_info.set_step_num(step_num); - step_info.set_step_name(step_details.StepName()); - step_info.set_begin_ps(step_time.begin_ps()); - step_info.set_duration_ps(step_time.duration_ps()); - } else { - // For a non-well-formed step, sets its duration to 0 so that it will be - // ignored by the caller of this function. - step_info.set_duration_ps(0); - } - return step_info; -} - -string DebugGenericStepBreakdown(const GenericStepBreakdown& generic) { - std::ostringstream out; - uint64 total_ps = 0; - const auto& type_ps_map = generic.type_ps(); - for (const auto& type_ps : type_ps_map) { - total_ps += type_ps.second; - } - out << "Total ps = " << total_ps << std::endl; - for (int type = LAST_EVENT_TYPE; type >= 0; --type) { - const auto* ps = gtl::FindOrNull(type_ps_map, type); - if (ps == nullptr) continue; - double percent = (*ps * 100.0) / total_ps; - auto event_type = static_cast(type); - out << PrintEventType(event_type) << ": " << percent << "%" - << ", ps = " << *ps << std::endl; - } - return out.str(); -} - -string DebugStepInfo(const StepInfoResult& step_info) { - std::ostringstream out; - out << "step_num=" << step_info.step_num() - << ", duration_ps=" << step_info.duration_ps() - << ", begin_ps=" << step_info.begin_ps() << std::endl; - GenericStepBreakdown generic; - if (step_info.step_breakdown().UnpackTo(&generic)) { - out << "Generic step breakdown:" << std::endl; - out << DebugGenericStepBreakdown(generic) << std::endl; - } else { - out << step_info.step_breakdown().DebugString() << std::endl; - } - return out.str(); -} - -} // namespace - -StepDatabaseResult ConvertStepEventsToStepDb( - bool has_device, bool maybe_drop_incomplete_steps, - StepEvents& nonoverlapped_step_events) { - StepDatabaseResult step_db; - // Gets sorted step numbers. - std::vector step_numbers; - step_numbers.reserve(nonoverlapped_step_events.size()); - for (const auto& step_events : nonoverlapped_step_events) { - step_numbers.push_back(step_events.first); - } - absl::c_sort(step_numbers); - for (const auto& step : step_numbers) { - auto* step_details = gtl::FindOrNull(nonoverlapped_step_events, step); - if (step_details == nullptr) continue; - PerCoreStepInfo per_core_step_info; - per_core_step_info.set_step_num(step); - if (!step_details->PerCoreOpMetricsDb().empty()) { - StepEventsToPerCoreStepInfo(step, *step_details, per_core_step_info); - } else { - StepInfoResult step_info = - ConvertStepDetailsToStepInfo(has_device, step, *step_details); - if (step_info.duration_ps() == 0) - continue; // Do not include non-well-formed steps. - // When we generated StepEvents, we already put events from all device - // cores and cpu threads on this host into a single event stream, - // therefore we can't separate them anymore. Simply assigns all events to - // Core-0. - (*per_core_step_info - .mutable_step_info_per_core())[kDefaultGpuLocalCoreId] = - std::move(step_info); - VLOG(2) - << std::endl - << "step_id: " << step << ", step_info:" << std::endl - << DebugStepInfo( - (*per_core_step_info - .mutable_step_info_per_core())[kDefaultGpuLocalCoreId]); - // Populates the collective ops information. - auto& collectives = *per_core_step_info.mutable_all_reduce_db_per_core(); - for (const auto& it : step_details->Collectives()) { - collectives[it.first] = it.second; - } - // Populates the device transfer stats for this step. - auto& device_memory_transfers = - *per_core_step_info.mutable_device_memory_transfers(); - for (const auto& dma : step_details->DeviceMemoryTransfers()) { - *device_memory_transfers.Add() = dma; - } - } - // The remaining fields in PerCoreStepInfo are not filled. - *step_db.add_step_sequence() = per_core_step_info; - } - - // If we are using sampling mode and we get enough steps, we would like to - // drop the incomplete steps at the beginning and the end. - // (Sometimes CUTPI instrumentation will prolong the first step too). - int kDropIncomplteteStepThreshold = 5; - if (maybe_drop_incomplete_steps && - step_db.step_sequence_size() > kDropIncomplteteStepThreshold) { - step_db.mutable_step_sequence()->erase( - step_db.mutable_step_sequence()->begin()); - step_db.mutable_step_sequence()->RemoveLast(); - } - return step_db; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.h b/tensorflow/core/profiler/convert/step_events_to_steps_db.h deleted file mode 100644 index 9764c46cfca6de..00000000000000 --- a/tensorflow/core/profiler/convert/step_events_to_steps_db.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_ - -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" - -namespace tensorflow { -namespace profiler { - -TF_CONST_INIT extern const uint32 kDefaultGpuLocalCoreId; - -// Converts from overlapped Step-Events to StepDatabaseResult. -StepDatabaseResult ConvertStepEventsToStepDb( - bool has_device, bool maybe_drop_incomplete_steps, - StepEvents& overlapped_step_events); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_ diff --git a/tensorflow/core/profiler/convert/tool_options.h b/tensorflow/core/profiler/convert/tool_options.h deleted file mode 100644 index b3f787df943058..00000000000000 --- a/tensorflow/core/profiler/convert/tool_options.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TOOL_OPTIONS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_TOOL_OPTIONS_H_ - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_cat.h" - -namespace tensorflow { -namespace profiler { - -using ToolOptions = - absl::flat_hash_map>; - -// Helper function to get parameter from tool options. -template -std::optional GetParam(const ToolOptions& options, const std::string& key) { - const auto iter = options.find(key); - if (iter == options.end()) { - return std::nullopt; - } - - const T* result = std::get_if(&iter->second); - if (!result) { - return std::nullopt; - } - return *result; -} - -// Helper function to get parameter from tool options with default value. -template -T GetParamWithDefault(const ToolOptions& options, const std::string& key, - const T& default_param) { - if (auto param = GetParam(options, key)) { - return *param; - } - return default_param; -} - -inline std::string DebugString(const ToolOptions& options) { - std::string output; - for (const auto& [k, v] : options) { - absl::StrAppend( - &output, k, ":", - std::visit([](const auto& value) { return absl::StrCat(value); }, v), - ":", v.index(), ";"); - } - return absl::StrCat("{", output, "}"); -} - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TOOL_OPTIONS_H_ diff --git a/tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h b/tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h deleted file mode 100644 index ba0fcf1919e414..00000000000000 --- a/tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TPU_INPUT_PIPELINE_ANALYSIS_CONSTANTS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_TPU_INPUT_PIPELINE_ANALYSIS_CONSTANTS_H_ - -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/macros.h" - -namespace tensorflow { -namespace profiler { - -TF_CONST_INIT extern const absl::string_view kProfileAllHostsDoc; -TF_CONST_INIT extern const absl::string_view kSparseCoreV0Name; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TPU_INPUT_PIPELINE_ANALYSIS_CONSTANTS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.cc b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.cc deleted file mode 100644 index ad3bea87341162..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.cc +++ /dev/null @@ -1,160 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h" - -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/match.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -bool HasDcnCollectiveStatsInXSpace(const XSpace& xspace) { - if (const tsl::profiler::XPlane* xplane = - FindPlaneWithName(xspace, tsl::profiler::kHostThreadsPlaneName); - xplane != nullptr) { - for (const auto& [_, metadata] : xplane->event_metadata()) { - if (absl::StartsWith(metadata.name(), "MegaScale:")) { - return true; - } - } - } - return false; -} - -absl::StatusOr GetDcnCollectiveStatsFromMultiXSpaceAndSaveToFile( - const SessionSnapshot& session_snapshot) { - DcnSlackAnalysisCombiner combiner; - for (int idx = 0; idx < session_snapshot.XSpaceSize(); idx++) { - std::string hostname = session_snapshot.GetHostname(idx); - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(idx)); - - // The profile does not have dcn collective stats. - if (!HasDcnCollectiveStatsInXSpace(*xspace)) { - DcnSlackAnalysis dcnSlackAnalysis; - TF_RETURN_IF_ERROR(WriteBinaryProto(session_snapshot, - StoredDataType::DCN_COLLECTIVE_STATS, - kNoHostIdentifier, dcnSlackAnalysis)); - return false; - } - - DcnSlackAnalysis dcnSlackAnalysis = - ConvertXSpaceToDcnSlackAnalysis(*xspace, nullptr, nullptr); - - TF_RETURN_IF_ERROR(WriteBinaryProto(session_snapshot, - StoredDataType::DCN_COLLECTIVE_STATS, - hostname, dcnSlackAnalysis)); - - combiner.Combine(dcnSlackAnalysis); - } - - DcnSlackAnalysis dcnSlackAnalysis = combiner.Finalize(); - TF_RETURN_IF_ERROR(WriteBinaryProto(session_snapshot, - StoredDataType::DCN_COLLECTIVE_STATS, - kAllHostsIdentifier, dcnSlackAnalysis)); - - // The profile has dcn collective stats. - return true; -} - -} // namespace - -absl::StatusOr HasDcnCollectiveStatsInMultiXSpace( - const SessionSnapshot& session_snapshot) { - std::pair hasCacheFile; - TF_ASSIGN_OR_RETURN(hasCacheFile, session_snapshot.HasCacheFile( - StoredDataType::DCN_COLLECTIVE_STATS)); - - // Cache file not present, check if trace contains dcn collective stats. - if (!hasCacheFile.first) { - for (int idx = 0; idx < session_snapshot.XSpaceSize(); idx++) { - std::string hostname = session_snapshot.GetHostname(idx); - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(idx)); - - if (HasDcnCollectiveStatsInXSpace(*xspace)) { - return true; - } - } - return false; - } - - if (hasCacheFile.second.empty()) { - // If the profiler finds a file NO_HOST.dcn_collective_stats.pb, this means - // dcn collective stats are not present in the profile. - return false; - } else { - // If the profiler finds a file ALL_HOSTS.dcn_collective_stats.pb, this - // means dcn collective stats are present in the profile. - return true; - } -} - -absl::StatusOr ConvertMultiXSpaceToDcnCollectiveStats( - const SessionSnapshot& session_snapshot) { - std::pair hasCacheFile; - TF_ASSIGN_OR_RETURN(hasCacheFile, session_snapshot.HasCacheFile( - StoredDataType::DCN_COLLECTIVE_STATS)); - - // Cache file not present, generate dcn collective stats. - if (!hasCacheFile.first) { - return GetDcnCollectiveStatsFromMultiXSpaceAndSaveToFile(session_snapshot); - } - - if (hasCacheFile.second.empty()) { - // If the profiler finds a file NO_HOST.dcn_collective_stats.pb, this means - // dcn collective stats are not present in the profile. - return false; - } else { - // If the profiler finds a file ALL_HOSTS.dcn_collective_stats.pb, this - // means dcn collective stats are present in the profile. - return true; - } -} - -absl::StatusOr GetDcnSlackAnalysisByHostName( - const SessionSnapshot& session_snapshot, const std::string hostname) { - TF_ASSIGN_OR_RETURN(bool hasDcnCollectiveStats, - ConvertMultiXSpaceToDcnCollectiveStats(session_snapshot)); - - DcnSlackAnalysis dcnSlackAnalysis; - if (hasDcnCollectiveStats) { - TF_RETURN_IF_ERROR(ReadBinaryProto(session_snapshot, - StoredDataType::DCN_COLLECTIVE_STATS, - hostname, &dcnSlackAnalysis)); - } - - return dcnSlackAnalysis; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h deleted file mode 100644 index 68e0b491331bdd..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_DCN_COLLECTIVE_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_DCN_COLLECTIVE_STATS_H_ - -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" - -namespace tensorflow { -namespace profiler { - -// Converts multiple XSpaces to dcn collective stats. -// Stores the dcn collective stats as files in the same directory -// as the xspace files. -absl::StatusOr ConvertMultiXSpaceToDcnCollectiveStats( - const SessionSnapshot& session_snapshot); - -// Returns whether there are dcn collective stats in the profile. -absl::StatusOr HasDcnCollectiveStatsInMultiXSpace( - const SessionSnapshot& session_snapshot); - -// Gets DcnSlackAnalysis proto for a host. -absl::StatusOr GetDcnSlackAnalysisByHostName( - const SessionSnapshot& session_snapshot, std::string hostname); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_DCN_COLLECTIVE_STATS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc deleted file mode 100644 index 2d73bbf8b929d6..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc +++ /dev/null @@ -1,207 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h" - -#include -#include -#include -#include -#include - -#include -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/tsl/platform/env.h" -#include "xla/tsl/platform/status.h" -#include "tensorflow/core/platform/file_system.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -DcnSlackAnalysis CreateDcnSlackAnalysisProto() { - DcnSlackAnalysis dcn_slack_analysis; - DcnSlackSummary* dcn_slack_summary = - dcn_slack_analysis.add_dcn_slack_summary(); - dcn_slack_summary->set_rendezvous("collective"); - dcn_slack_summary->set_recv_op_name("recv-done"); - dcn_slack_summary->set_send_op_name("send"); - dcn_slack_summary->set_slack_us(2); - dcn_slack_summary->set_observed_duration_us(12); - dcn_slack_summary->set_stall_duration_us(5); - dcn_slack_summary->set_occurrences(4); - dcn_slack_summary->set_bytes_transmitted_over_network(819200); - return dcn_slack_analysis; -} - -SessionSnapshot CreateSessionSnapshot(bool create_cache_file, - bool has_dcn_collective_stats) { - std::string test_name = - ::testing::UnitTest::GetInstance()->current_test_info()->name(); - std::string path = absl::StrCat("ram://", test_name, "/"); - std::unique_ptr xplane_file; - std::vector paths = {absl::StrCat(path, "hostname.xplane.pb")}; - - auto xspace = std::make_unique(); - XPlane* xplane = FindOrAddMutablePlaneWithName(xspace.get(), "/host:CPU"); - if (has_dcn_collective_stats) { - XPlaneBuilder xplane_builder(xplane); - xplane_builder.GetOrCreateEventMetadata("MegaScale:"); - } - - if (create_cache_file) { - if (has_dcn_collective_stats) { - tsl::Env::Default() - ->NewAppendableFile( - absl::StrCat(path, "hostname.dcn_collective_stats.pb"), - &xplane_file) - .IgnoreError(); - tsl::Env::Default() - ->NewAppendableFile( - absl::StrCat(path, "ALL_HOSTS.dcn_collective_stats.pb"), - &xplane_file) - .IgnoreError(); - } else { - tsl::Env::Default() - ->NewAppendableFile( - absl::StrCat(path, "NO_HOST.dcn_collective_stats.pb"), - &xplane_file) - .IgnoreError(); - } - } - - std::vector> xspaces; - xspaces.push_back(std::move(xspace)); - - absl::StatusOr session_snapshot_status = - SessionSnapshot::Create(paths, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_status.status()); - SessionSnapshot session_snapshot = std::move(session_snapshot_status.value()); - if (has_dcn_collective_stats) { - DcnSlackAnalysis dcn_slack_analysis = CreateDcnSlackAnalysisProto(); - TF_CHECK_OK(session_snapshot.WriteBinaryProto( - DCN_COLLECTIVE_STATS, "hostname", dcn_slack_analysis)); - TF_CHECK_OK(session_snapshot.WriteBinaryProto( - DCN_COLLECTIVE_STATS, kAllHostsIdentifier, dcn_slack_analysis)); - } - return session_snapshot; -} - -TEST(ConvertXplaneToDcnCollectiveStats, - HasAllHostsDcnCollectiveStatsCacheFile) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(true, true); - - absl::StatusOr status = - HasDcnCollectiveStatsInMultiXSpace(session_snapshot); - EXPECT_EQ(status.value(), true); -} - -TEST(ConvertXplaneToDcnCollectiveStats, HasNoHostDcnCollectiveStatsCacheFile) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(true, false); - - absl::StatusOr status = - HasDcnCollectiveStatsInMultiXSpace(session_snapshot); - EXPECT_EQ(status.value(), false); -} - -TEST(ConvertXplaneToDcnCollectiveStats, - NoCacheFileButTraceHasDcnCollectiveStats) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(false, true); - - absl::StatusOr status = - HasDcnCollectiveStatsInMultiXSpace(session_snapshot); - EXPECT_EQ(status.value(), true); -} - -TEST(ConvertXplaneToDcnCollectiveStats, - NoCacheFileNoDcnCollectiveStatsPresent) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(false, false); - - absl::StatusOr status = - HasDcnCollectiveStatsInMultiXSpace(session_snapshot); - EXPECT_EQ(status.value(), false); -} - -TEST(ConvertXplaneToDcnCollectiveStats, - ConvertXSpaceToDcnCollectiveStatsWhenStatsPresent) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(false, true); - - absl::StatusOr status = - ConvertMultiXSpaceToDcnCollectiveStats(session_snapshot); - absl::StatusOr> all_hosts_filepath = - session_snapshot.GetHostDataFilePath(StoredDataType::DCN_COLLECTIVE_STATS, - kAllHostsIdentifier); - absl::StatusOr> host_filepath = - session_snapshot.GetHostDataFilePath(StoredDataType::DCN_COLLECTIVE_STATS, - "hostname"); - - EXPECT_EQ(status.value(), true); - TF_EXPECT_OK(all_hosts_filepath.status()); - EXPECT_TRUE(all_hosts_filepath.value().has_value()); - EXPECT_FALSE(all_hosts_filepath.value().value().empty()); - TF_EXPECT_OK(host_filepath.status()); - EXPECT_TRUE(host_filepath.value().has_value()); - EXPECT_FALSE(host_filepath.value().value().empty()); -} - -TEST(ConvertXplaneToDcnCollectiveStats, - ConvertXSpaceToDcnCollectiveStatsWhenStatsNotPresent) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(false, false); - - absl::StatusOr status = - ConvertMultiXSpaceToDcnCollectiveStats(session_snapshot); - absl::StatusOr> filepath = - session_snapshot.GetHostDataFilePath(StoredDataType::DCN_COLLECTIVE_STATS, - kNoHostIdentifier); - - EXPECT_EQ(status.value(), false); - TF_EXPECT_OK(filepath.status()); - EXPECT_TRUE(filepath.value().has_value()); - EXPECT_FALSE(filepath.value().value().empty()); -} - -TEST(ConvertXplaneToDcnCollectiveStats, - GetHostDcnSlackAnalysisWhenStatsNotPresent) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(false, false); - - absl::StatusOr host_dcn_slack_analysis = - GetDcnSlackAnalysisByHostName(session_snapshot, "hostname"); - - TF_EXPECT_OK(host_dcn_slack_analysis.status()); - EXPECT_EQ(host_dcn_slack_analysis.value().dcn_slack_summary_size(), 0); -} - -TEST(ConvertXplaneToDcnCollectiveStats, - GetHostDcnSlackAnalysisWhenStatsPresent) { - SessionSnapshot session_snapshot = CreateSessionSnapshot(true, true); - - absl::StatusOr host_dcn_slack_analysis = - GetDcnSlackAnalysisByHostName(session_snapshot, "hostname"); - - TF_EXPECT_OK(host_dcn_slack_analysis.status()); - EXPECT_EQ(host_dcn_slack_analysis.value().dcn_slack_summary_size(), 1); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_hlo.cc b/tensorflow/core/profiler/convert/xplane_to_hlo.cc deleted file mode 100644 index 62ee1c487b41a7..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_hlo.cc +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_hlo.h" - -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/service/hlo.pb.h" -#include "xla/tsl/platform/env.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/utils/file_system_utils.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -using tsl::profiler::ProfilerJoinPath; - -constexpr char kNoModuleIdentifier[] = "NO_MODULE"; -constexpr char kHloProtoSuffix[] = ".hlo_proto.pb"; - -// Extracts and deduplicates the HLO protos from all the XSpaces. -// Stores the HLO protos as files in the same directory as the xspace files. -absl::StatusOr GetHloProtoFromMultiXSpaceAndSaveToFile( - const SessionSnapshot& session_snapshot) { - // Get all HLO protos from XSpaces and deduplicate. - HloProtoMap hlo_proto_map; - for (int i = 0; i < session_snapshot.XSpaceSize(); i++) { - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(i)); - hlo_proto_map.AddHloProtosFromXSpace(*xspace); - } - - std::vector module_list = hlo_proto_map.GetModuleList(); - // Write an empty identifier if there is no HLO module. - if (module_list.empty()) { - std::string file_name = - ProfilerJoinPath(session_snapshot.GetSessionRunDir(), - absl::StrCat(kNoModuleIdentifier, kHloProtoSuffix)); - xla::HloProto empty_hlo; - TF_RETURN_IF_ERROR( - tsl::WriteBinaryProto(tsl::Env::Default(), file_name, empty_hlo)); - // The profile does not have HLO proto. - return false; - } - - // Save HLO protos to session run directory. - for (const absl::string_view module_name : module_list) { - auto hlo_proto_or = hlo_proto_map.GetHloProtoByModuleName(module_name); - if (!hlo_proto_or.ok()) { - return errors::Internal(hlo_proto_or.status().message()); - } - std::string file_name = - ProfilerJoinPath(session_snapshot.GetSessionRunDir(), - absl::StrCat(module_name, kHloProtoSuffix)); - TF_RETURN_IF_ERROR(tsl::WriteBinaryProto(tsl::Env::Default(), file_name, - *hlo_proto_or.value())); - } - - // The profile has HLO proto. - return true; -} - -} // namespace - -absl::StatusOr GetHloProtoByModuleName( - const SessionSnapshot& session_snapshot, - const absl::string_view module_name) { - std::string file_name = - ProfilerJoinPath(session_snapshot.GetSessionRunDir(), - absl::StrCat(module_name, kHloProtoSuffix)); - xla::HloProto hlo_proto; - TF_RETURN_IF_ERROR( - tsl::ReadBinaryProto(tsl::Env::Default(), file_name, &hlo_proto)); - return hlo_proto; -} - -absl::StatusOr ConvertMultiXSpaceToHloProto( - const SessionSnapshot& session_snapshot) { - // Gets all the files in session run directory. - // TODO(profiler): Move this glob to SessionSnapshot and build a map from file - // type to file paths. - std::vector results; - TF_RETURN_IF_ERROR(tsl::Env::Default()->GetChildren( - std::string(session_snapshot.GetSessionRunDir()), &results)); - - // If the profiler finds a filename with hlo proto suffix, this means HLO - // proto was already generated previously. - for (const std::string& path : results) { - if (absl::EndsWith(path, kHloProtoSuffix)) { - if (absl::EndsWith(path, - absl::StrCat(kNoModuleIdentifier, kHloProtoSuffix))) { - return false; - } else { - return true; - } - } - } - - // Generate HLO proto. - // TODO(jiesun): Maybe generate a tag file at profile collection time, so - // don't need to read XSpace files for checking whether HLO proto exists or - // not. - return GetHloProtoFromMultiXSpaceAndSaveToFile(session_snapshot); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_hlo.h b/tensorflow/core/profiler/convert/xplane_to_hlo.h deleted file mode 100644 index 2361ba6e13d194..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_hlo.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_HLO_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_HLO_H_ - -#include - -#include "absl/strings/string_view.h" -#include "xla/service/hlo.pb.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/profiler/convert/repository.h" - -namespace tensorflow { -namespace profiler { - -// Get HLO proto by module name. -absl::StatusOr GetHloProtoByModuleName( - const SessionSnapshot& session_snapshot, absl::string_view module_name); - -// Converts multiple XSpaces to HLO protos. -// Stores the HLO protos as files in the same directory as the xspace files. -// Returns whether there are HLO protos in this profile. -absl::StatusOr ConvertMultiXSpaceToHloProto( - const SessionSnapshot& session_snapshot); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_HLO_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc deleted file mode 100644 index 733185a2747624..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h" - -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/utils/gpu_event_stats.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" -#include "tensorflow/core/profiler/utils/trace_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -void ConvertDeviceTraceXPlaneToKernelReports( - const XPlane& device_trace, - const std::function& - on_kernel_fn, - KernelReportMap* reports) { - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace); - plane.ForEachLine([&](const XLineVisitor& line) { - if (IsDerivedThreadId(line.Id())) { - return; - } - line.ForEachEvent([&](const XEventVisitor& event) { - if (event.DurationNs() == 0) return; - KernelReport kernel; - GpuEventStats stats(&event); - if (!stats.IsKernel()) return; - - kernel.set_name(std::string(event.Name())); - kernel.set_is_kernel_using_tensor_core( - IsKernelUsingTensorCore(event.Name())); - kernel.set_total_duration_ns(event.DurationNs()); - kernel.set_min_duration_ns(event.DurationNs()); - kernel.set_max_duration_ns(event.DurationNs()); - ParseKernelLaunchParams(stats.kernel_details, &kernel); - - if (stats.IsTfOp()) { - tsl::profiler::TfOp tf_op = - tsl::profiler::ParseTfOpFullname(stats.tf_op_fullname); - kernel.set_op_name(std::string(tf_op.name)); - bool tensor_core_eligible = - IsEinsumTensorCoreEligible(stats.equation) || - IsOpTensorCoreEligible(kernel.op_name()); - if (!tensor_core_eligible && kernel.is_kernel_using_tensor_core()) { - VLOG(1) << "Detected new Op using TensorCores: " << kernel.op_name() - << std::endl; - tensor_core_eligible = true; - } - kernel.set_is_op_tensor_core_eligible(tensor_core_eligible); - } - - if (on_kernel_fn) { - on_kernel_fn(stats, &kernel); - } - - KernelReportValue value; - value.total_duration_ns = event.DurationNs(); - value.min_duration_ns = event.DurationNs(); - value.max_duration_ns = event.DurationNs(); - value.occurrences = 1; - InsertOrUpdateKernelReport(kernel, value, reports); - }); - }); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h deleted file mode 100644 index 57607d06337b52..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_KERNEL_STATS_DB_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_KERNEL_STATS_DB_H_ - -#include -#include - -#include "absl/log/log.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/utils/gpu_event_stats.h" -#include "tensorflow/core/profiler/utils/hlo_module_map.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -void ConvertDeviceTraceXPlaneToKernelReports( - const XPlane& device_trace, - const std::function& - on_kernel_fn, - KernelReportMap* reports); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_KERNEL_STATS_DB_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc deleted file mode 100644 index a675e69248a81a..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h" - -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -TEST(ConvertXplaneToKernelStats, MultiKernels) { - XSpace space; - XPlane* device_trace = space.add_planes(); - XPlaneBuilder device_trace_builder(device_trace); - - // Empty default stream - device_trace_builder.GetOrCreateLine(0); - - XLineBuilder line_builder = device_trace_builder.GetOrCreateLine(0); - CreateXEvent(&device_trace_builder, &line_builder, "kernel_name_shortest", - /*offset_ps=*/10000, /*duration_ps=*/1000, - {{StatType::kTfOp, "mul_786"}, - {StatType::kKernelDetails, R"MULTI(regs:16 -static_shared:0 -dynamic_shared:0 -grid:1,1,1 -block:1,1,1 -occ_pct:50.0)MULTI"}, - {StatType::kEquation, ""}}); - - CreateXEvent(&device_trace_builder, &line_builder, "kernel_name_middle", - /*offset_ps=*/20000, /*duration_ps=*/2000, - {{StatType::kTfOp, "Conv2D"}, - {StatType::kKernelDetails, R"MULTI(regs:32 -static_shared:0 -dynamic_shared:16384 -grid:2,1,1 -block:32,1,1 -occ_pct=13.0)MULTI"}, - {StatType::kEquation, ""}}); - - CreateXEvent(&device_trace_builder, &line_builder, - "volta_fp16_s884gemm_fp16_128x128_ldg8_f2f_tn", - /*offset_ps=*/30000, /*duration_ps=*/3000, - {{StatType::kTfOp, "Einsum_80"}, - {StatType::kKernelDetails, R"MULTI(regs:32 -static_shared:0 -dynamic_shared:16384 -grid:3,1,1 -block:64,1,1 -occ_pct:25.0)MULTI"}, - {StatType::kEquation, ""}}); - - KernelReportMap reports; - ConvertDeviceTraceXPlaneToKernelReports(*device_trace, {}, &reports); - KernelStatsDb kernel_stats; - CopyTopKDurationKernelReportsToDb(reports, &kernel_stats); - - EXPECT_EQ(kernel_stats.reports_size(), 3); - - { - const auto& kernel = kernel_stats.reports().at(2); - EXPECT_EQ(kernel.name(), "kernel_name_shortest"); - EXPECT_EQ(kernel.registers_per_thread(), 16); - EXPECT_EQ(kernel.static_shmem_bytes(), 0); - EXPECT_EQ(kernel.dynamic_shmem_bytes(), 0); - EXPECT_EQ(kernel.grid_dim().at(0), 1); - EXPECT_EQ(kernel.grid_dim().at(1), 1); - EXPECT_EQ(kernel.grid_dim().at(2), 1); - EXPECT_EQ(kernel.block_dim().at(0), 1); - EXPECT_EQ(kernel.block_dim().at(1), 1); - EXPECT_EQ(kernel.block_dim().at(2), 1); - EXPECT_EQ(kernel.total_duration_ns(), 1); - EXPECT_FALSE(kernel.is_kernel_using_tensor_core()); - EXPECT_FALSE(kernel.is_op_tensor_core_eligible()); - EXPECT_EQ(kernel.op_name(), "mul_786"); - } - - { - const auto& kernel = kernel_stats.reports().at(1); - EXPECT_EQ(kernel.name(), "kernel_name_middle"); - EXPECT_EQ(kernel.registers_per_thread(), 32); - EXPECT_EQ(kernel.static_shmem_bytes(), 0); - EXPECT_EQ(kernel.dynamic_shmem_bytes(), 16384); - EXPECT_EQ(kernel.grid_dim().at(0), 2); - EXPECT_EQ(kernel.grid_dim().at(1), 1); - EXPECT_EQ(kernel.grid_dim().at(2), 1); - EXPECT_EQ(kernel.block_dim().at(0), 32); - EXPECT_EQ(kernel.block_dim().at(1), 1); - EXPECT_EQ(kernel.block_dim().at(2), 1); - EXPECT_EQ(kernel.total_duration_ns(), 2); - EXPECT_FALSE(kernel.is_kernel_using_tensor_core()); - EXPECT_TRUE(kernel.is_op_tensor_core_eligible()); - EXPECT_EQ(kernel.op_name(), "Conv2D"); - } - - { - const auto& kernel = kernel_stats.reports().at(0); - EXPECT_EQ(kernel.name(), "volta_fp16_s884gemm_fp16_128x128_ldg8_f2f_tn"); - EXPECT_EQ(kernel.registers_per_thread(), 32); - EXPECT_EQ(kernel.static_shmem_bytes(), 0); - EXPECT_EQ(kernel.dynamic_shmem_bytes(), 16384); - EXPECT_EQ(kernel.grid_dim().at(0), 3); - EXPECT_EQ(kernel.grid_dim().at(1), 1); - EXPECT_EQ(kernel.grid_dim().at(2), 1); - EXPECT_EQ(kernel.block_dim().at(0), 64); - EXPECT_EQ(kernel.block_dim().at(1), 1); - EXPECT_EQ(kernel.block_dim().at(2), 1); - EXPECT_EQ(kernel.total_duration_ns(), 3); - EXPECT_TRUE(kernel.is_kernel_using_tensor_core()); - EXPECT_TRUE(kernel.is_op_tensor_core_eligible()); - EXPECT_EQ(kernel.op_name(), "Einsum_80"); - } -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc deleted file mode 100644 index f996000579301a..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc +++ /dev/null @@ -1,571 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/profiler/protobuf/memory_profile.pb.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -constexpr int64_t kInvalidStepId = -1; - -// Index of the time-sorted memory_profile_snapshots list, and the -// MemoryActivityMetadata proto it contains. -using IndexMetaPair = - std::pair; - -bool IsMemoryAllocation(int64_t event_type) { - return event_type == HostEventType::kMemoryAllocation; -} - -bool IsMemoryDeallocation(int64_t event_type) { - return event_type == HostEventType::kMemoryDeallocation; -} - -void UpdateProfileSummary(const MemoryAggregationStats& stats, - int64_t time_offset_ps, - MemoryProfileSummary* summary) { - // Update the peak memory usage over allocator's lifetime. - summary->set_peak_bytes_usage_lifetime(stats.peak_bytes_in_use()); - MemoryAggregationStats* peak_stats = summary->mutable_peak_stats(); - // If we reach (or stay at) peak memory usage within the profiling window, - // update memory profile summary. - if (stats.stack_reserved_bytes() + stats.heap_allocated_bytes() >= - peak_stats->peak_bytes_in_use()) { - *peak_stats = stats; - peak_stats->set_peak_bytes_in_use(stats.stack_reserved_bytes() + - stats.heap_allocated_bytes()); - summary->set_peak_stats_time_ps(time_offset_ps); - summary->set_memory_capacity(stats.stack_reserved_bytes() + - stats.heap_allocated_bytes() + - stats.free_memory_bytes()); - } -} - -// Generate memory profile proto by processing host trace XPlane. -MemoryProfile GenerateMemoryProfile(const XPlane* host_trace) { - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - MemoryProfile memory_profile; - // Iterate over all XEvents in the XPlane, and add the XStats to a new - // MemoryProfileSnapshot if the EventType is kMemoryAllocation or - // kMemoryDeallocation. - plane.ForEachLine([&](const XLineVisitor& line) { - line.ForEachEvent([&](const XEventVisitor& event) { - int64_t event_type = - event.Type().value_or(HostEventType::kUnknownHostEventType); - if (!(IsMemoryAllocation(event_type) || - IsMemoryDeallocation(event_type))) { - return; - } - - MemoryAggregationStats stats; - MemoryActivityMetadata metadata; - if (IsMemoryAllocation(event_type)) { - metadata.set_memory_activity(ALLOCATION); - } else if (IsMemoryDeallocation(event_type)) { - metadata.set_memory_activity(DEALLOCATION); - } - metadata.set_step_id(kInvalidStepId); - - std::string memory_id; - event.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kIndexOnHost: - case StatType::kDeviceOrdinal: - memory_id = absl::StrCat(stat.IntValue()); - break; - case StatType::kAllocatorName: - memory_id = std::string(stat.StrOrRefValue()); - break; - case StatType::kBytesReserved: - stats.set_stack_reserved_bytes(stat.IntValue()); - break; - case StatType::kBytesAllocated: - stats.set_heap_allocated_bytes(stat.IntValue()); - break; - case StatType::kBytesAvailable: - stats.set_free_memory_bytes(stat.IntValue()); - break; - case StatType::kFragmentation: - stats.set_fragmentation(stat.DoubleValue()); - break; - case StatType::kPeakBytesInUse: - stats.set_peak_bytes_in_use(stat.IntValue()); - break; - case StatType::kRequestedBytes: - metadata.set_requested_bytes(stat.IntValue()); - break; - case StatType::kAllocationBytes: - metadata.set_allocation_bytes(stat.IntValue()); - break; - case StatType::kAddress: - metadata.set_address(stat.IntValue()); - break; - case StatType::kTfOp: - metadata.set_tf_op_name(std::string(stat.StrOrRefValue())); - break; - case StatType::kGroupId: - metadata.set_step_id(stat.IntValue()); - break; - case StatType::kRegionType: - metadata.set_region_type(std::string(stat.StrOrRefValue())); - break; - case StatType::kDataType: - metadata.set_data_type(tensorflow::DataTypeString( - static_cast(stat.IntValue()))); - break; - case StatType::kTensorShapes: - metadata.set_tensor_shape(std::string(stat.StrOrRefValue())); - break; - } - }); - - MemoryProfileSummary* summary = - (*memory_profile.mutable_memory_profile_per_allocator())[memory_id] - .mutable_profile_summary(); - UpdateProfileSummary(stats, event.OffsetPs(), summary); - - MemoryProfileSnapshot* snapshot = - (*memory_profile.mutable_memory_profile_per_allocator())[memory_id] - .add_memory_profile_snapshots(); - snapshot->set_time_offset_ps(event.OffsetPs()); - *snapshot->mutable_aggregation_stats() = std::move(stats); - *snapshot->mutable_activity_metadata() = std::move(metadata); - }); - }); - return memory_profile; -} - -// Fix invalid step ids of snapshots at the beginning/end of the profile or at -// the step boundaries. The snapshots with invalid step ids at the beginning get -// 0 for their step ids. Those at the step boundaries or at the end get the -// previous snapshot's step id + 1. -void UpdateStepId(PerAllocatorMemoryProfile* memory_profile) { - int64_t last_valid_step_id = -1; - // Snapshots are already sorted in time. - for (auto& snapshot : *memory_profile->mutable_memory_profile_snapshots()) { - DCHECK(snapshot.has_activity_metadata()); - if (snapshot.mutable_activity_metadata()->step_id() == kInvalidStepId) { - snapshot.mutable_activity_metadata()->set_step_id(last_valid_step_id + 1); - } else { - last_valid_step_id = snapshot.mutable_activity_metadata()->step_id(); - } - } -} - -// Update the MemoryActivityMetadata for each deallocation event by copying from -// matching allocation. -void UpdateDeallocation(PerAllocatorMemoryProfile* memory_profile) { - absl::flat_hash_map - addr_metadata_map; - for (auto& snapshot : *memory_profile->mutable_memory_profile_snapshots()) { - // Match the deallocation with previous allocation based on address. - uint64 address = snapshot.activity_metadata().address(); - if (snapshot.activity_metadata().memory_activity() == DEALLOCATION) { - if (addr_metadata_map.contains(address)) { - const MemoryActivityMetadata* alloc_meta = addr_metadata_map[address]; - snapshot.mutable_activity_metadata()->set_tf_op_name( - alloc_meta->tf_op_name()); - snapshot.mutable_activity_metadata()->set_region_type( - alloc_meta->region_type()); - snapshot.mutable_activity_metadata()->set_data_type( - alloc_meta->data_type()); - snapshot.mutable_activity_metadata()->set_tensor_shape( - alloc_meta->tensor_shape()); - // In case of following (unexpected) deallocations to the same chunk - // address, leave the metadata as it is (empty or already captured). - addr_metadata_map.erase(address); - } else { - VLOG(2) - << "Can't find matching memory allocation for this deallocation: " - << snapshot.DebugString(); - } - } else if (!addr_metadata_map.contains(address)) { // Allocation. - addr_metadata_map[address] = &snapshot.activity_metadata(); - } else { - VLOG(2) << "There are two allocations recorded for the same address: " - << address - << ". The later allocation event is: " << snapshot.DebugString(); - } - } - VLOG(2) << "Number of allocations that cannot find matching dealloctions: " - << addr_metadata_map.size(); -} - -// Return the step id for the peak memory usage data point. -int64_t GetPeakMemoryStep(int64_t peak_bytes_profile, - const PerAllocatorMemoryProfile* memory_profile) { - int64_t peak_bytes_profile_step_id = 0; - for (const auto& snapshot : memory_profile->memory_profile_snapshots()) { - // Get the step id of the peak memory usage. - if (peak_bytes_profile == - snapshot.aggregation_stats().heap_allocated_bytes() + - snapshot.aggregation_stats().stack_reserved_bytes()) { - DCHECK(snapshot.has_activity_metadata()); - peak_bytes_profile_step_id = snapshot.activity_metadata().step_id(); - } - } - return peak_bytes_profile_step_id; -} - -// Functor that compares (index, metadata) pair to sort in the order of -// allocation bytes and requested bytes (descending), as well as TF Op name, -// region type, data type, and tensor shape (ascending). -struct MetadataComparator { - bool operator()(const IndexMetaPair& a, const IndexMetaPair& b) const { - const MemoryActivityMetadata* a_meta = a.second; - const MemoryActivityMetadata* b_meta = b.second; - DCHECK_NE(a_meta, nullptr); - DCHECK_NE(b_meta, nullptr); - - auto lhs = - std::make_tuple(-a_meta->allocation_bytes(), -a_meta->requested_bytes(), - a_meta->tf_op_name(), a_meta->region_type(), - a_meta->data_type(), a_meta->tensor_shape()); - auto rhs = - std::make_tuple(-b_meta->allocation_bytes(), -b_meta->requested_bytes(), - b_meta->tf_op_name(), b_meta->region_type(), - b_meta->data_type(), b_meta->tensor_shape()); - return lhs < rhs; - } -}; - -// If applicable, add items into active_allocs vector and special_allocations -// proto for the unmapped memory usage (in heap) and stack reservation at peak. -void InsertSpecialAllocations(int64_t unmapped_allocation_bytes, - int64_t step_id, - PerAllocatorMemoryProfile* memory_profile, - std::vector* active_allocs) { - int index = 0; - if (unmapped_allocation_bytes > 0) { - MemoryActivityMetadata* special_allocation = - memory_profile->add_special_allocations(); - special_allocation->set_memory_activity(ALLOCATION); - special_allocation->set_requested_bytes(unmapped_allocation_bytes); - special_allocation->set_allocation_bytes(unmapped_allocation_bytes); - special_allocation->set_address(0); - special_allocation->set_tf_op_name("unused preallocated device memory"); - special_allocation->set_step_id(step_id); - special_allocation->set_region_type("persist/dynamic"); - special_allocation->set_data_type( - tensorflow::DataTypeString(static_cast(0))); - special_allocation->set_tensor_shape("unknown"); - active_allocs->push_back({--index, special_allocation}); - } - int64_t stack_bytes = - memory_profile->profile_summary().peak_stats().stack_reserved_bytes(); - if (stack_bytes > 0) { - MemoryActivityMetadata* special_allocation = - memory_profile->add_special_allocations(); - special_allocation->set_memory_activity(ALLOCATION); - special_allocation->set_requested_bytes(stack_bytes); - special_allocation->set_allocation_bytes(stack_bytes); - special_allocation->set_address(0); - special_allocation->set_tf_op_name("stack"); - special_allocation->set_step_id(step_id); - special_allocation->set_region_type("stack"); - special_allocation->set_data_type( - tensorflow::DataTypeString(static_cast(0))); - special_allocation->set_tensor_shape("unknown"); - active_allocs->push_back({--index, special_allocation}); - } -} - -bool operator==(const IndexMetaPair& a, const IndexMetaPair& b) { - const MemoryActivityMetadata* a_meta = a.second; - const MemoryActivityMetadata* b_meta = b.second; - return a_meta->allocation_bytes() == b_meta->allocation_bytes() && - a_meta->requested_bytes() == b_meta->requested_bytes() && - a_meta->tf_op_name() == b_meta->tf_op_name() && - a_meta->region_type() == b_meta->region_type() && - a_meta->data_type() == b_meta->data_type() && - a_meta->tensor_shape() == b_meta->tensor_shape(); -} - -// Generate the memory breakdown table of active allocations at the peak usage -// (within profiling window) and fill each ActiveAllocation proto (i.e. a row). -void ProcessActiveAllocations(int64_t peak_bytes_profile_step_id, - PerAllocatorMemoryProfile* memory_profile) { - int64_t unmapped_allocation_bytes = - memory_profile->profile_summary().peak_stats().heap_allocated_bytes(); - int64_t unmapped_deallocation_bytes = 0; - absl::flat_hash_map active_alloc_map; - // Only account for the memory activities in the step that includes peak - // memory usage. - for (int i = 0; i < memory_profile->memory_profile_snapshots_size(); i++) { - const auto& snapshot = memory_profile->memory_profile_snapshots().at(i); - DCHECK(snapshot.has_activity_metadata()); - const MemoryActivityMetadata& metadata = snapshot.activity_metadata(); - if (snapshot.time_offset_ps() > - memory_profile->profile_summary().peak_stats_time_ps()) - break; - if (metadata.step_id() != peak_bytes_profile_step_id) continue; - - if (metadata.memory_activity() == ALLOCATION) { - active_alloc_map[metadata.address()] = {i, &metadata}; - unmapped_allocation_bytes -= metadata.allocation_bytes(); - } else { - DCHECK_EQ(metadata.memory_activity(), DEALLOCATION); - if (active_alloc_map.contains(metadata.address())) { - active_alloc_map.erase(metadata.address()); - } else { - unmapped_deallocation_bytes += metadata.allocation_bytes(); - } - unmapped_allocation_bytes += metadata.allocation_bytes(); - } - } - // This separates the persistent memory from the freed memory from last step's - // allocations. - unmapped_allocation_bytes -= unmapped_deallocation_bytes; - - VLOG(2) << "unmapped_allocation_bytes=" << unmapped_allocation_bytes - << ", unmapped_deallocation_bytes=" << unmapped_deallocation_bytes; - - // Using pair of (index, MemoryActivityMetadata*) so that we can sort by the - // metadata, and fetch metadata by indexing the time-sorted snapshots at - // frontend. - std::vector active_allocs; - for (const auto& address_and_index_meta : active_alloc_map) { - active_allocs.push_back(address_and_index_meta.second); - } - - InsertSpecialAllocations(unmapped_allocation_bytes, - peak_bytes_profile_step_id, memory_profile, - &active_allocs); - - std::sort(active_allocs.begin(), active_allocs.end(), MetadataComparator()); - - // Fill the sorted active_allocations proto messages at peak memory usage. - // Merge identical allocations and show occurrences. - for (int i = 0, end = active_allocs.size(); i < end; i++) { - ActiveAllocation* allocation = memory_profile->add_active_allocations(); - allocation->set_snapshot_index(active_allocs[i].first); - if (active_allocs[i].first < 0) { - allocation->set_special_index(-active_allocs[i].first - 1); - } else { - allocation->set_special_index(-1); - } - allocation->set_num_occurrences(1); - const int last_alloc = active_allocs.size() - 1; - while (i < last_alloc && active_allocs[i] == active_allocs[i + 1]) { - allocation->set_num_occurrences(allocation->num_occurrences() + 1); - i++; - } - } - - VLOG(2) << "Distinctive active allocation count=" - << memory_profile->active_allocations_size(); -} - -// This function saves the MemoryProfileSnapshots referenced by -// max_num_snapshots. -void SaveActiveAllocationSnapshots( - tsl::protobuf::RepeatedPtrField* snapshots, - tsl::protobuf::RepeatedPtrField* active_allocations) { - std::vector samples; - // Puts the snapshots referenced by active_allocations in . - for (const auto& allocation : *active_allocations) { - auto orig_index = allocation.snapshot_index(); - if (orig_index < 0) continue; - samples.push_back(&(*snapshots)[orig_index]); - } - - // Change the reference index in . - int new_index = 0; - for (auto& allocation : *active_allocations) { - int64_t origin_index = allocation.snapshot_index(); - if (origin_index < 0) continue; - allocation.set_snapshot_index(new_index); - new_index++; - } - - tsl::protobuf::RepeatedPtrField new_snapshots; - new_snapshots.Reserve(samples.size()); - for (const auto& sample : samples) { - *new_snapshots.Add() = std::move(*sample); - } - *snapshots = std::move(new_snapshots); -} - -// Sample memory profile snapshots from the original memory -// profile data. -void SampleMemoryProfileTimeline(int64_t max_num_snapshots, - PerAllocatorMemoryProfile* memory_profile) { - const tsl::protobuf::RepeatedPtrField& - original_snapshots = memory_profile->memory_profile_snapshots(); - tsl::protobuf::RepeatedPtrField* timeline_snapshots = - memory_profile->mutable_sampled_timeline_snapshots(); - int64_t snapshot_count = original_snapshots.size(); - if (snapshot_count > max_num_snapshots) { - // When there are more memory profile data than , we - // sample the origin data using a max box filter. Filter width is - // , collect samples starting from the index - // in the original snapshots. - auto max_box_filter = [&](int filter_width, int count, int start) { - for (int i = 0; i < count; i++) { - // Use a max function to get the MemoryProfileSnapshot with the largest - // memory usage in the box filter. - const MemoryProfileSnapshot* max_snapshot = - &original_snapshots[start + filter_width * i]; - int64_t max_bytes = - max_snapshot->aggregation_stats().heap_allocated_bytes() + - max_snapshot->aggregation_stats().stack_reserved_bytes(); - for (int index = start + filter_width * i + 1; - index < start + filter_width * (i + 1); index++) { - int64_t bytes = original_snapshots[index] - .aggregation_stats() - .heap_allocated_bytes() + - original_snapshots[index] - .aggregation_stats() - .stack_reserved_bytes(); - if (bytes > max_bytes) { - max_snapshot = &original_snapshots[index]; - max_bytes = bytes; - } - } - *timeline_snapshots->Add() = *max_snapshot; - } - }; - - int width = snapshot_count / max_num_snapshots; - int count1 = max_num_snapshots * (width + 1) - snapshot_count; - int count2 = max_num_snapshots - count1; - - // Collect samples with box filter width , then collect - // samples with box filter width , the total number of - // samples collected will be . - max_box_filter(width, count1, 0); - max_box_filter(width + 1, count2, width * count1); - } else { - // When the number of original snapshots are smaller than - // , just copy all the data points to the timeline. - *timeline_snapshots = original_snapshots; - } -} - -// Post-process the memory profile to correctly update proto fields, and break -// down peak memory usage for each allocator. -void ProcessMemoryProfileProto(int64_t max_num_snapshots, - MemoryProfile* memory_profile) { - memory_profile->set_num_hosts(1); - // Add sorted memory ids within memory profile data to the selection list. - for (const auto& id_and_allocator_profile : - memory_profile->memory_profile_per_allocator()) { - if (!id_and_allocator_profile.second.memory_profile_snapshots().empty()) { - memory_profile->add_memory_ids(id_and_allocator_profile.first); - } - } - absl::c_sort(*memory_profile->mutable_memory_ids()); - - for (auto& id_and_allocator_profile : - *memory_profile->mutable_memory_profile_per_allocator()) { - PerAllocatorMemoryProfile* allocator_memory_profile = - &id_and_allocator_profile.second; - tsl::protobuf::RepeatedPtrField* snapshots = - allocator_memory_profile->mutable_memory_profile_snapshots(); - // Sort the memory_profile_snapshots by time_offset_ps (ascending) in proto. - absl::c_sort(*snapshots, [](const MemoryProfileSnapshot& a, - const MemoryProfileSnapshot& b) { - return a.time_offset_ps() < b.time_offset_ps(); - }); - - UpdateStepId(allocator_memory_profile); - UpdateDeallocation(allocator_memory_profile); - - // Sample a subset of MemoryProfileSnapshots to display in the frontend - // memory timeline graph. - SampleMemoryProfileTimeline(max_num_snapshots, allocator_memory_profile); - - int64_t peak_step_id = - GetPeakMemoryStep(allocator_memory_profile->profile_summary() - .peak_stats() - .peak_bytes_in_use(), - allocator_memory_profile); - ProcessActiveAllocations(peak_step_id, allocator_memory_profile); - SaveActiveAllocationSnapshots( - snapshots, allocator_memory_profile->mutable_active_allocations()); - } -} - -template -absl::Status ConvertProtoToJson(const Proto& proto_output, - std::string* json_output) { - tsl::protobuf::util::JsonPrintOptions json_options; - json_options.always_print_primitive_fields = true; - auto status = tsl::protobuf::util::MessageToJsonString( - proto_output, json_output, json_options); - if (!status.ok()) { - // Convert error_msg google::protobuf::StringPiece (or absl::string_view) to - // tensorflow::StringPiece. - auto error_msg = status.message(); - return errors::Internal( - "Could not convert proto to JSON string: ", - absl::string_view(error_msg.data(), error_msg.length())); - } - return absl::OkStatus(); -} - -} // namespace - -MemoryProfile ConvertXPlaneToMemoryProfile(const XPlane& host_plane, - int64_t max_num_snapshots) { - MemoryProfile memory_profile = GenerateMemoryProfile(&host_plane); - ProcessMemoryProfileProto(max_num_snapshots, &memory_profile); - // Default version number is 0, set version number to 1 here due to the new - // memory profile sampling algorithm. - memory_profile.set_version(1); - return memory_profile; -} - -absl::Status ConvertXSpaceToMemoryProfileJson(const XSpace& xspace, - std::string* json_output) { - if (const XPlane* host_plane = - FindPlaneWithName(xspace, kHostThreadsPlaneName)) { - MemoryProfile memory_profile = ConvertXPlaneToMemoryProfile(*host_plane); - TF_RETURN_IF_ERROR(ConvertProtoToJson(memory_profile, json_output)); - } - return absl::OkStatus(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.h b/tensorflow/core/profiler/convert/xplane_to_memory_profile.h deleted file mode 100644 index 00f919d4dbd42e..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_MEMORY_PROFILE_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_MEMORY_PROFILE_H_ - -#include - -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/memory_profile.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Process the host threads XPlane and generate MemoryProfile result; at most -// max_num_snapshots will be displayed on the UI. -// REQUIRED: host_plane should have been grouped by calling GroupTfEvents(). -MemoryProfile ConvertXPlaneToMemoryProfile(const XPlane& host_plane, - int64_t max_num_snapshots = 1000); - -absl::Status ConvertXSpaceToMemoryProfileJson(const XSpace& xspace, - std::string* json_output); -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_MEMORY_PROFILE_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc deleted file mode 100644 index a60d505cfc786f..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h" - -#include - -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/memory_profile.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -// Tests with a sample profile with multiple memory allocation and deallocation -// activities within one memory allocator captured in host trace. -TEST(ConvertXPlaneToMemoryProfile, OneAllocatorMultiActivitiesTest) { - XSpace space; - XPlane* host_plane = GetOrCreateHostXPlane(&space); - XPlaneBuilder host_plane_builder(host_plane); - host_plane_builder.ReserveLines(1); - - auto tf_executor_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &tf_executor_thread, "MemoryAllocation", - 40000, 1000, - {{StatType::kBytesReserved, int64_t{2000}}, - {StatType::kBytesAllocated, int64_t{3000}}, - {StatType::kBytesAvailable, int64_t{5000}}, - {StatType::kPeakBytesInUse, int64_t{8500}}, - {StatType::kRequestedBytes, int64_t{200}}, - {StatType::kAllocationBytes, int64_t{256}}, - {StatType::kAddress, int64_t{222333}}, - {StatType::kStepId, int64_t{-93746}}, - {StatType::kDataType, int64_t{1}}, - {StatType::kAllocatorName, "GPU_0_bfc"}, - {StatType::kTfOp, "foo/bar"}, - {StatType::kRegionType, "output"}, - {StatType::kTensorShapes, "[3, 3, 512, 512]"}}); - - CreateXEvent(&host_plane_builder, &tf_executor_thread, "MemoryDeallocation", - 50000, 1000, - {{StatType::kBytesReserved, int64_t{2000}}, - {StatType::kBytesAllocated, int64_t{2744}}, - {StatType::kBytesAvailable, int64_t{5256}}, - {StatType::kPeakBytesInUse, int64_t{8500}}, - {StatType::kRequestedBytes, int64_t{200}}, - {StatType::kAllocationBytes, int64_t{256}}, - {StatType::kAddress, int64_t{222333}}, - {StatType::kStepId, int64_t{0}}, - {StatType::kDataType, int64_t{0}}, - {StatType::kAllocatorName, "GPU_0_bfc"}, - {StatType::kRegionType, ""}, - {StatType::kTensorShapes, ""}}); - - CreateXEvent(&host_plane_builder, &tf_executor_thread, "MemoryAllocation", - 70000, 1000, - {{StatType::kBytesReserved, int64_t{2000}}, - {StatType::kBytesAllocated, int64_t{5000}}, - {StatType::kBytesAvailable, int64_t{3000}}, - {StatType::kPeakBytesInUse, int64_t{9500}}, - {StatType::kRequestedBytes, int64_t{300}}, - {StatType::kAllocationBytes, int64_t{300}}, - {StatType::kAddress, int64_t{345678}}, - {StatType::kStepId, int64_t{-93746}}, - {StatType::kDataType, int64_t{9}}, - {StatType::kAllocatorName, "GPU_0_bfc"}, - {StatType::kTfOp, "mul_grad/Sum"}, - {StatType::kRegionType, "temp"}, - {StatType::kTensorShapes, "[1, 2]"}}); - - tsl::profiler::GroupTfEvents(&space); - MemoryProfile memory_profile = ConvertXPlaneToMemoryProfile(*host_plane); - EXPECT_EQ(memory_profile.memory_profile_per_allocator().size(), 1); - EXPECT_EQ(memory_profile.num_hosts(), 1); - EXPECT_EQ(memory_profile.memory_ids_size(), 1); - EXPECT_EQ(memory_profile.memory_profile_per_allocator().begin()->first, - "GPU_0_bfc"); - EXPECT_EQ(memory_profile.version(), 1); - const auto& allocator_memory_profile = - memory_profile.memory_profile_per_allocator().begin()->second; - EXPECT_EQ( - allocator_memory_profile.profile_summary().peak_bytes_usage_lifetime(), - 9500); - EXPECT_EQ(allocator_memory_profile.profile_summary() - .peak_stats() - .peak_bytes_in_use(), - 7000); - EXPECT_EQ(allocator_memory_profile.profile_summary().peak_stats_time_ps(), - 70000); - EXPECT_EQ(allocator_memory_profile.sampled_timeline_snapshots_size(), 3); - EXPECT_EQ(allocator_memory_profile.memory_profile_snapshots_size(), 1); - EXPECT_EQ(allocator_memory_profile.memory_profile_snapshots() - .at(0) - .activity_metadata() - .tf_op_name(), - "mul_grad/Sum"); - EXPECT_EQ(allocator_memory_profile.active_allocations_size(), 3); - EXPECT_EQ( - allocator_memory_profile.active_allocations().at(2).snapshot_index(), 0); - EXPECT_EQ(allocator_memory_profile.special_allocations_size(), 2); - EXPECT_EQ(allocator_memory_profile.special_allocations().at(1).tf_op_name(), - "stack"); - EXPECT_EQ( - allocator_memory_profile.special_allocations().at(1).allocation_bytes(), - 2000); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc deleted file mode 100644 index b216f95a4a2fcf..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ /dev/null @@ -1,332 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/convert/op_stack.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/cost_utils.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tensorflow/core/profiler/utils/op_utils.h" -#include "tensorflow/core/profiler/utils/trace_utils.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using tsl::profiler::GetDeviceEventTimespan; - -// Type of a TensorFlow Op activity, which is either beginning or ending an Op. -enum TfActivityType { kTfOpBegin, kTfOpEnd }; - -// Instant activity representing the begin or end of a host-side TF Op. -struct TfActivity { - // The timestamp in picoseconds when this activity happened. - uint64 timestamp_ps; - // The ID of this Op. - uint32 tf_op_id; - // Type of this activity. - TfActivityType activity_type; - // Full TF op name and type of this activity (backed by XEvent::name). - tsl::profiler::TfOp tf_op; - // Whether it is eagerly executed. - bool is_eager; -}; - -// TF Op metrics stored as element in OpStack. -struct TfOpInfo { - explicit TfOpInfo(uint64 ts) : start_timestamp_ps(ts) {} - - // Start timestamp in picoseconds. - uint64 start_timestamp_ps; - // Children duration in picoseconds. - uint64 children_duration_ps = 0; -}; - -// Processes a TF-activity on particular core. -void ProcessOneTfActivity(const TfActivity& activity, - OpStack* tf_op_stack, - TfMetricsDbData* tf_metrics_data) { - uint32 tf_op_id = activity.tf_op_id; - switch (activity.activity_type) { - case kTfOpBegin: { - tf_op_stack->Push(tf_op_id, - std::make_unique(activity.timestamp_ps)); - break; - } - case kTfOpEnd: { - std::unique_ptr info = tf_op_stack->Pop(tf_op_id); - if (info == nullptr) { - // This happens if TraceMes overlap. - VLOG(1) << "No begin event found for TF activity id=" << tf_op_id - << " name=" << activity.tf_op.name - << " type=" << activity.tf_op.type; - break; - } - tsl::profiler::Timespan tf_op_span = tsl::profiler::PicoSpan( - info->start_timestamp_ps, activity.timestamp_ps); - tf_metrics_data->tf_metrics_db_builder.EnterOp( - activity.tf_op.name, activity.tf_op.type, activity.is_eager, - tf_op_span.duration_ps(), info->children_duration_ps); - TfOpInfo* parent_info = tf_op_stack->Top(); - if (parent_info != nullptr) { - parent_info->children_duration_ps += tf_op_span.duration_ps(); - } - if (tsl::profiler::IsInfeedEnqueueOp(activity.tf_op.type)) { - tf_metrics_data->tf_metrics_db_builder.EnterHostInfeedEnqueue( - tf_op_span); - } - break; - } - } -} - -// Processes all TF-activities on the given core. -void ProcessTfActivities(std::vector* tf_activities, - TfMetricsDbData* tf_metrics_db_data) { - if (tf_activities->empty()) return; - absl::c_stable_sort(*tf_activities, - [](const TfActivity& a, const TfActivity& b) { - return a.timestamp_ps < b.timestamp_ps; - }); - OpStack tf_op_stack; - for (const auto& tf_activity : *tf_activities) { - ProcessOneTfActivity(tf_activity, &tf_op_stack, tf_metrics_db_data); - } - SetTotalTimePs( - tf_metrics_db_data->tf_metrics_db, - tf_activities->back().timestamp_ps - tf_activities->front().timestamp_ps); -} - -void CollectTfActivities( - const XLineVisitor& line, - const absl::flat_hash_map& tf_ops, - std::vector* tf_activities) { - uint32 tf_op_id = 0; - if (IsDerivedThreadId(line.Id())) return; - tf_activities->reserve(line.NumEvents() * 2); - line.ForEachEvent( - [&tf_ops, &tf_op_id, &tf_activities](const XEventVisitor& event) { - const tsl::profiler::TfOp* tf_op = gtl::FindOrNull(tf_ops, event.Id()); - if (tf_op != nullptr) { - ++tf_op_id; - bool is_eager = false; - if (std::optional stat = - event.GetStat(StatType::kIsEager)) { - is_eager = stat->IntValue(); - } - tsl::profiler::Timespan span = event.GetTimespan(); - tf_activities->push_back( - {span.begin_ps(), tf_op_id, kTfOpBegin, *tf_op, is_eager}); - tf_activities->push_back( - {span.end_ps(), tf_op_id, kTfOpEnd, *tf_op, is_eager}); - } - if (auto tf_op_stat = event.GetStat(StatType::kTfOp); - tf_op_stat.has_value()) { - ++tf_op_id; - tsl::profiler::TfOp tf_op = - tsl::profiler::ParseTfOpFullname(tf_op_stat->StrOrRefValue()); - tsl::profiler::Timespan span = event.GetTimespan(); - tf_activities->push_back( - {span.begin_ps(), tf_op_id, kTfOpBegin, tf_op, false}); - tf_activities->push_back( - {span.end_ps(), tf_op_id, kTfOpEnd, tf_op, false}); - } - }); -} - -} // namespace - -absl::flat_hash_map -CollectTfOpsFromHostThreadsXPlane(const XPlane& host_trace) { - absl::flat_hash_map tf_ops; - for (const auto& id_metadata : host_trace.event_metadata()) { - const XEventMetadata& metadata = id_metadata.second; - // On the host, we have added some user-specified TraceMe's in addition to - // the TraceMe's added to every TensorFlow op by the system. These - // user-inserted TraceMe's have "unknown" type. We don't count them in - // Tf-stats. - tsl::profiler::TfOp tf_op = - tsl::profiler::ParseTfOpFullname(metadata.name()); - if (tf_op.category != tsl::profiler::Category::kUnknown) { - tf_ops.try_emplace(metadata.id(), tf_op); - } - } - return tf_ops; -} - -TfMetricsDbData ConvertHostThreadsXLineToTfMetricsDbData( - const XLineVisitor& line, - const absl::flat_hash_map& tf_ops) { - TfMetricsDbData tf_metrics_db_data; - std::vector tf_activities; - CollectTfActivities(line, tf_ops, &tf_activities); - ProcessTfActivities(&tf_activities, &tf_metrics_db_data); - return tf_metrics_db_data; -} - -void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst) { - AddIdleOp(src.tf_metrics_db); - // Host OpMetricsDb does not need to update the number of cores a certain op - // occurs. - dst->Combine(src.tf_metrics_db, /*update_num_cores=*/false); - src.tf_metrics_db.Clear(); -} - -OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace) { - absl::flat_hash_map tf_ops = - CollectTfOpsFromHostThreadsXPlane(host_trace); - OpMetricsDb result; - OpMetricsDbCombiner combiner(&result); - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&host_trace); - plane.ForEachLine([&tf_ops, &combiner](const XLineVisitor& line) { - ConsumeTfMetricsDbData( - ConvertHostThreadsXLineToTfMetricsDbData(line, tf_ops), &combiner); - }); - return result; -} - -OpMetricsDb ConvertTpuDeviceTraceXPlaneToOpMetricsDb( - const XPlane& device_trace) { - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace); - XEventsOpMetricsDbBuilder builder; - uint64_t first_op_timestamp_ps = std::numeric_limits::max(); - uint64_t last_op_timestamp_ps = 0; - - struct ParentReference { - const XEventVisitor event; - tsl::profiler::Timespan device_timespan; - uint64_t children_duration_ps = 0; - }; - - tsl::profiler::AncestorStack event_stack( - [&](const ParentReference& parent) { - OpMetrics op_metrics = FromXEvent(parent.event); - op_metrics.set_time_ps(parent.device_timespan.duration_ps()); - op_metrics.set_self_time_ps(op_metrics.time_ps() - - parent.children_duration_ps); - builder.AddOpMetric(op_metrics, GetOpKeyFromXEvent(parent.event)); - }, - [](const ParentReference& parent, const ParentReference& child) { - return parent.device_timespan.Includes(child.device_timespan); - }, - [](ParentReference& parent, ParentReference& child) { - parent.children_duration_ps += child.device_timespan.duration_ps(); - }); - - auto track_first_and_last_op_timestamps = [&](const XEventVisitor& event) { - tsl::profiler::Timespan timespan = GetDeviceEventTimespan(event); - first_op_timestamp_ps = - std::min(first_op_timestamp_ps, timespan.begin_ps()); - last_op_timestamp_ps = std::max(last_op_timestamp_ps, timespan.end_ps()); - }; - - plane.ForEachLine([&](const XLineVisitor& line) { - if (line.Name() == tsl::profiler::kSparseCoreStepLineName || - line.Name() == tsl::profiler::kStepLineName) { - line.ForEachEvent(track_first_and_last_op_timestamps); - } - if (!tsl::profiler::IsOpLineName(line.Name())) return; - line.ForEachEvent([&](const XEventVisitor& event) { - tsl::profiler::Timespan timespan = GetDeviceEventTimespan(event); - track_first_and_last_op_timestamps(event); - - event_stack.Push({.event = event, .device_timespan = timespan}); - }); - event_stack.Flush(); - }); - - return builder.Finalize(last_op_timestamp_ps - first_op_timestamp_ps); -} - -OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace) { - OpMetricsDb result; - DeviceOpMetricsDbBuilder device_op_metrics_db_builder(&result); - - int64_t first_op_offset_ps = kint64max; - int64_t last_op_offset_ps = 0; - - TfOpRoofLineCostEstimator op_level_cost_estimator; - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace); - plane.ForEachLine([&](const XLineVisitor& line) { - if (IsDerivedThreadId(line.Id())) return; - line.ForEachEvent([&](const XEventVisitor& event) { - first_op_offset_ps = std::min(first_op_offset_ps, event.OffsetPs()); - last_op_offset_ps = std::max(last_op_offset_ps, event.EndOffsetPs()); - - absl::string_view tf_op_full_name; - bool is_eager = false; - int64_t program_id = 0; - absl::string_view deduplicated_name = ""; - event.ForEachStat([&](const XStatVisitor& stat) { - if (stat.Type() == StatType::kTfOp) { - tf_op_full_name = stat.StrOrRefValue(); - } else if (stat.Type() == StatType::kIsEager) { - is_eager = stat.IntValue(); - } else if (stat.Type() == StatType::kProgramId) { - program_id = stat.IntOrUintValue(); - } else if (stat.Type() == StatType::kDeduplicatedName) { - deduplicated_name = stat.StrOrRefValue(); - } - }); - if (tf_op_full_name.empty()) return; - tsl::profiler::TfOp tf_op = - tsl::profiler::ParseTfOpFullname(tf_op_full_name); - TfOpRoofLineCostEstimator::OpRoofLineStats costs; - if (tf_op.category != tsl::profiler::Category::kUnknown) { - costs = op_level_cost_estimator.Predict(event); - } - device_op_metrics_db_builder.EnterOp( - /*program_id=*/program_id, - /**name=*/absl::StrCat(tf_op.name, "/", event.Name()), - /**category=*/tf_op.type, - /*provenance=*/tf_op_full_name, deduplicated_name, is_eager, - /*occurrences=*/1, event.DurationPs(), - /*children_time_ps=*/0, costs.flops, costs.bytes_accessed); - }); - }); - SetTotalTimePs( - result, last_op_offset_ps ? last_op_offset_ps - first_op_offset_ps : 0); - AddIdleOp(result); - return result; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h deleted file mode 100644 index 06bcec66d136cb..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_METRICS_DB_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_METRICS_DB_H_ - -#include "absl/container/flat_hash_map.h" -#include "absl/types/optional.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/op_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Data per host thread for TensorFlow Op Metrics Database. -struct TfMetricsDbData { - // A database of TF-Op metrics for this core. - OpMetricsDb tf_metrics_db; - HostOpMetricsDbBuilder tf_metrics_db_builder{&tf_metrics_db}; -}; - -absl::flat_hash_map -CollectTfOpsFromHostThreadsXPlane(const XPlane& host_trace); - -TfMetricsDbData ConvertHostThreadsXLineToTfMetricsDbData( - const XLineVisitor& line, - const absl::flat_hash_map& tf_ops); - -void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst); - -OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace); - -OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace); - -// Convert TPU DeviceTrace XPlane to OpMetricDb -OpMetricsDb ConvertTpuDeviceTraceXPlaneToOpMetricsDb( - const XPlane& device_trace); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_METRICS_DB_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc deleted file mode 100644 index c877fe8ec8d942..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc +++ /dev/null @@ -1,295 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h" - -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -#if defined(PLATFORM_GOOGLE) -// NOLINTNEXTLINE: clang-tidy missing-includes -using ::testing::EqualsProto; -#endif - -void AddTensorFlowTpuOpEvent(std::string&& name, std::string&& tf_op_fullname, - int64_t start_timestamp_ns, int64_t duration_ns, - std::string&& hlo_category, uint64 flops, - uint64 bytes_accessed, int64_t occurences, - int64_t self_duration, int64_t program_id, - int64_t symbol_id, XPlaneBuilder* plane, - XLineBuilder* line) { - XEventBuilder event = line->AddEvent(*plane->GetOrCreateEventMetadata(name)); - event.SetTimestampNs(start_timestamp_ns); - event.SetDurationNs(duration_ns); - event.SetNumOccurrences(occurences); - XStatsBuilder event_metadata( - plane->GetOrCreateEventMetadata(name), plane); - event_metadata.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), - tf_op_fullname); - event_metadata.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kHloCategory)), - hlo_category); - event_metadata.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kFlops)), flops); - event_metadata.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)), - symbol_id); - event_metadata.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kProgramId)), - program_id); -} - -void AddTensorFlowOpEvent(std::string&& tf_op_fullname, - int64_t start_timestamp_ns, int64_t duration_ns, - bool on_device, absl::string_view kernel_name, - XPlaneBuilder* plane, XLineBuilder* line) { - absl::string_view name = on_device ? kernel_name : tf_op_fullname; - XEventBuilder event = line->AddEvent(*plane->GetOrCreateEventMetadata(name)); - event.SetTimestampNs(start_timestamp_ns); - event.SetDurationNs(duration_ns); - if (!on_device) return; - event.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), - *plane->GetOrCreateStatMetadata(std::move(tf_op_fullname))); -} - -void AddXlaCpuOpEvent(std::string&& hlo_op_name, std::string&& tf_op, - int64_t start_timestamp_ns, int64_t duration_ns, - XPlaneBuilder* plane, XLineBuilder* line) { - XEventBuilder event = - line->AddEvent(*plane->GetOrCreateEventMetadata(hlo_op_name)); - event.SetTimestampNs(start_timestamp_ns); - event.SetDurationNs(duration_ns); - event.ParseAndAddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), tf_op); -} - -TEST(ConvertXPlaneToOpMetricsDb, HostOpMetricsDb) { - static constexpr char kTfOp1[] = "TfOp1"; - static constexpr char kTfOp2[] = "TfOp2"; - constexpr int64_t kTfOp1StartNs = 100000; - constexpr int64_t kTfOp1DurationNs = 8000; - constexpr int64_t kTfOp2StartNs = 110000; - constexpr int64_t kTfOp2DurationNs = 10000; - - XSpace xspace; - XPlane* xplane = GetOrCreateHostXPlane(&xspace); - XPlaneBuilder host_plane(xplane); - XLineBuilder thread1 = host_plane.GetOrCreateLine(/*line_id=*/10); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kTfOp1StartNs, - kTfOp1DurationNs, /*on_device=*/false, - /*kernel_name=*/"", &host_plane, &thread1); - XLineBuilder thread2 = host_plane.GetOrCreateLine(/*line_id=*/20); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kTfOp1StartNs, - kTfOp1DurationNs, /*on_device=*/false, - /*kernel_name=*/"", &host_plane, &thread2); - AddTensorFlowOpEvent(absl::StrCat(kTfOp2, ":", kTfOp2), kTfOp2StartNs, - kTfOp2DurationNs, /*on_device=*/false, - /*kernel_name=*/"", &host_plane, &thread2); - - OpMetricsDb op_metrics = ConvertHostThreadsXPlaneToOpMetricsDb(*xplane); - // Op1, Op2, Idle. - EXPECT_EQ(3, op_metrics.metrics_db_size()); - uint64 total_op_duration = - tsl::profiler::NanoToPico(kTfOp1DurationNs * 2 + kTfOp2DurationNs); - EXPECT_EQ(total_op_duration, op_metrics.total_op_time_ps()); - uint64 total_duration = tsl::profiler::NanoToPico( - kTfOp2StartNs - kTfOp1StartNs + kTfOp2DurationNs + kTfOp1DurationNs); - EXPECT_EQ(total_duration, op_metrics.total_time_ps()); - - // Verifies OpMetricsDb is built correctly. - const OpMetrics& op_1 = op_metrics.metrics_db().at(0); - EXPECT_EQ(kTfOp1, op_1.name()); - EXPECT_EQ(kTfOp1, op_1.category()); - EXPECT_EQ(2, op_1.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToPico(kTfOp1DurationNs) * 2, op_1.time_ps()); - - const OpMetrics& idle = op_metrics.metrics_db().at(1); - EXPECT_EQ(kIdle, idle.name()); - EXPECT_EQ(kIdle, idle.category()); - // Idle time is the gap between Op2 start and the end of Op1, which is 2000ns. - EXPECT_EQ(tsl::profiler::NanoToPico(2000), idle.time_ps()); - - const OpMetrics& op_2 = op_metrics.metrics_db().at(2); - EXPECT_EQ(kTfOp2, op_2.name()); - EXPECT_EQ(kTfOp2, op_2.category()); - EXPECT_EQ(1, op_2.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToPico(kTfOp2DurationNs), op_2.time_ps()); -} - -TEST(ConvertXPlaneToOpMetricsDb, DeviceOpMetricsDb) { - // TfOp1 has kernel1 and kernel2; TfOp2 has kernel3. - static constexpr char kTfOp1[] = "TfOp1"; - static constexpr char kTfOp2[] = "TfOp2"; - static constexpr char kKernel1[] = "kernel1"; - static constexpr char kKernel2[] = "kernel2"; - static constexpr char kKernel3[] = "kernel3"; - constexpr int64_t kKernel1StartNs = 100000; - constexpr int64_t kKernel1DurationNs = 8000; - constexpr int64_t kKernel2StartNs = 110000; - constexpr int64_t kKernel2DurationNs = 10000; - constexpr int64_t kKernel3StartNs = 120000; - constexpr int64_t kKernel3DurationNs = 10000; - - XSpace xspace; - XPlane* xplane = GetOrCreateGpuXPlane(&xspace, /*device_ordinal=*/0); - XPlaneBuilder device_plane(xplane); - XLineBuilder stream1 = device_plane.GetOrCreateLine(/*line_id=*/10); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel1StartNs, - kKernel1DurationNs, /*on_device=*/true, kKernel1, - &device_plane, &stream1); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel2StartNs, - kKernel2DurationNs, /*on_device=*/true, kKernel2, - &device_plane, &stream1); - XLineBuilder stream2 = device_plane.GetOrCreateLine(/*line_id=*/20); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel1StartNs, - kKernel1DurationNs, /*on_device=*/true, kKernel1, - &device_plane, &stream2); - AddTensorFlowOpEvent(absl::StrCat(kTfOp1, ":", kTfOp1), kKernel2StartNs, - kKernel2DurationNs, /*on_device=*/true, kKernel2, - &device_plane, &stream2); - AddTensorFlowOpEvent(absl::StrCat(kTfOp2, ":", kTfOp2), kKernel3StartNs, - kKernel3DurationNs, /*on_device=*/true, kKernel3, - &device_plane, &stream2); - - OpMetricsDb op_metrics = ConvertDeviceTraceXPlaneToOpMetricsDb(*xplane); - - // kernel1, kernel2, kernel3, Idle. - EXPECT_EQ(4, op_metrics.metrics_db_size()); - uint64 total_op_duration = tsl::profiler::NanoToPico( - kKernel1DurationNs * 2 + kKernel2DurationNs * 2 + kKernel3DurationNs); - EXPECT_EQ(total_op_duration, op_metrics.total_op_time_ps()); - // For device, the total_duration for each device is the total duration merged - // from all GPU streams, which is from 100000 to 130000. - uint64 total_duration = tsl::profiler::NanoToPico( - kKernel3StartNs + kKernel3DurationNs - kKernel1StartNs); - EXPECT_EQ(std::max(total_duration, total_op_duration), - op_metrics.total_time_ps()); - - // Verifies OpMetricsDb is built correctly. - const OpMetrics& op_1 = op_metrics.metrics_db().at(0); - EXPECT_EQ(absl::StrCat(kTfOp1, "/", kKernel1), op_1.name()); - EXPECT_EQ(kTfOp1, op_1.category()); - EXPECT_EQ(2, op_1.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToPico(kKernel1DurationNs) * 2, op_1.time_ps()); - - const OpMetrics& op_2 = op_metrics.metrics_db().at(1); - EXPECT_EQ(absl::StrCat(kTfOp1, "/", kKernel2), op_2.name()); - EXPECT_EQ(kTfOp1, op_2.category()); - EXPECT_EQ(2, op_2.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToPico(kKernel2DurationNs) * 2, op_2.time_ps()); - - const OpMetrics& op_3 = op_metrics.metrics_db().at(2); - EXPECT_EQ(absl::StrCat(kTfOp2, "/", kKernel3), op_3.name()); - EXPECT_EQ(kTfOp2, op_3.category()); - EXPECT_EQ(1, op_3.occurrences()); - EXPECT_EQ(tsl::profiler::NanoToPico(kKernel3DurationNs), op_3.time_ps()); - - const OpMetrics& idle = op_metrics.metrics_db().at(3); - EXPECT_EQ(kIdle, idle.name()); - EXPECT_EQ(kIdle, idle.category()); - // GPU is always busy in this example. - EXPECT_EQ(tsl::profiler::NanoToPico(0), idle.time_ps()); -} - -TEST(ConvertXPlaneToOpMetricsDb, TpuDeviceOpMetricsDb) { - XSpace xspace; - XPlane* xplane = GetOrCreateTpuXPlane(&xspace, /*device_ordinal=*/0, "TPU V4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder device_plane(xplane); - XLineBuilder stream1 = device_plane.GetOrCreateLine(/*line_id=*/10); - stream1.SetName(tsl::profiler::kTensorFlowOpLineName); - AddTensorFlowTpuOpEvent("MatMul", "while:MatMul", 0, 10, "MatMul", 34, 45, 2, - 5, 1, 1, &device_plane, &stream1); - OpMetricsDb op_metrics = ConvertTpuDeviceTraceXPlaneToOpMetricsDb(*xplane); -#if defined(PLATFORM_GOOGLE) - EXPECT_THAT(op_metrics, - EqualsProto(R"pb(metrics_db { - hlo_module_id: 1 - self_time_ps: 10000 - flops: 68 - model_flops: 68 - num_cores: 1 - occurrences: 2 - name: "MatMul" - time_ps: 10000 - category: "MatMul" - provenance: "while:MatMul" - min_time_ps: 10000 - } - metrics_db { name: "IDLE" category: "IDLE" } - total_time_ps: 10000 - total_op_time_ps: 10000 - )pb")); -#endif -} - -TEST(ConvertXPlaneToOpMetricsDb, HostXPlaneWithXlaOps) { - XPlane xplane; - XPlaneBuilder plane(&xplane); - XLineBuilder line = plane.GetOrCreateLine(/*line_id=*/10); - AddXlaCpuOpEvent("xla_op", "tf_op", 100000, 8000, &plane, &line); - AddXlaCpuOpEvent("xla_op2", "tf_op2", 110000, 10000, &plane, &line); - OpMetricsDb op_metrics = ConvertHostThreadsXPlaneToOpMetricsDb(xplane); -#if defined(PLATFORM_GOOGLE) - EXPECT_THAT(op_metrics, EqualsProto(R"pb(metrics_db { - self_time_ps: 8000000 - occurrences: 1 - name: "tf_op" - time_ps: 8000000 - } - metrics_db { - self_time_ps: 10000000 - occurrences: 1 - name: "tf_op2" - time_ps: 10000000 - } - metrics_db { - self_time_ps: 2000000 - name: "IDLE" - time_ps: 2000000 - category: "IDLE" - } - total_time_ps: 20000000 - total_op_time_ps: 18000000 - precision_stats {} - )pb")); -#endif -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc deleted file mode 100644 index de75de4bc2b77f..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ /dev/null @@ -1,470 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" - -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/convert/duty_cycle_combiner.h" -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/convert/step_events_to_steps_db.h" -#include "tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h" -#include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h" -#include "tensorflow/core/profiler/convert/xplane_to_step_events.h" -#include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h" -#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_function.pb.h" -#include "tensorflow/core/profiler/utils/device_caps_utils.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/gpu_event_stats.h" -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" -#include "tensorflow/core/profiler/utils/hlo_module_map.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" -#include "tensorflow/core/profiler/utils/op_utils.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using tsl::profiler::FindPlanesWithPrefix; -using tsl::profiler::FindTensorCorePlanes; -using tsl::profiler::Timespan; - -std::string Hostname(const XSpace& space) { - if (space.hostnames().empty()) return "localhost"; - DCHECK_EQ(space.hostnames_size(), 1); - const std::string& hostname = space.hostnames(0); - return hostname; -} - -} // namespace - -PerfEnv MakePerfEnv(double peak_tera_flops_per_second, - std::vector peak_bws) { - PerfEnv result; - result.set_peak_tera_flops_per_second(peak_tera_flops_per_second); - - for (const auto bw : peak_bws) { - result.add_peak_bws_giga_bytes_per_second(bw); - } - result.set_ridge_point(tsl::profiler::TeraToGiga(peak_tera_flops_per_second) / - peak_bws[MemBwType::MEM_BW_TYPE_HBM_RW]); - return result; -} - -PerfEnv MakePerfEnvForTpu(double peak_tera_flops_per_second, - std::vector peak_bws, bool has_merged_vmem, - bool has_megacore) { - PerfEnv result = MakePerfEnv(peak_tera_flops_per_second, peak_bws); - result.set_has_cmem(peak_bws[MemBwType::MEM_BW_TYPE_CMEM_RD] > 0 || - peak_bws[MemBwType::MEM_BW_TYPE_CMEM_WR] > 0); - result.set_has_merged_vmem(has_merged_vmem); - result.set_has_megacore(has_megacore); - return result; -} - -PerfEnv MakePerfEnvForGpu(double peak_tera_flops_per_second, - std::vector peak_bws) { - return MakePerfEnv(peak_tera_flops_per_second, peak_bws); -} - -PerfEnv GetPerfEnvFromXPlane(const XPlane& device_plane) { - DeviceCapabilities cap = GetDeviceCaps(device_plane); - if (!absl::StartsWith(device_plane.name(), kTpuPlanePrefix)) { - double peak_tera_flops_per_second = - cap.num_cores() * - tsl::profiler::GigaToTera(GetFlopMaxThroughputPerSM(cap)); - double hbm_bw_giga_bytes_per_second = - tsl::profiler::UniToGiga(cap.memory_bandwidth()); - double shm_giga_bytes_per_second = - cap.num_cores() * - tsl::profiler::UniToGiga(GetSharedMemoryBandwidthPerSM(cap)); - // Note that treat SRAM_RD and SRAM_WR as the same. So in future, we could - // only use one for shared memory / L1 cache, one for another like L2. - return MakePerfEnvForGpu(peak_tera_flops_per_second, - {/*HBM_RW=*/hbm_bw_giga_bytes_per_second, - /*SRAM_RD=*/shm_giga_bytes_per_second, - /*SRAM_WR=*/shm_giga_bytes_per_second}); - } else { - XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(&device_plane); - std::optional peak_tera_flops_per_second = - visitor.GetStat(StatType::kDevCapPeakTeraflopsPerSecond); - double peak_tera_flops_per_second_val = - peak_tera_flops_per_second.has_value() - ? peak_tera_flops_per_second->DoubleValue() - : 0.0; - std::optional peak_hbm_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakHbmBwGigabytesPerSecond); - double peak_hbm_bw_giga_bytes_per_second_val = - peak_hbm_bw_giga_bytes_per_second.has_value() - ? peak_hbm_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional peak_sram_rd_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakSramRdBwGigabytesPerSecond); - double peak_sram_rd_bw_giga_bytes_per_second_val = - peak_sram_rd_bw_giga_bytes_per_second.has_value() - ? peak_sram_rd_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional peak_sram_wr_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakSramWrBwGigabytesPerSecond); - double peak_sram_wr_bw_giga_bytes_per_second_val = - peak_sram_wr_bw_giga_bytes_per_second.has_value() - ? peak_sram_wr_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional cmem_rd_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakCmemRdBwGigabytesPerSecond); - double cmem_rd_bw_giga_bytes_per_second_val = - cmem_rd_bw_giga_bytes_per_second.has_value() - ? cmem_rd_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional cmem_wr_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakCmemWrBwGigabytesPerSecond); - double cmem_wr_bw_giga_bytes_per_second_val = - cmem_wr_bw_giga_bytes_per_second.has_value() - ? cmem_wr_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional vmem_rd_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakVmemRdBwGigabytesPerSecond); - double vmem_rd_bw_giga_bytes_per_second_val = - vmem_rd_bw_giga_bytes_per_second.has_value() - ? vmem_rd_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional vmem_wr_bw_giga_bytes_per_second = - visitor.GetStat(StatType::kDevCapPeakVmemWrBwGigabytesPerSecond); - double vmem_wr_bw_giga_bytes_per_second_val = - vmem_wr_bw_giga_bytes_per_second.has_value() - ? vmem_wr_bw_giga_bytes_per_second->DoubleValue() - : 0.0; - std::optional has_megacore = - visitor.GetStat(StatType::kDevHasMegacore); - bool has_megacore_val = - has_megacore.has_value() ? has_megacore->BoolValue() : false; - std::optional has_merged_vmem = - visitor.GetStat(StatType::kDevHasMergedVmem); - bool has_merged_vmem_val = - has_merged_vmem.has_value() ? has_merged_vmem->BoolValue() : false; - return MakePerfEnvForTpu( - peak_tera_flops_per_second_val, - {/*HBM_RW=*/peak_hbm_bw_giga_bytes_per_second_val, - /*SRAM_RD=*/peak_sram_rd_bw_giga_bytes_per_second_val, - /*SRAM_WR=*/peak_sram_wr_bw_giga_bytes_per_second_val, - /**CMEM_RD=*/cmem_rd_bw_giga_bytes_per_second_val, - /**CMEM_WR=*/cmem_wr_bw_giga_bytes_per_second_val, - /**VMEM_RD=*/vmem_rd_bw_giga_bytes_per_second_val, - /**VMEM_WR=*/vmem_wr_bw_giga_bytes_per_second_val}, - has_merged_vmem_val, has_megacore_val); - } -} - -void SetRunEnvironment(const XSpace& space, RunEnvironment* env) { - // Currently, we only support profiling one host and one program. - env->set_host_count(1); - env->set_task_count(1); - env->mutable_hostnames()->insert({Hostname(space), true}); - - std::vector gpu_planes = - FindPlanesWithPrefix(space, kGpuPlanePrefix); - if (!gpu_planes.empty()) { - absl::string_view gpu_model = - GpuModelName(GetDeviceCaps(*gpu_planes.front())); - if (!gpu_model.empty()) { - env->set_device_type(std::string(gpu_model)); - } else { - env->set_device_type("GPU"); - } - env->set_device_core_count(gpu_planes.size()); - env->set_hardware_type(tensorflow::profiler::HardwareType::GPU); - } else if (std::vector tpu_planes = - FindTensorCorePlanes(space); - !tpu_planes.empty()) { - XPlaneVisitor visitor = - tsl::profiler::CreateTfXPlaneVisitor(tpu_planes.at(0)); - auto xstat = visitor.GetStat(StatType::kDeviceTypeString); - if (xstat.has_value()) { - env->set_device_type(std::string(xstat->StrOrRefValue())); - } - env->set_device_core_count(tpu_planes.size()); - env->set_hardware_type(tensorflow::profiler::HardwareType::TPU); - } else { - env->set_device_type("CPU"); - env->set_device_core_count(0); - env->set_hardware_type(tensorflow::profiler::HardwareType::CPU_ONLY); - } -} - -void PropagateXSpaceDiagnosticsToOpStats(const XSpace& space, - OpStats* op_stats) { - if (!space.errors().empty()) { - absl::flat_hash_set unique_errors; - unique_errors.insert(space.errors().begin(), space.errors().end()); - *op_stats->mutable_diagnostics()->mutable_errors() = {unique_errors.begin(), - unique_errors.end()}; - } - if (!space.warnings().empty()) { - absl::flat_hash_set unique_warnings; - unique_warnings.insert(space.warnings().begin(), space.warnings().end()); - *op_stats->mutable_diagnostics()->mutable_warnings() = { - unique_warnings.begin(), unique_warnings.end()}; - } -} - -// This function should be idempotent to be called -void SetProgramIdToNameMap(const HloProtoMap& hlo_proto_map, - tensorflow::profiler::OpStats& op_stats) { - auto& program_id_to_name_map = *op_stats.mutable_program_id_to_name_map(); - for (const auto& [program_id, hlo_proto] : hlo_proto_map) { - program_id_to_name_map[program_id] = hlo_proto->hlo_module().name(); - } -} - -void UpdateOpMetricsDbFromHloModuleMap(OpMetricsDb& op_metrics_db, - const HloModuleMap& hlo_module_map) { - for (OpMetrics& op_metrics : *op_metrics_db.mutable_metrics_db()) { - EnterOpMetadataFromHloModuleMap(&op_metrics, hlo_module_map); - } -} - -DutyCycleTracker ConstructDutyCycleTracker(XPlaneVisitor& visitor) { - DutyCycleTracker duty_cycle_tracker; - visitor.ForEachLine([&](const XLineVisitor& line) { - if (line.Name() == kXlaOpLineName) { - line.ForEachEvent([&](const XEventVisitor& event) { - auto hlo_category_stat = event.GetStat(StatType::kHloCategory); - duty_cycle_tracker.AddInterval( - Timespan(event.OffsetPs(), event.DurationPs()), - !(hlo_category_stat && - tsl::profiler::IsOffDutyOp(hlo_category_stat->StrOrRefValue()))); - }); - } else if (line.Name() == kSparseCoreOpLineName || - line.Name() == kSparseCoreModuleLineName) { - line.ForEachEvent([&](const XEventVisitor& event) { - duty_cycle_tracker.AddInterval( - Timespan(event.OffsetPs(), event.DurationPs()), - /*is_active=*/line.Name() == kSparseCoreOpLineName); - }); - } - }); - return duty_cycle_tracker; -} - -OpStats ConvertXSpaceToOpStats(const XSpace& space, - const OpStatsOptions& options) { - OpStats op_stats; - StepEvents step_events; - PropagateXSpaceDiagnosticsToOpStats(space, &op_stats); - // Convert device planes. - OpMetricsDbCombiner op_metrics_db_combiner( - op_stats.mutable_device_op_metrics_db()); - SetRunEnvironment(space, op_stats.mutable_run_environment()); - - KernelReportMap reports; - - // Handle device planes first. device_planes will contain either GPU or TPU. - std::vector device_planes = - FindPlanesWithPrefix(space, kTpuPlanePrefix); - const bool is_gpu = device_planes.empty(); - if (is_gpu) { - device_planes = FindPlanesWithPrefix(space, kGpuPlanePrefix); - } - const bool is_tpu = !is_gpu; - std::string hostname = Hostname(space); - auto& core_id_to_details_map = *op_stats.mutable_core_id_to_details(); - if (is_gpu) { - core_id_to_details_map[kDefaultGpuLocalCoreId].set_hostname(hostname); - } - DutyCycleCombiner duty_cycle_combiner; - // TODO(b/161942993) parallelize XPlane processing per thread. - HloModuleMap hlo_module_map; - if (options.generate_kernel_stats_db || - (is_tpu && options.generate_op_metrics_db)) { - ProcessHloModuleMapFromXSpace(hlo_module_map, &space); - } - for (const XPlane* device_trace : device_planes) { - if (options.generate_op_metrics_db) { - if (!op_stats.has_perf_env()) { - *op_stats.mutable_perf_env() = GetPerfEnvFromXPlane(*device_trace); - } - if (!is_tpu) { - OpMetricsDb device_op_metrics_db = - ConvertDeviceTraceXPlaneToOpMetricsDb(*device_trace); - op_metrics_db_combiner.Combine(device_op_metrics_db); - } else { - // TODO(b/397774568): Remove this once the SparseCore OpMetricsDb is - // implemented. - if (!tsl::profiler::GetSparseCoreId(device_trace->name()).has_value()) { - OpMetricsDb device_op_metrics_db = - ConvertTpuDeviceTraceXPlaneToOpMetricsDb(*device_trace); - UpdateOpMetricsDbFromHloModuleMap(device_op_metrics_db, - hlo_module_map); - op_metrics_db_combiner.Combine(device_op_metrics_db); - } - } - } - if (options.generate_step_db) { - StepEvents device_step_events = - ConvertDeviceTraceXPlaneToStepEvents(*device_trace); - if (is_tpu) { - // In TPU, we take the intersection of step events across cores as well - // as hosts.see b/158249775 and cl/331842545. - IntersectCombineStepEvents(device_step_events, &step_events); - } else { - UnionCombineStepEvents(device_step_events, &step_events); - } - } - if (options.generate_kernel_stats_db) { - ConvertDeviceTraceXPlaneToKernelReports( - *device_trace, - // TODO(cleanup): Move this to xplane_to_kernel_stats_db.cc - [&](const GpuEventStats& stats, KernelReport* kernel) { - if (!stats.IsXlaOp()) return; - const HloInstructionWrapper* hlo_instruction = GetHloInstruction( - hlo_module_map, stats.program_id, stats.hlo_op_names.back()); - if (hlo_instruction != nullptr) { - kernel->set_op_name(std::string(hlo_instruction->TfOpName())); - bool tc_eligible = IsOpTensorCoreEligible(kernel->op_name()); - if (VLOG_IS_ON(1) && !tc_eligible && - kernel->is_kernel_using_tensor_core()) { - VLOG(1) << "Detected new Op using TensorCores: " - << kernel->op_name() << std::endl; - } - kernel->set_is_op_tensor_core_eligible( - tc_eligible || kernel->is_op_tensor_core_eligible()); - } - }, - &reports); - } - XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(device_trace); - DutyCycleTracker duty_cycle_tracker = ConstructDutyCycleTracker(visitor); - if (std::optional core_details_stat = - visitor.GetStat(StatType::kCoreDetails)) { - CoreDetails core_details; - absl::string_view core_details_bytes = core_details_stat->BytesValue(); - if (core_details.ParseFromArray(core_details_bytes.data(), - core_details_bytes.size())) { - core_details.set_hostname(hostname); - // This is a backfill for XPlanes that were create before this field was - // added. - core_details.set_is_sparse_core( - tsl::profiler::GetSparseCoreId(device_trace->name()).has_value()); - core_id_to_details_map[device_trace->id()] = core_details; - } - } - if (core_id_to_details_map.contains(device_trace->id())) { - CoreDetails& core_details = core_id_to_details_map[device_trace->id()]; - duty_cycle_combiner.CombineCore(duty_cycle_tracker, - core_details.local_chip_id()); - } else { - LOG(WARNING) << "No CoreDetails found for TPU device plane: " - << device_trace->name(); - duty_cycle_combiner.CombineChip(duty_cycle_tracker); - } - } - - if (is_tpu) { - OpMetricsDb& op_metrics_db = *op_stats.mutable_device_op_metrics_db(); - op_metrics_db.set_idle_time_ps(duty_cycle_combiner.GetTotalIdleTimePs()); - op_metrics_db.set_busy_time_ps(duty_cycle_combiner.GetTotalActiveTimePs()); - } - - // Combine into reports. - if (options.generate_kernel_stats_db) { - CopyTopKDurationKernelReportsToDb(reports, - op_stats.mutable_kernel_stats_db()); - } - - bool has_device = !device_planes.empty(); - // Convert a host plane. - const XPlane* host_plane = FindPlaneWithName(space, kHostThreadsPlaneName); - if (host_plane) { - if (options.generate_op_metrics_db) { - *op_stats.mutable_host_op_metrics_db() = - ConvertHostThreadsXPlaneToOpMetricsDb(*host_plane); - } - if (options.generate_step_db && !has_device) { - StepEvents host_step_events = - ConvertHostThreadsXPlaneToStepEvents(*host_plane, nullptr); - UnionCombineStepEvents(host_step_events, &step_events); - } - XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(host_plane); - auto stat = visitor.GetStat(StatType::kMatrixUnitUtilizationPercent); - if (stat.has_value()) { - op_stats.mutable_performance_counter_result() - ->set_matrix_unit_utilization_percent(stat->DoubleValue()); - } - TfFunctionDb* tf_function_db = op_stats.mutable_tf_function_db(); - visitor.ForEachLine([&](const XLineVisitor& line) { - CombineTfFunctionDb(ConvertHostThreadsXLineToTfFunctionDb(line), - tf_function_db); - }); - } - if (options.generate_step_db) { - if (is_tpu) { - // TPU steps relies on step number in step line in Xplane which has - // already dropped the incomplete steps at both beginning and end. - *op_stats.mutable_step_db() = ConvertStepEventsToStepDb( - has_device, /*maybe_drop_incomplete_steps=*/false, step_events); - *op_stats.mutable_device_op_metrics_db()->mutable_precision_stats() = - ComputePrecisionStats(step_events); - OpMetricsDbCombiner combiner( - op_stats.mutable_hlo_metrics_db_complete_steps_only()); - for (const auto& step_info : op_stats.step_db().step_sequence()) { - combiner.Combine(step_info.hlo_metrics_db()); - } - } else { - StepEvents nonoverlapped_step_events = - ToNonOverlappedStepEvents(step_events); - *op_stats.mutable_step_db() = ConvertStepEventsToStepDb( - has_device, options.maybe_drop_incomplete_steps, - nonoverlapped_step_events); - *op_stats.mutable_device_op_metrics_db()->mutable_precision_stats() = - ComputePrecisionStats(nonoverlapped_step_events); - } - } - - // Set program_id_to_name map in OpStats from Xspace - // Will be non-op if the space does not have materialized device traces - HloProtoMap hlo_proto_map; - hlo_proto_map.AddHloProtosFromXSpace(space); - SetProgramIdToNameMap(hlo_proto_map, op_stats); - - return op_stats; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.h b/tensorflow/core/profiler/convert/xplane_to_op_stats.h index cd180e7c8dcd0e..4f116494393761 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.h +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.h @@ -16,51 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_STATS_H_ #define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_STATS_H_ -#include - -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -struct OpStatsOptions { - bool maybe_drop_incomplete_steps = false; - bool generate_op_metrics_db = false; - bool generate_step_db = false; - bool generate_kernel_stats_db = false; -}; - -// NOTE: call GroupTfEvents before if OpStats.step_db needs to be generated. -OpStats ConvertXSpaceToOpStats(const XSpace& space, - const OpStatsOptions& options); - -// Populates the program_id_to_name map in OpStats. -void SetProgramIdToNameMap(const HloProtoMap& hlo_proto_map, - tensorflow::profiler::OpStats& op_stats); - -// Populates the given RunEnvironment with data from XSpace. -void SetRunEnvironment(const XSpace& space, RunEnvironment* env); - -// Propagate and dedup the diagnostics in XSpace and add to OpStats. -void PropagateXSpaceDiagnosticsToOpStats(const XSpace& space, - OpStats* op_stats); - -// Populates PerfEnv. -PerfEnv MakePerfEnv(double peak_tera_flops_per_second, - std::vector peak_bws); - -// Extracts PerfEnv from XPlane stats. -PerfEnv GetPerfEnvFromXPlane(const XPlane& device_plane); - -// Constructs a DutyCycleTracker from the given XPlaneVisitor. -DutyCycleTracker ConstructDutyCycleTracker(XPlaneVisitor& visitor); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/convert/xplane_to_op_stats.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_STATS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc deleted file mode 100644 index c1a310e0127165..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc +++ /dev/null @@ -1,814 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" - -#include -#include -#include -#include -#include - -#include -#include -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/status.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" -#include "tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/step_events_to_steps_db.h" -#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_function.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::testing::Property; -using ::testing::UnorderedElementsAre; - -TEST(ConvertXPlaneToOpStats, GpuPerfEnv) { - auto space = std::make_unique(); - constexpr double kMaxError = 0.01; - constexpr int kClockRateKHz = 1530000; - constexpr int kCoreCount = 80; - constexpr uint64 kMemoryBandwidthBytesPerSecond = - uint64{900} * 1000 * 1000 * 1000; - // Volta. - constexpr int kComputeCapMajor = 7; - constexpr int kComputeCapMinor = 0; - - XPlaneBuilder device_plane( - GetOrCreateGpuXPlane(space.get(), /*device_ordinal=*/0)); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevVendor)), - kDeviceVendorNvidia); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata("clock_rate"), - kClockRateKHz); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata("core_count"), - kCoreCount); - device_plane.AddStatValue( - *device_plane.GetOrCreateStatMetadata("memory_bandwidth"), - kMemoryBandwidthBytesPerSecond); - device_plane.AddStatValue( - *device_plane.GetOrCreateStatMetadata("compute_cap_major"), - kComputeCapMajor); - device_plane.AddStatValue( - *device_plane.GetOrCreateStatMetadata("compute_cap_minor"), - kComputeCapMinor); - - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStatsOptions options; - options.generate_op_metrics_db = true; - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), - options, &op_stats)); - const PerfEnv& perf_env = op_stats.perf_env(); - // Change to lower flops number that we do not use sum of the tensor core peak - // flops and the cuda core peak flops together as peak flops. Only use the - // tensor core peak flops as all those white papers are using. - EXPECT_NEAR(125.34, perf_env.peak_tera_flops_per_second(), kMaxError); - EXPECT_NEAR( - 900, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_HBM_RW), - kMaxError); - // Ridge point changed accordingly from above peak flops change. - EXPECT_NEAR(139.26, perf_env.ridge_point(), kMaxError); -} - -TEST(ConvertXPlaneToOpStats, GpuRunEnvironment) { - auto space = std::make_unique(); - XPlaneBuilder device_plane1( - GetOrCreateGpuXPlane(space.get(), /*device_ordinal=*/0)); - device_plane1.AddStatValue(*device_plane1.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevVendor)), - kDeviceVendorNvidia); - XPlaneBuilder device_plane2( - GetOrCreateGpuXPlane(space.get(), /*device_ordinal=*/1)); - device_plane2.AddStatValue(*device_plane2.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevVendor)), - kDeviceVendorNvidia); - - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot_or.value(), OpStatsOptions(), &op_stats)); - const RunEnvironment& run_env = op_stats.run_environment(); - - EXPECT_EQ("Nvidia GPU", run_env.device_type()); - EXPECT_EQ(1, run_env.host_count()); - EXPECT_EQ(1, run_env.task_count()); - EXPECT_EQ(2, run_env.device_core_count()); -} - -TEST(ConvertXPlaneToOpStats, CpuOnlyStepDbTest) { - constexpr int64_t kStepNum = 123; - constexpr int64_t kStepId = 0; - - auto space = std::make_unique(); - XPlaneBuilder host_plane_builder(GetOrCreateHostXPlane(space.get())); - host_plane_builder.ReserveLines(2); - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext, - 0, 100, {{StatType::kStepNum, kStepNum}}); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun, - 10, 90, - {{StatType::kStepId, kStepId}, - {StatType::kProducerType, int64_t{1}}, - {StatType::kProducerId, kStepId}}); - - auto tf_executor_thread = host_plane_builder.GetOrCreateLine(1); - CreateXEvent(&host_plane_builder, &tf_executor_thread, - HostEventType::kExecutorStateProcess, 20, 80, - {{StatType::kStepId, kStepId}, - {StatType::kConsumerType, int64_t{1}}, - {StatType::kConsumerId, kStepId}}); - CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 30, 70); - - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), - options, &op_stats)); - const StepDatabaseResult& step_db = op_stats.step_db(); - - EXPECT_EQ(step_db.step_sequence_size(), 1); -} - -TEST(ConvertXPlaneToOpStats, GpuStepDbTest) { - constexpr int64_t kStepNum = 123; - constexpr int64_t kStepId = 0; - constexpr int64_t kCorrelationId = 100; - - auto space = std::make_unique(); - XPlaneBuilder host_plane_builder(GetOrCreateHostXPlane(space.get())); - host_plane_builder.ReserveLines(2); - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext, - 0, 100, {{StatType::kStepNum, kStepNum}}); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun, - 10, 90, - {{StatType::kStepId, kStepId}, - {StatType::kProducerType, int64_t{1}}, - {StatType::kProducerId, kStepId}}); - - auto tf_executor_thread = host_plane_builder.GetOrCreateLine(1); - CreateXEvent(&host_plane_builder, &tf_executor_thread, - HostEventType::kExecutorStateProcess, 20, 20, - {{StatType::kStepId, kStepId}, - {StatType::kConsumerType, int64_t{1}}, - {StatType::kConsumerId, kStepId}}); - CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 30, 10, - {{StatType::kCorrelationId, kCorrelationId}}); - - XPlaneBuilder device_plane_builder( - GetOrCreateGpuXPlane(space.get(), /*device_ordinal=*/0)); - device_plane_builder.ReserveLines(1); - - auto stream = device_plane_builder.GetOrCreateLine(0); - CreateXEvent(&device_plane_builder, &stream, "matmul", 50, 40, - {{StatType::kCorrelationId, kCorrelationId}}); - - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), - options, &op_stats)); - const StepDatabaseResult& step_db = op_stats.step_db(); - - EXPECT_EQ(step_db.step_sequence_size(), 1); - - PrecisionStats precision_stats = - op_stats.device_op_metrics_db().precision_stats(); - EXPECT_EQ(precision_stats.compute_16bit_ps(), 0); - EXPECT_EQ(precision_stats.compute_32bit_ps(), 40); -} - -TEST(ConvertXPlaneToOpStats, PropagateAndDedupErrors) { - XSpace space; - static constexpr char kError[] = "host: error"; - *space.add_errors() = kError; - *space.add_errors() = kError; - - OpStats op_stats = ConvertXSpaceToOpStats(space, OpStatsOptions()); - - EXPECT_EQ(1, op_stats.diagnostics().errors_size()); - EXPECT_EQ(kError, op_stats.diagnostics().errors(/*index=*/0)); -} - -TEST(ConvertXPlaneToOpStats, Hostnames) { - XSpace space; - static constexpr char kHost[] = "host1"; - *space.add_hostnames() = kHost; - - OpStats op_stats = ConvertXSpaceToOpStats(space, OpStatsOptions()); - EXPECT_EQ( - kHost, - op_stats.core_id_to_details().at(kDefaultGpuLocalCoreId).hostname()); -} - -void BuildXSpaceForTest(XSpace& xspace, absl::string_view hostname) { - constexpr int64_t kStepNum = 123; - constexpr int64_t kStepId = 456; - // Create a host only XSpace for test. - XPlaneBuilder host_plane_builder(GetOrCreateHostXPlane(&xspace)); - host_plane_builder.ReserveLines(2); - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext, - 0, 100, {{StatType::kStepNum, kStepNum}}); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun, - 10, 90, - {{StatType::kStepId, kStepId}, - {StatType::kProducerType, int64_t{1}}, - {StatType::kProducerId, kStepId}}); - - auto executor_thread = host_plane_builder.GetOrCreateLine(1); - CreateXEvent(&host_plane_builder, &executor_thread, - HostEventType::kExecutorStateProcess, 20, 80, - {{StatType::kStepId, kStepId}, - {StatType::kConsumerType, int64_t{1}}, - {StatType::kConsumerId, kStepId}}); - // Create a TensorFlow op that runs for 70 ps. - CreateXEvent(&host_plane_builder, &executor_thread, "aaa:bbb", 30, 70); - xspace.add_hostnames(std::string(hostname)); -} - -TEST(ConvertXPlaneToOpStats, TestConvertMultiXSpacesToCombinedOpStats) { - static constexpr char kHost1[] = "host1"; - static constexpr char kHost2[] = "host2"; - - auto xspace1 = std::make_unique(); - auto xspace2 = std::make_unique(); - - BuildXSpaceForTest(*xspace1, kHost1); - BuildXSpaceForTest(*xspace2, kHost2); - - std::vector xspace_paths; - xspace_paths.push_back("host1.pb"); - xspace_paths.push_back("host2.pb"); - - std::vector> xspaces; - xspaces.push_back(std::move(xspace1)); - xspaces.push_back(std::move(xspace2)); - - auto session_snapshot_or = - SessionSnapshot::Create(std::move(xspace_paths), std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - OpStats combined_op_stats; - - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), - options, &combined_op_stats)) - << "Failed to convert multi XSpace to OpStats"; - - // Result OpStats has 2 Host Ops, "IDLE" and "aaa:bbb". - ASSERT_EQ(combined_op_stats.host_op_metrics_db().metrics_db_size(), 2); - const auto& metric = combined_op_stats.host_op_metrics_db().metrics_db(1); - EXPECT_EQ(metric.name(), "aaa"); - EXPECT_EQ(metric.category(), "bbb"); - // Each host has the HostOp "aaa:bbb" running for 70 ps, so the combined - // OpStats has "aaa:bbb" running for 140 ps in total. - EXPECT_EQ(metric.self_time_ps(), 140); - - // Result OpStats has 1 step, 2 cores. - ASSERT_EQ(combined_op_stats.step_db().step_sequence_size(), 1); - ASSERT_EQ( - combined_op_stats.step_db().step_sequence(0).step_info_per_core_size(), - 2); - const auto& step_info_per_core = - combined_op_stats.step_db().step_sequence(0).step_info_per_core(); - // global_core_id is computed using: 1000 * host_id + local_core_id. - EXPECT_TRUE(step_info_per_core.contains(kDefaultGpuLocalCoreId)); - EXPECT_TRUE(step_info_per_core.contains(1000 + kDefaultGpuLocalCoreId)); - - const auto& core_details_map = combined_op_stats.core_id_to_details(); - EXPECT_EQ(kHost1, core_details_map.at(kDefaultGpuLocalCoreId).hostname()); - EXPECT_EQ(kHost2, - core_details_map.at(1000 + kDefaultGpuLocalCoreId).hostname()); -} - -TEST(ConvertXPlaneToOpStats, RunEnvironmentExtractedFromTpuPlane) { - XSpace xspace; - for (int i : {0, 1, 2, 3}) { - GetOrCreateTpuXPlane(&xspace, i, "TPU V4", 0, 0); - } - - OpStats op_stats = ConvertXSpaceToOpStats(xspace, OpStatsOptions()); - - EXPECT_EQ(op_stats.run_environment().device_type(), "TPU V4"); - EXPECT_EQ(op_stats.run_environment().device_core_count(), 4); -} - -TEST(ConvertXPlaneToOpStats, TpuPerfEnv) { - auto space = std::make_unique(); - constexpr double kMaxError = 0.01; - constexpr int kClockRateKHz = 1530000; - constexpr int kCoreCount = 80; - constexpr uint64 kMemoryBandwidthBytesPerSecond = - uint64{900} * 1000 * 1000 * 1000; - // Volta. - constexpr int kComputeCapMajor = 7; - constexpr int kComputeCapMinor = 0; - constexpr double kDevCapPeakTeraflopsPerSecond = 141.0; - constexpr double kDevCapPeakHbmBwGigabytesPerSecond = 900.0; - constexpr double kDevCapPeakSramRdBwGigabytesPerSecond = 101.0; - constexpr double kDevCapPeakSramWrBwGigabytesPerSecond = 102.0; - constexpr double kDevCapPeakCmemRdBwGigabytesPerSecond = 101.0; - constexpr double kDevCapPeakCmemWrBwGigabytesPerSecond = 102.0; - constexpr double kDevCapPeakVmemRdBwGigabytesPerSecond = 201.0; - constexpr double kDevCapPeakVmemWrBwGigabytesPerSecond = 202.0; - - XPlaneBuilder device_plane(GetOrCreateTpuXPlane( - space.get(), /*device_ordinal=*/0, "TPU V4", - kDevCapPeakTeraflopsPerSecond, kDevCapPeakHbmBwGigabytesPerSecond)); - /*device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevVendor)), - kDeviceVendorNvidia); // "Google, Inc.");*/ - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata("clock_rate"), - kClockRateKHz); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata("core_count"), - kCoreCount); - device_plane.AddStatValue( - *device_plane.GetOrCreateStatMetadata("memory_bandwidth"), - kMemoryBandwidthBytesPerSecond); - device_plane.AddStatValue( - *device_plane.GetOrCreateStatMetadata("compute_cap_major"), - kComputeCapMajor); - device_plane.AddStatValue( - *device_plane.GetOrCreateStatMetadata("compute_cap_minor"), - kComputeCapMinor); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - "peak_sram_rd_bw_gigabytes_per_second"), - kDevCapPeakSramRdBwGigabytesPerSecond); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - "peak_sram_wr_bw_gigabytes_per_second"), - kDevCapPeakSramWrBwGigabytesPerSecond); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - "peak_cmem_rd_bw_gigabytes_per_second"), - kDevCapPeakCmemRdBwGigabytesPerSecond); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - "peak_cmem_wr_bw_gigabytes_per_second"), - kDevCapPeakCmemWrBwGigabytesPerSecond); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - "peak_vmem_rd_bw_gigabytes_per_second"), - kDevCapPeakVmemRdBwGigabytesPerSecond); - device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata( - "peak_vmem_wr_bw_gigabytes_per_second"), - kDevCapPeakVmemWrBwGigabytesPerSecond); - - OpStatsOptions options; - options.generate_op_metrics_db = true; - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), - options, &op_stats)); - const PerfEnv& perf_env = op_stats.perf_env(); - EXPECT_NEAR(kDevCapPeakTeraflopsPerSecond, - perf_env.peak_tera_flops_per_second(), kMaxError); - EXPECT_NEAR( - kDevCapPeakHbmBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_HBM_RW), - kMaxError); - EXPECT_NEAR( - kDevCapPeakSramRdBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_SRAM_RD), - kMaxError); - EXPECT_NEAR( - kDevCapPeakSramWrBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_SRAM_WR), - kMaxError); - EXPECT_NEAR( - kDevCapPeakCmemRdBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_CMEM_RD), - kMaxError); - EXPECT_NEAR( - kDevCapPeakCmemWrBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_CMEM_WR), - kMaxError); - EXPECT_NEAR( - kDevCapPeakVmemRdBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_VMEM_RD), - kMaxError); - EXPECT_NEAR( - kDevCapPeakVmemWrBwGigabytesPerSecond, - perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_VMEM_WR), - kMaxError); - EXPECT_NEAR(156.67, perf_env.ridge_point(), kMaxError); -} - -TEST(ConvertXPlaneToOpStats, TpuRunEnvironment) { - auto space = std::make_unique(); - XPlaneBuilder device_plane1( - GetOrCreateTpuXPlane(space.get(), /*device_ordinal=*/0, "TPU V4", 0, 0)); - XPlaneBuilder device_plane2( - GetOrCreateTpuXPlane(space.get(), /*device_ordinal=*/1, "TPU V4", 0, 0)); - - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot_or.value(), OpStatsOptions(), &op_stats)); - const RunEnvironment& run_env = op_stats.run_environment(); - - EXPECT_EQ("TPU V4", run_env.device_type()); - EXPECT_EQ(1, run_env.host_count()); - EXPECT_EQ(1, run_env.task_count()); - EXPECT_EQ(2, run_env.device_core_count()); -} - -TEST(ConvertXPlaneToOpStats, TpuDeviceTraceToStepDb) { - auto space = std::make_unique(); - constexpr double kDevCapPeakTeraflopsPerSecond = 141.0; - constexpr double kDevCapPeakHbmBwGigabytesPerSecond = 1000.0; - XPlaneBuilder xplane_builder(GetOrCreateTpuXPlane( - space.get(), /*device_ordinal=*/0, "TPU V4", - kDevCapPeakTeraflopsPerSecond, kDevCapPeakHbmBwGigabytesPerSecond)); - - XEventMetadata* event_metadata = xplane_builder.GetOrCreateEventMetadata(1); - event_metadata->set_name("op_name"); - XStatsBuilder stats(event_metadata, &xplane_builder); - - stats.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kProgramId)), - 1); - stats.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kSymbolId)), - 1); - stats.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kSelfDurationPs)), - 10); - stats.AddStatValue( - *xplane_builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), - "tf_op_name"); - stats.AddStatValue(*xplane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kHloCategory)), - "category"); - XLineBuilder line = xplane_builder.GetOrCreateLine(1); - line.SetName(kTensorFlowOpLineName); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetNs(0); - event.SetDurationNs(10); - - OpStatsOptions options; - options.generate_op_metrics_db = true; - std::vector> xspaces; - xspaces.push_back(std::move(space)); - auto session_snapshot_or = - SessionSnapshot::Create({"test_xspace"}, std::move(xspaces)); - TF_CHECK_OK(session_snapshot_or.status()); - OpStats op_stats; - TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), - options, &op_stats)); - EXPECT_THAT(op_stats.device_op_metrics_db().metrics_db(), - UnorderedElementsAre(Property(&OpMetrics::name, "op_name"), - Property(&OpMetrics::name, "IDLE"))); -} - -// Verifies that the step db is generated correctly by intersecting for -// multi-device TPU. -TEST(ConvertXPlaneToOpStats, TpuMultiDeviceStepDbTest) { - auto space = std::make_unique(); - - XPlaneBuilder device_plane_builder1( - GetOrCreateTpuXPlane(space.get(), /*device_ordinal=*/0, "TPU V4", 0, 0)); - XPlaneBuilder device_plane_builder2( - GetOrCreateTpuXPlane(space.get(), /*device_ordinal=*/1, "TPU V4", 0, 0)); - device_plane_builder1.ReserveLines(1); - device_plane_builder2.ReserveLines(1); - - // Create 1 step in xplane in TPU ordinal 0. - XStatMetadata* kGroupId1 = device_plane_builder1.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kGroupId)); - XLineBuilder line = device_plane_builder1.GetOrCreateLine(1); - line.SetName(kXlaOpLineName); - // Step 1 - XEventMetadata* event_metadata = - device_plane_builder1.GetOrCreateEventMetadata(1); - event_metadata->set_name("Step 1"); - XEventBuilder event_builder = line.AddEvent(*event_metadata); - event_builder.AddStatValue(*kGroupId1, 1); // step num - event_builder.SetDurationNs(100); - event_builder.SetOffsetNs(100); - - // Create 2 steps in xplane in TPU ordinal 1. - line = device_plane_builder2.GetOrCreateLine(1); - line.SetName(kXlaOpLineName); - // Step 1 - XStatMetadata* kGroupId2 = device_plane_builder2.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kGroupId)); - XEventMetadata* event_metadata2 = - device_plane_builder2.GetOrCreateEventMetadata(2); - event_metadata2->set_name("Step 1"); - XEventBuilder event_builder2 = line.AddEvent(*event_metadata2); - event_builder2.AddStatValue(*kGroupId2, 1); // step num - event_builder2.SetDurationNs(100); - event_builder2.SetOffsetNs(300); - // Step 2 - XStatMetadata* kGroupId3 = device_plane_builder2.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kGroupId)); - XEventMetadata* event_metadata3 = - device_plane_builder2.GetOrCreateEventMetadata(2); - event_metadata3->set_name("Step 2"); - XEventBuilder event_builder3 = line.AddEvent(*event_metadata3); - event_builder3.AddStatValue(*kGroupId3, 2); // step num - event_builder3.SetDurationNs(100); - event_builder3.SetOffsetNs(300); - - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - OpStats op_stats = ConvertXSpaceToOpStats(*space, options); - const StepDatabaseResult& step_db = op_stats.step_db(); - // For TPU step events, we intersect the step events by step num across - // different TPU devices. - EXPECT_EQ(step_db.step_sequence_size(), 1); -} - -TEST(ConvertXPlaneToOpStats, ConstructDutyCycleTrackerFromXlaOps) { - XSpace space; - XPlane* device_plane = GetOrCreateTpuXPlane( - &space, /*device_ordinal=*/0, /*device_type=*/"TPU v4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder device_plane_builder(device_plane); - XLineBuilder op_line = device_plane_builder.GetOrCreateLine(0); - op_line.SetName(kXlaOpLineName); - CreateXEvent(&device_plane_builder, &op_line, "op.1", /*offset_ps=*/10, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloInfeed}}); - CreateXEvent(&device_plane_builder, &op_line, "op.2", /*offset_ps=*/20, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloCall}}); - CreateXEvent(&device_plane_builder, &op_line, "op.3", /*offset_ps=*/30, - /*duration_ps=*/10); - CreateXEvent(&device_plane_builder, &op_line, "op.4", /*offset_ps=*/40, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloOutfeed}}); - - XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(device_plane); - DutyCycleTracker tracker = ConstructDutyCycleTracker(visitor); - EXPECT_EQ(tracker.GetActiveTimePs(), 20); - EXPECT_EQ(tracker.GetIdleTimePs(), 20); -} - -TEST(ConvertXPlaneToOpStats, ConstructDutyCycleTrackerFromSparseCore) { - XSpace space; - XPlane* sc_plane = GetOrCreateTpuXPlane( - &space, /*device_ordinal=*/0, /*device_type=*/"TPU v4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder sc_plane_builder(sc_plane); - XLineBuilder op_line = sc_plane_builder.GetOrCreateLine(0); - op_line.SetName(kSparseCoreOpLineName); - CreateXEvent(&sc_plane_builder, &op_line, "op.1", /*offset_ps=*/10, - /*duration_ps=*/10); - CreateXEvent(&sc_plane_builder, &op_line, "op.2", /*offset_ps=*/20, - /*duration_ps=*/10); - CreateXEvent(&sc_plane_builder, &op_line, "op.3", /*offset_ps=*/30, - /*duration_ps=*/10); - CreateXEvent(&sc_plane_builder, &op_line, "op.4", /*offset_ps=*/40, - /*duration_ps=*/10); - XLineBuilder module_line = sc_plane_builder.GetOrCreateLine(1); - module_line.SetName(kSparseCoreModuleLineName); - CreateXEvent(&sc_plane_builder, &module_line, "module.1", /*offset_ps=*/5, - /*duration_ps=*/50); - - XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(sc_plane); - DutyCycleTracker tracker = ConstructDutyCycleTracker(visitor); - EXPECT_EQ(tracker.GetActiveTimePs(), 40); - EXPECT_EQ(tracker.GetIdleTimePs(), 10); -} - -TEST(ConvertXPlaneToOpStats, MultiCoreChipBusyAndIdleTimeTest) { - XSpace space; - CoreDetails tc_core_details; - tc_core_details.set_local_chip_id(0); - CoreDetails sc_core_details; - sc_core_details.set_local_chip_id(0); - XPlane* tc_plane = GetOrCreateTpuXPlane( - &space, /*device_ordinal=*/0, /*device_type=*/"TPU v4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder tc_plane_builder(tc_plane); - tc_plane_builder.AddStatValue(*tc_plane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kCoreDetails)), - tc_core_details); - XLineBuilder xla_op_line = tc_plane_builder.GetOrCreateLine(0); - xla_op_line.SetName(kXlaOpLineName); - CreateXEvent(&tc_plane_builder, &xla_op_line, "op.1", /*offset_ps=*/10, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloInfeed}}); - CreateXEvent(&tc_plane_builder, &xla_op_line, "op.2", /*offset_ps=*/20, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloCall}}); - CreateXEvent(&tc_plane_builder, &xla_op_line, "op.3", /*offset_ps=*/30, - /*duration_ps=*/10); - CreateXEvent(&tc_plane_builder, &xla_op_line, "op.4", /*offset_ps=*/40, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloOutfeed}}); - - XPlane* sc_plane = GetOrCreateTpuXPlane( - &space, /*device_ordinal=*/1, /*device_type=*/"TPU v4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder sc_plane_builder(sc_plane); - sc_plane_builder.AddStatValue(*sc_plane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kCoreDetails)), - sc_core_details); - XLineBuilder sc_op_line = sc_plane_builder.GetOrCreateLine(0); - sc_op_line.SetName(kSparseCoreOpLineName); - CreateXEvent(&sc_plane_builder, &sc_op_line, "op.1", /*offset_ps=*/10, - /*duration_ps=*/10); - CreateXEvent(&sc_plane_builder, &sc_op_line, "op.2", /*offset_ps=*/20, - /*duration_ps=*/10); - CreateXEvent(&sc_plane_builder, &sc_op_line, "op.3", /*offset_ps=*/30, - /*duration_ps=*/10); - CreateXEvent(&sc_plane_builder, &sc_op_line, "op.4", /*offset_ps=*/40, - /*duration_ps=*/10); - XLineBuilder sc_module_line = sc_plane_builder.GetOrCreateLine(1); - sc_module_line.SetName(kSparseCoreModuleLineName); - CreateXEvent(&sc_plane_builder, &sc_module_line, "module.1", /*offset_ps=*/5, - /*duration_ps=*/50); - - OpStats op_stats = ConvertXSpaceToOpStats(space, OpStatsOptions()); - EXPECT_EQ(op_stats.device_op_metrics_db().idle_time_ps(), 10); - EXPECT_EQ(op_stats.device_op_metrics_db().busy_time_ps(), 40); -} - -TEST(ConvertXPlaneToOpStats, HandleSparseCoreBusyOpMetrics) { - XSpace space; - XPlane* tc_plane = GetOrCreateTpuXPlane( - &space, /*device_ordinal=*/0, /*device_type=*/"TPU v4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder tc_plane_builder(tc_plane); - tc_plane_builder.SetId(0); - XLineBuilder tc_step_line = tc_plane_builder.GetOrCreateLine(0); - tc_step_line.SetName(tsl::profiler::kStepLineName); - CreateXEvent(&tc_plane_builder, &tc_step_line, "step.1", /*offset_ps=*/10, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{1}}}); - CreateXEvent(&tc_plane_builder, &tc_step_line, "step.2", /*offset_ps=*/20, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{2}}}); - CreateXEvent(&tc_plane_builder, &tc_step_line, "step.3", /*offset_ps=*/30, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{3}}}); - CreateXEvent(&tc_plane_builder, &tc_step_line, "step.4", /*offset_ps=*/40, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{4}}}); - XLineBuilder tc_module_line = tc_plane_builder.GetOrCreateLine(1); - tc_module_line.SetName(tsl::profiler::kXlaModuleLineName); - CreateXEvent(&tc_plane_builder, &tc_module_line, "module.1", /*offset_ps=*/10, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{1}}}); - CreateXEvent(&tc_plane_builder, &tc_module_line, "module.2", /*offset_ps=*/20, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{2}}}); - CreateXEvent(&tc_plane_builder, &tc_module_line, "module.3", /*offset_ps=*/30, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{3}}}); - CreateXEvent(&tc_plane_builder, &tc_module_line, "module.4", /*offset_ps=*/40, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{4}}}); - XLineBuilder tc_op_line = tc_plane_builder.GetOrCreateLine(2); - tc_op_line.SetName(kXlaOpLineName); - auto& program_id_stat = *tc_plane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kProgramId)); - auto& symbol_id_stat = *tc_plane_builder.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kSymbolId)); - XStatsBuilder op1_stats( - tc_plane_builder.GetOrCreateEventMetadata("op.1"), &tc_plane_builder); - op1_stats.AddStatValue(program_id_stat, 1); - op1_stats.AddStatValue(symbol_id_stat, 1); - XStatsBuilder op2_stats( - tc_plane_builder.GetOrCreateEventMetadata("op.2"), &tc_plane_builder); - op2_stats.AddStatValue(program_id_stat, 1); - op2_stats.AddStatValue(symbol_id_stat, 2); - XStatsBuilder op3_stats( - tc_plane_builder.GetOrCreateEventMetadata("op.3"), &tc_plane_builder); - op3_stats.AddStatValue(program_id_stat, 1); - op3_stats.AddStatValue(symbol_id_stat, 3); - XStatsBuilder op4_stats( - tc_plane_builder.GetOrCreateEventMetadata("op.4"), &tc_plane_builder); - op4_stats.AddStatValue(program_id_stat, 1); - op4_stats.AddStatValue(symbol_id_stat, 4); - CreateXEvent(&tc_plane_builder, &tc_op_line, "op.1", /*offset_ps=*/15, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{1}}}); - CreateXEvent(&tc_plane_builder, &tc_op_line, "op.2", /*offset_ps=*/25, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{2}}}); - CreateXEvent(&tc_plane_builder, &tc_op_line, "op.3", /*offset_ps=*/35, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{3}}}); - CreateXEvent(&tc_plane_builder, &tc_op_line, "op.4", /*offset_ps=*/45, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{4}}}); - XPlane* sc_plane = GetOrCreateTpuXPlane( - &space, /*device_ordinal=*/1, /*device_type=*/"TPU v4", - /*peak_tera_flops_per_second=*/0, - /*peak_hbm_bw_gigabytes_per_second=*/0); - XPlaneBuilder sc_plane_builder(sc_plane); - sc_plane_builder.SetId(1); - sc_plane_builder.SetName( - absl::StrCat(sc_plane->name(), " SparseCore ", sc_plane->id())); - XLineBuilder sc_step_line = sc_plane_builder.GetOrCreateLine(0); - sc_step_line.SetName(tsl::profiler::kSparseCoreStepLineName); - CreateXEvent(&sc_plane_builder, &sc_step_line, "step.1", /*offset_ps=*/10, - /*duration_ps=*/10, - {{StatType::kStepIdleTimePs, int64_t{5}}, - {StatType::kGroupId, int64_t{1}}}); - CreateXEvent(&sc_plane_builder, &sc_step_line, "step.2", /*offset_ps=*/20, - /*duration_ps=*/10, - {{StatType::kStepIdleTimePs, int64_t{5}}, - {StatType::kGroupId, int64_t{2}}}); - CreateXEvent(&sc_plane_builder, &sc_step_line, "step.3", /*offset_ps=*/30, - /*duration_ps=*/10, - {{StatType::kStepIdleTimePs, int64_t{5}}, - {StatType::kGroupId, int64_t{3}}}); - CreateXEvent(&sc_plane_builder, &sc_step_line, "step.4", /*offset_ps=*/40, - /*duration_ps=*/10, - {{StatType::kStepIdleTimePs, int64_t{5}}, - {StatType::kGroupId, int64_t{4}}}); - XLineBuilder sc_module_line = sc_plane_builder.GetOrCreateLine(1); - sc_module_line.SetName(kSparseCoreModuleLineName); - CreateXEvent(&sc_plane_builder, &sc_module_line, "module.1", /*offset_ps=*/10, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{1}}}); - CreateXEvent(&sc_plane_builder, &sc_module_line, "module.2", /*offset_ps=*/20, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{2}}}); - CreateXEvent(&sc_plane_builder, &sc_module_line, "module.3", /*offset_ps=*/30, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{3}}}); - CreateXEvent(&sc_plane_builder, &sc_module_line, "module.4", /*offset_ps=*/40, - /*duration_ps=*/10, {{StatType::kGroupId, int64_t{4}}}); - XLineBuilder sc_op_line = sc_plane_builder.GetOrCreateLine(2); - sc_op_line.SetName(kSparseCoreOpLineName); - CreateXEvent(&sc_plane_builder, &sc_op_line, "scs op.1", /*offset_ps=*/15, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{1}}}); - CreateXEvent(&sc_plane_builder, &sc_op_line, "scs op.2", /*offset_ps=*/25, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{2}}}); - CreateXEvent(&sc_plane_builder, &sc_op_line, "scs op.3", /*offset_ps=*/35, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{3}}}); - CreateXEvent(&sc_plane_builder, &sc_op_line, "scs op.4", /*offset_ps=*/45, - /*duration_ps=*/5, {{StatType::kGroupId, int64_t{4}}}); - OpStats op_stats = ConvertXSpaceToOpStats( - space, - OpStatsOptions{.generate_op_metrics_db = true, .generate_step_db = true}); - EXPECT_EQ(op_stats.device_op_metrics_db().total_time_ps(), 40); - EXPECT_EQ(op_stats.device_op_metrics_db().total_op_time_ps(), 20); - EXPECT_EQ(op_stats.step_db().step_sequence_size(), 4); - EXPECT_EQ(op_stats.hlo_metrics_db_complete_steps_only().total_time_ps(), 40); - EXPECT_EQ(op_stats.hlo_metrics_db_complete_steps_only().total_op_time_ps(), - 20); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.cc b/tensorflow/core/profiler/convert/xplane_to_step_events.cc deleted file mode 100644 index e7debb44da8996..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_step_events.cc +++ /dev/null @@ -1,392 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_step_events.h" - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/match.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tensorflow/core/profiler/utils/trace_utils.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -inline AllReduceInfo GetAllReduceInfo(const XEventVisitor& event, - uint64_t all_reduce_unique_id) { - AllReduceInfo collective_ops; - collective_ops.set_id(all_reduce_unique_id); - collective_ops.set_start_time_ps(event.TimestampPs()); - if (auto device_offset_ps_stat = event.GetStat(StatType::kDeviceOffsetPs)) { - collective_ops.set_start_time_ps(device_offset_ps_stat->IntOrUintValue()); - } - collective_ops.set_end_time_ps(event.EndTimestampPs()); - if (auto device_duration_ps_stat = - event.GetStat(StatType::kDeviceDurationPs)) { - collective_ops.set_end_time_ps(collective_ops.start_time_ps() + - device_duration_ps_stat->IntOrUintValue()); - } - if (auto all_reduce_id_stat = event.GetStat(StatType::kAllReduceId)) { - collective_ops.set_all_reduce_id(all_reduce_id_stat->IntOrUintValue()); - } - if (auto bytes_accessed_stat = - event.Metadata().GetStat(StatType::kBytesAccessed)) { - collective_ops.set_byte_size(bytes_accessed_stat->IntOrUintValue()); - } - return collective_ops; -} - -inline bool IsExplicitHostStepMarker(absl::string_view event_name) { - return (absl::StartsWith(event_name, "train") || - absl::StartsWith(event_name, "test") || - absl::StartsWith(event_name, "TraceContext")) && - !absl::StrContains(event_name, "/"); -} - -// Returns true if the given event_name should be considered as real computation -// on CPU. -inline bool IsRealCpuCompute(absl::string_view event_name) { - bool not_real = absl::StartsWith(event_name, "EagerExecute") || - absl::StartsWith(event_name, "EagerLocalExecute") || - absl::StartsWith(event_name, "EagerKernelExecute") || - absl::StartsWith(event_name, "FunctionRun") || - IsExplicitHostStepMarker(event_name); - return !not_real; -} - -uint64 ParseNumBytesFromMemcpyDetail(absl::string_view memcpy_detail) { - const std::vector params = - absl::StrSplit(memcpy_detail, absl::ByAnyChar(":\n")); - - // Processes value pairs. - for (uint32 ii = 0; ii < params.size(); ii += 2) { - if (params[ii] != "num_bytes") continue; - uint64 value = 0; - if (absl::SimpleAtoi(params[ii + 1], &value)) return value; - break; - } - return 0ULL; -} - -EventType ClassifyGpuCompute(absl::string_view event_name, - absl::string_view tensor_shapes) { - if (tensor_shapes.empty()) { - // Deduces the precision from the name. - return (absl::StrContains(event_name, "half") || - absl::StrContains(event_name, "fp16")) - ? DEVICE_COMPUTE_16 - : DEVICE_COMPUTE_32; - } else { - // Deduces the precision from the shapes. - return (absl::StrContains(tensor_shapes, "half")) ? DEVICE_COMPUTE_16 - : DEVICE_COMPUTE_32; - } -} - -EventType ClassifyGpuEvent(absl::string_view event_name, - absl::string_view tensor_shapes) { - tsl::profiler::TfOp tf_op = tsl::profiler::ParseTfOpFullname(event_name); - if (tsl::profiler::IsMemcpyHToDOp(tf_op)) { - return HOST_TO_DEVICE; - } else if (tsl::profiler::IsMemcpyDToHOp(tf_op)) { - return DEVICE_TO_HOST; - } else if (tsl::profiler::IsMemcpyDToDOp(tf_op)) { - return DEVICE_TO_DEVICE; - } else if (absl::StartsWithIgnoreCase(event_name, "nccl")) { - return DEVICE_COLLECTIVES; - } else { - return ClassifyGpuCompute(event_name, tensor_shapes); - } -} - -EventType ClassifyCpuEvent(absl::string_view event_name, bool has_device, - bool has_correlation_id) { - tsl::profiler::TfOp tf_op = tsl::profiler::ParseTfOpFullname(event_name); - if (tsl::profiler::IsInfeedEnqueueOp(tf_op) || - tsl::profiler::IsMemcpyHToDOp(tf_op)) { - return HOST_TO_DEVICE; - } else if (tsl::profiler::IsMemcpyHToHOp(tf_op)) { - return HOST_TO_HOST; - } else if (has_device && (has_correlation_id || - absl::StartsWithIgnoreCase( - event_name, "ExecutorState::Process"))) { - // TODO(b/150420972): Separate runtime overhead from actual compute for - // CPU-only. - return HOST_PREPARE; - } else if (absl::StartsWithIgnoreCase(event_name, "IteratorGetNext")) { - return HOST_WAIT_INPUT; - } else { - return HOST_COMPUTE; - } -} - -} // namespace - -StepEvents ConvertHostThreadsXLineToStepEvents( - const XLineVisitor& line, const StepEvents* device_step_events) { - StepEvents result; - line.ForEachEvent([&](const XEventVisitor& event) { - int64_t correlation_id = -1; - int64_t group_id = -1; - absl::string_view step_name; - event.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kCorrelationId: - correlation_id = stat.IntValue(); - break; - case StatType::kGroupId: - group_id = stat.IntValue(); - break; - case StatType::kStepName: - step_name = stat.StrOrRefValue(); - break; - } - }); - if (group_id < 0) return; - // Don't add CPU events when (1) it includes device step events and (2) it - // doesn't have a device and that the group_id (i.e. step number) already - // appears on the device. This will filter out all cpu events that do not - // correspond to any steps executed on the device. - bool has_device = (device_step_events != nullptr); - if (has_device && !device_step_events->contains(group_id)) return; - if (IsExplicitHostStepMarker(event.Name())) { - result[group_id].AddMarker( - StepMarker(StepMarkerType::kExplicitHostStepMarker, event.Name(), - event.GetTimespan())); - } else if (!step_name.empty()) { - // Grouping adds a step_name stat to implicit host step markers. - result[group_id].AddMarker( - StepMarker(StepMarkerType::kImplicitHostStepMarker, event.Name(), - event.GetTimespan())); - } else if (IsRealCpuCompute(event.Name())) { - result[group_id].AddEvent(EventTypeSpan( - ClassifyCpuEvent(event.Name(), has_device, correlation_id >= 0), - event.GetTimespan())); - } - if (!step_name.empty()) { - result[group_id].SetStepName(std::string(step_name)); - } - }); - return result; -} - -StepEvents ConvertHostThreadsXPlaneToStepEvents( - const XPlane& host_trace, const StepEvents* device_step_events) { - StepEvents host_step_events; - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&host_trace); - plane.ForEachLine([&](const XLineVisitor& line) { - StepEvents thread_step_events = - ConvertHostThreadsXLineToStepEvents(line, device_step_events); - UnionCombineStepEvents(thread_step_events, &host_step_events); - }); - return host_step_events; -} - -StepEvents ConvertDeviceStepInfoToStepMarkers(const XLineVisitor& line) { - StepEvents result; - line.ForEachEvent([&](const XEventVisitor& event) { - if (std::optional stat = event.GetStat(StatType::kGroupId)) { - result[stat->IntValue()].AddMarker( - StepMarker(StepMarkerType::kDeviceStepMarker, event.Name(), - event.GetTimespan())); - } - }); - return result; -} - -StepEvents ConvertDeviceTraceXLineToStepEvents(const uint64 device_id, - const XLineVisitor& line) { - StepEvents result; - line.ForEachEvent([&](const XEventVisitor& event) { - int64_t correlation_id = -1; - int64_t group_id = -1; - absl::string_view tensor_shapes; - absl::string_view memcpy_details; - event.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kCorrelationId: - correlation_id = stat.IntValue(); - break; - case StatType::kGroupId: - group_id = stat.IntValue(); - break; - case StatType::kTensorShapes: - tensor_shapes = stat.StrOrRefValue(); - break; - case StatType::kMemcpyDetails: - memcpy_details = stat.StrOrRefValue(); - break; - } - }); - - if (correlation_id >= 0 && group_id >= 0) { - EventType event_type = ClassifyGpuEvent(event.Name(), tensor_shapes); - EventTypeSpan event_type_span(event_type, event.GetTimespan()); - result[group_id].AddEvent(event_type_span); - switch (event_type) { - case DEVICE_COLLECTIVES: { - AllReduceInfo collective_ops; - collective_ops.set_start_time_ps(event.TimestampPs()); - collective_ops.set_end_time_ps(event.EndOffsetPs()); - // TODO(jiesun): figure out how to get size info etc. - result[group_id].AddCollectiveOpEvent(device_id, collective_ops); - break; - } - case HOST_TO_DEVICE: - case DEVICE_TO_DEVICE: - case DEVICE_TO_HOST: { - // TODO(jiesun): not all memcpy events are grouped, figure out a - // better way to attribute them to steps. - uint64 bytes_transferred = - ParseNumBytesFromMemcpyDetail(memcpy_details); - result[group_id].AddDeviceMemoryTransferEvent( - event_type, event.GetTimespan(), bytes_transferred); - break; - } - default: - return; - } - } - }); - return result; -} - -StepEvents ConvertTpuDeviceTraceXLineToStepEvents(const uint64 device_id, - const XLineVisitor& line) { - StepEvents result; - absl::flat_hash_map - op_metrics_builder; - struct ParentRef { - const XEventVisitor event; - tsl::profiler::Timespan device_timespan; - uint64_t children_duration_ps = 0; - int64_t group_id = -1; - }; - tsl::profiler::AncestorStack event_stack( - // Adds an OpMetric to the builder based on the provided parent reference. - [&](const ParentRef& parent) { - OpMetrics op_metrics = FromXEvent(parent.event); - op_metrics.set_time_ps(parent.device_timespan.duration_ps()); - // TODO(b/397774568): Remove this once the SparseCore OpMetricsDb is - // implemented. - if (device_id < kSparseCoreIndexStart) { - op_metrics.set_self_time_ps(op_metrics.time_ps() - - parent.children_duration_ps); - } - op_metrics_builder[parent.group_id].AddOpMetric( - op_metrics, GetOpKeyFromXEvent(parent.event)); - }, - // Checks if the child event is a child of the parent event. - [](const ParentRef& parent, const ParentRef& child) { - return parent.device_timespan.Includes(child.device_timespan); - }, - // Adds the child duration to the parent. - [](ParentRef& parent, ParentRef& child) { - parent.children_duration_ps += child.device_timespan.duration_ps(); - }); - line.ForEachEvent([&](const XEventVisitor& event) { - auto group_id_stat = event.GetStat(StatType::kGroupId); - if (!group_id_stat.has_value()) return; - int64_t group_id = group_id_stat->IntOrUintValue(); - event_stack.Push(ParentRef{ - .event = event, - .device_timespan = tsl::profiler::GetDeviceEventTimespan(event), - .group_id = group_id, - }); - - if (auto all_reduce_unique_id_stat = - event.GetStat(StatType::kAllReduceUniqueId)) { - result[group_id].AddCollectiveOpEvent( - device_id, - GetAllReduceInfo(event, all_reduce_unique_id_stat->IntOrUintValue())); - } - }); - event_stack.Flush(); - for (auto& [group_id, builder] : op_metrics_builder) { - // Finalize Without the step time now. - result[group_id].SetPerCoreOpMetricsDb(builder.Finalize(), device_id); - } - return result; -} - -StepEvents ConvertDeviceTraceXPlaneToStepEvents(const XPlane& device_trace) { - StepEvents device_step_events; - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace); - std::optional tpu_core_id = tsl::profiler::GetTensorCoreId(plane.Name()); - std::optional sc_core_id = tsl::profiler::GetSparseCoreId(plane.Name()); - plane.ForEachLine([&](const XLineVisitor& line) { - int64_t line_id = line.Id(); - if (line_id == kThreadIdStepInfo || - (tpu_core_id.has_value() && - line.Name() == tsl::profiler::kStepLineName)) { - // TODO(b/397774568): Re-add processing of SparseCore steps once the - // SparseCore OpMetricsDb is implemented. - StepEvents step_marker_events = ConvertDeviceStepInfoToStepMarkers(line); - UnionCombineStepEvents(step_marker_events, &device_step_events); - } else if (IsDerivedThreadId(line_id)) { - return; - } else { - StepEvents stream_step_events; - if (tpu_core_id.has_value()) { - if (!tsl::profiler::IsOpLineName(line.Name())) return; - // In TPU sampling mode, the profiling session could stop in the middle - // of a training step. In this case, the "XLA Ops" line will have - // one more step than the "Step" line. We need to intersect them to get - // the common step numbers. - stream_step_events = - ConvertTpuDeviceTraceXLineToStepEvents(plane.Id(), line); - IntersectCombineStepEvents(stream_step_events, &device_step_events); - } else if (sc_core_id.has_value()) { - // TODO(b/397774568): Switch to IsOpLineName once SparseCore OpMetricsDb - // is implemented. - if (line.Name() != tsl::profiler::kSparseCoreStepLineName) return; - stream_step_events = ConvertTpuDeviceTraceXLineToStepEvents( - kSparseCoreIndexStart + plane.Id(), line); - IntersectCombineStepEvents(stream_step_events, &device_step_events); - } else { - stream_step_events = - ConvertDeviceTraceXLineToStepEvents(plane.Id(), line); - UnionCombineStepEvents(stream_step_events, &device_step_events); - } - } - }); - return device_step_events; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.h b/tensorflow/core/profiler/convert/xplane_to_step_events.h deleted file mode 100644 index 35580f95281589..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_step_events.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_EVENTS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_EVENTS_H_ - -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Convert the host threads in XLine format to StepEvents format. If -// device_step_events is non-null, we will filter out events that only happens -// on CPU. -StepEvents ConvertHostThreadsXLineToStepEvents( - const XLineVisitor& line, const StepEvents* device_step_events); - -// Convert the host threads in XPlane format to StepEvents format. If -// device_step_events is non-null, we will filter out events that only happens -// on CPU. -StepEvents ConvertHostThreadsXPlaneToStepEvents( - const XPlane& host_trace, const StepEvents* device_step_events); - -// Convert the device trace in XLine format to StepEvents. -StepEvents ConvertDeviceTraceXLineToStepEvents(const XLineVisitor& line); - -// Convert the device trace in XPlane format to StepEvents. -StepEvents ConvertDeviceTraceXPlaneToStepEvents(const XPlane& device_trace); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_EVENTS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc deleted file mode 100644 index 2389f619edf3c5..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc +++ /dev/null @@ -1,189 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_step_events.h" - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/utils/event_span.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -// Tests with a sample profile with two steps captured on the host but only one -// step on the device. On the host, each step consists of TraceContext -> -// FunctionRun -> ExecutorState::Process -> matmul. On the host, each step -// consists of matmul. The host's step db should be created only for the step -// observed on the host. -TEST(ConvertXPlaneToOpStats, CpuOnlyStepDbTest) { - constexpr int64_t kFirstStepNum = 123; - constexpr int64_t kSecondStepNum = 456; - constexpr int64_t kFirstStepId = 0; - constexpr int64_t kSecondStepId = 1; - constexpr int64_t kFirstCorrelationId = 100; - constexpr int64_t kSecondCorrelationId = 200; - - XSpace space; - XPlane* host_plane = GetOrCreateHostXPlane(&space); - XPlaneBuilder host_plane_builder(host_plane); - host_plane_builder.ReserveLines(2); - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext, - 0, 100, {{StatType::kStepNum, kFirstStepNum}}); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun, - 10, 90, - {{StatType::kStepId, kFirstStepId}, - {StatType::kProducerType, int64_t{1}}, - {StatType::kProducerId, kFirstStepId}}); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext, - 300, 100, {{StatType::kStepNum, kSecondStepNum}}); - CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun, - 310, 90, - {{StatType::kStepId, kSecondStepId}, - {StatType::kProducerType, int64_t{1}}, - {StatType::kProducerId, kSecondStepId}}); - - auto tf_executor_thread = host_plane_builder.GetOrCreateLine(1); - CreateXEvent(&host_plane_builder, &tf_executor_thread, - HostEventType::kExecutorStateProcess, 20, 20, - {{StatType::kStepId, kFirstStepId}, - {StatType::kConsumerType, int64_t{1}}, - {StatType::kConsumerId, kFirstStepId}}); - CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 30, 10, - {{StatType::kCorrelationId, kFirstCorrelationId}}); - CreateXEvent(&host_plane_builder, &tf_executor_thread, - HostEventType::kExecutorStateProcess, 320, 20, - {{StatType::kStepId, kSecondStepId}, - {StatType::kConsumerType, int64_t{1}}, - {StatType::kConsumerId, kSecondStepId}}); - CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 330, 10, - {{StatType::kCorrelationId, kSecondCorrelationId}}); - - XPlane* device_plane = space.add_planes(); - XPlaneBuilder device_plane_builder(device_plane); - device_plane_builder.ReserveLines(1); - - auto stream = device_plane_builder.GetOrCreateLine(0); - CreateXEvent(&device_plane_builder, &stream, "matmul", 50, 40, - {{StatType::kCorrelationId, kFirstCorrelationId}}); - - tsl::profiler::GroupTfEvents(&space); - StepEvents device_step_events = - ConvertDeviceTraceXPlaneToStepEvents(*device_plane); - EXPECT_EQ(device_step_events.size(), 1); - EXPECT_EQ(device_step_events[0].Events().size(), 1); - StepEvents host_step_events = - ConvertHostThreadsXPlaneToStepEvents(*host_plane, &device_step_events); - // Should contain only the step which is also present on the device. - EXPECT_EQ(host_step_events.size(), 1); - // TraceContext should be added as a step marker. - EXPECT_EQ(host_step_events[0].Markers().size(), 1); - // FunctionRun shouldn't be added. - EXPECT_EQ(host_step_events[0].Events().size(), 2); -} - -TEST(ConvertXPlaneToStepEvents, TpuDevicePlaneToStepEvents) { - XPlane raw_plane; - XPlaneBuilder plane(&raw_plane); - int64_t device_id = 1; - plane.SetId(device_id); - plane.SetName("/device:TPU:0"); - XLineBuilder op_line = plane.GetOrCreateLine(0); - op_line.SetName(tsl::profiler::kXlaOpLineName); - const XStatMetadata& program_id_stat = - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kProgramId)); - const XStatMetadata& symbol_id_stat = - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)); - const XStatMetadata& group_id_stat = - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId)); - { - XEventMetadata* event_metadata = - plane.GetOrCreateEventMetadata("op_long_name"); - event_metadata->set_display_name("op_name"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue(program_id_stat, 1); - stats.AddStatValue(symbol_id_stat, 1); - { - XEventBuilder event = op_line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(50); - event.AddStatValue(group_id_stat, 1); - } - { - XEventBuilder event = op_line.AddEvent(*event_metadata); - event.SetOffsetPs(100); - event.SetDurationPs(50); - event.AddStatValue(group_id_stat, 2); - } - } - { - XEventMetadata* event_metadata = - plane.GetOrCreateEventMetadata("op_long_name2"); - event_metadata->set_display_name("op_name2"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue(program_id_stat, 1); - stats.AddStatValue(symbol_id_stat, 2); - XEventBuilder event = op_line.AddEvent(*event_metadata); - event.SetOffsetPs(50); - event.SetDurationPs(50); - event.AddStatValue(group_id_stat, 1); - } - XLineBuilder step_line = plane.GetOrCreateLine(1); - step_line.SetName(tsl::profiler::kStepLineName); - { - XEventMetadata* event_metadata = plane.CreateEventMetadata(); - XStatsBuilder stats(event_metadata, &plane); - { - XEventBuilder event = step_line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(100); - event.AddStatValue(group_id_stat, 1); - } - { - XEventBuilder event = step_line.AddEvent(*event_metadata); - event.SetOffsetPs(100); - event.SetDurationPs(100); - event.AddStatValue(group_id_stat, 2); - } - } - - StepEvents step_events = ConvertDeviceTraceXPlaneToStepEvents(raw_plane); - EXPECT_EQ(step_events.size(), 2); - EXPECT_TRUE(step_events.contains(1)); - StepDetails step_1 = step_events[/*group_id=*/1]; - ASSERT_TRUE(step_1.PerCoreOpMetricsDb().contains(device_id)); - EXPECT_EQ(step_1.PerCoreOpMetricsDb().at(device_id).metrics_db_size(), 2); - EXPECT_EQ(step_1.Markers().size(), 1); - EXPECT_TRUE(step_events.contains(2)); - StepDetails step_2 = step_events[/*group_id=*/2]; - ASSERT_TRUE(step_2.PerCoreOpMetricsDb().contains(device_id)); - EXPECT_EQ(step_2.PerCoreOpMetricsDb().at(device_id).metrics_db_size(), 1); - EXPECT_EQ(step_2.Markers().size(), 1); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_step_stats.cc b/tensorflow/core/profiler/convert/xplane_to_step_stats.cc index b653d7723f2310..8645876955bb1f 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_stats.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/utils/gpu_event_stats.h" // from @org_xprof namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc deleted file mode 100644 index ae7871569e3046..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc +++ /dev/null @@ -1,523 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h" - -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" -#include "tensorflow/core/profiler/utils/html_utils.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -// 50 us from https://www.tensorflow.org/guide/data_performance_analysis -const int64_t kSlowCallThresholdPs = 50 * 1000000; - -namespace { - -// Returns true if the given iterator event is for a root iterator. -bool IsRootIteratorEvent(const XEventVisitor& iterator_event) { - std::vector split_result = - absl::StrSplit(iterator_event.Name(), "::"); - // The root iterator's name contains only its own name (no parent - // information). - return split_result.size() == 2; -} - -// Returns true if the given iterator event name is for an async iterator. -bool IsAsyncIterator(absl::string_view iterator_event_name) { - static auto* kAsyncIterators = new absl::flat_hash_set( - {"Prefetch", "ParallelInterleave", "ParallelMap", "ParseExample", - "MapAndBatch", "DataService", "LegacyParallelInterleave", - "ParallelBatch"}); - return kAsyncIterators->contains(iterator_event_name); -} - -void SetIteratorMetadata(int64_t id, const XEventVisitor& event, - IteratorMetadata* metadata) { - metadata->set_id(id); - auto parent_id_stat = event.GetStat(StatType::kParentId); - if (parent_id_stat.has_value()) { - metadata->set_parent_id(parent_id_stat->IntValue()); - } - metadata->set_name(tsl::profiler::IteratorName(event.Name())); - metadata->set_long_name(event.Name().data(), event.Name().size()); - metadata->set_is_async(IsAsyncIterator(metadata->name())); - // TODO(b/161831651): Set params. -} - -// Returns the parent iterator's id if it is a root of a device input -// pipeline. -std::optional FindDeviceInputPipeline(const XEventVisitor& event) { - if (event.Type() == HostEventType::kDeviceInputPipelineSecondIterator) { - auto parent_id_stat = event.GetStat(StatType::kParentId); - if (parent_id_stat.has_value()) return parent_id_stat->IntValue(); - } - return std::nullopt; -} - -// Processes tsl::profiler::EventForest to do the following: -// (1) set iterator metadata -// (2) find root iterator events -// (3) find device input pipeline ids -void ProcessEventForest( - const tsl::profiler::EventForest& event_forest, - absl::flat_hash_set* device_input_pipeline_ids, - absl::flat_hash_map>* - root_iterator_event_map, - TfDataStats* tf_data_stats) { - const tsl::profiler::EventNodeMap& event_node_map = - event_forest.GetEventNodeMap(); - auto* iterator_event_list = - gtl::FindOrNull(event_node_map, HostEventType::kIterator); - if (!iterator_event_list) return; - for (const tsl::profiler::EventNode& iterator_event : *iterator_event_list) { - const XEventVisitor& iterator_event_visitor = - iterator_event.GetEventVisitor(); - auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId); - if (!iterator_id_stat.has_value()) continue; - int64_t iterator_id = iterator_id_stat->IntValue(); - auto result = tf_data_stats->mutable_iterator_metadata()->insert( - {iterator_id, IteratorMetadata()}); - IteratorMetadata& metadata = result.first->second; - if (result.second) { - // First time processing this iterator. - SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata); - } - if (IsRootIteratorEvent(iterator_event_visitor)) { - // Record root iterator events. - (*root_iterator_event_map)[iterator_id].push_back(&iterator_event); - } - } - auto* device_input_pipeline_second_iterator_events = gtl::FindOrNull( - event_node_map, HostEventType::kDeviceInputPipelineSecondIterator); - if (!device_input_pipeline_second_iterator_events) return; - for (const tsl::profiler::EventNode& iterator_event : - *device_input_pipeline_second_iterator_events) { - const XEventVisitor& iterator_event_visitor = - iterator_event.GetEventVisitor(); - auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId); - if (!iterator_id_stat.has_value()) continue; - int64_t iterator_id = iterator_id_stat->IntValue(); - auto result = tf_data_stats->mutable_iterator_metadata()->insert( - {iterator_id, IteratorMetadata()}); - IteratorMetadata& metadata = result.first->second; - if (result.second) { - // First time processing this iterator. - SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata); - // Find and record device input pipeline ids. - std::optional device_input_pipeline_id = - FindDeviceInputPipeline(iterator_event_visitor); - if (device_input_pipeline_id.has_value()) { - device_input_pipeline_ids->insert(*device_input_pipeline_id); - } - } - } -} - -void SetInputPipelineMetadata(int64_t id, int64_t name_id, - bool is_device_input_pipeline, - InputPipelineMetadata* metadata) { - constexpr absl::string_view kHostInputPipelinePrefix = "Host:"; - constexpr absl::string_view kDeviceInputPipelinePrefix = "Device:"; - metadata->set_id(id); - if (is_device_input_pipeline) { - metadata->set_type(InputPipelineMetadata::DEVICE); - metadata->set_name(absl::StrCat(kDeviceInputPipelinePrefix, name_id)); - } else { - metadata->set_type(InputPipelineMetadata::HOST); - metadata->set_name(absl::StrCat(kHostInputPipelinePrefix, name_id)); - } -} - -void ProcessIteratorEvent(const tsl::profiler::EventNode& iterator_event, - InputPipelineStat* input_pipeline_stat, - bool is_blocking, int level = 0) { - if (level > 100) return; - const XEventVisitor& visitor = iterator_event.GetEventVisitor(); - auto iterator_id_stat = visitor.GetStat(StatType::kStepId); - if (!iterator_id_stat.has_value()) return; - int64_t iterator_id = iterator_id_stat->IntValue(); - auto result = input_pipeline_stat->mutable_iterator_stats()->insert( - {iterator_id, IteratorStat()}); - IteratorStat& iterator_stat = result.first->second; - if (result.second) { - iterator_stat.set_id(iterator_id); - iterator_stat.set_start_time_ps(visitor.TimestampPs()); - } - iterator_stat.set_duration_ps(iterator_stat.duration_ps() + - visitor.DurationPs()); - int64_t self_time_ps = visitor.DurationPs(); - tsl::profiler::Timespan self_time_span = visitor.GetTimespan(); - for (const tsl::profiler::EventNode* child : iterator_event.GetChildren()) { - const XEventVisitor& child_visitor = child->GetEventVisitor(); - if (tsl::profiler::ParseTfOpFullname(child_visitor.Name()).category == - tsl::profiler::Category::kTfData) { - int64_t overlap_duration_ps = - self_time_span.OverlappedDurationPs(child_visitor.GetTimespan()); - ProcessIteratorEvent(*child, input_pipeline_stat, - is_blocking && overlap_duration_ps, level + 1); - // Note: Assume no overlap between child events. - self_time_ps -= overlap_duration_ps; - } - } - iterator_stat.set_self_time_ps(iterator_stat.self_time_ps() + self_time_ps); - iterator_stat.set_is_blocking(iterator_stat.is_blocking() || is_blocking); - iterator_stat.set_num_calls(iterator_stat.num_calls() + 1); -} - -void SetBottleneckIteratorId(InputPipelineStat* input_pipeline_stat) { - int64_t bottleneck_iterator_id = 0; - int64_t max_self_time = 0; - for (const auto& pair : input_pipeline_stat->iterator_stats()) { - const auto& id = pair.first; - const auto& iterator_stat = pair.second; - if (iterator_stat.is_blocking() && - iterator_stat.self_time_ps() > max_self_time) { - bottleneck_iterator_id = id; - max_self_time = iterator_stat.self_time_ps(); - } - } - input_pipeline_stat->set_bottleneck_iterator_id(bottleneck_iterator_id); - input_pipeline_stat->set_bottleneck_iterator_latency_ps(max_self_time); -} - -void ProcessInputPipelines( - const absl::flat_hash_set& device_input_pipeline_ids, - absl::flat_hash_map>* - root_iterator_event_map, - TfDataStats* tf_data_stats) { - auto* input_pipelines = tf_data_stats->mutable_input_pipelines(); - int64_t num_host_input_pipelines = 0; - int64_t num_device_input_pipelines = 0; - for (auto& id_and_events : *root_iterator_event_map) { - auto& root_iterator_id = id_and_events.first; - auto& root_iterator_events = id_and_events.second; - absl::c_sort(root_iterator_events, [](const tsl::profiler::EventNode* lhs, - const tsl::profiler::EventNode* rhs) { - return lhs->GetEventVisitor().DurationPs() > - rhs->GetEventVisitor().DurationPs(); - }); - auto result = - input_pipelines->insert({root_iterator_id, InputPipelineStats()}); - InputPipelineStats& input_pipeline_stats = result.first->second; - InputPipelineMetadata* metadata = input_pipeline_stats.mutable_metadata(); - if (result.second) { - bool is_device_input_pipeline = - device_input_pipeline_ids.contains(root_iterator_id); - int64_t name_id = is_device_input_pipeline ? num_device_input_pipelines++ - : num_host_input_pipelines++; - SetInputPipelineMetadata(root_iterator_id, name_id, - is_device_input_pipeline, metadata); - } - int64_t sum_latency_ps = 0; - int64_t min_latency_ps = INT64_MAX; - int64_t max_latency_ps = 0; - int64_t num_slow_calls = 0; - for (const tsl::profiler::EventNode* root_iterator_event : - root_iterator_events) { - InputPipelineStat* stat = input_pipeline_stats.add_stats(); - ProcessIteratorEvent(*root_iterator_event, stat, - /*is_blocking*/ true); - SetBottleneckIteratorId(stat); - int64_t latency_ps = root_iterator_event->GetEventVisitor().DurationPs(); - sum_latency_ps += latency_ps; - min_latency_ps = std::min(min_latency_ps, latency_ps); - max_latency_ps = std::max(max_latency_ps, latency_ps); - if (latency_ps > kSlowCallThresholdPs) num_slow_calls++; - } - input_pipeline_stats.set_avg_latency_ps(sum_latency_ps / - root_iterator_events.size()); - input_pipeline_stats.set_min_latency_ps(min_latency_ps); - input_pipeline_stats.set_max_latency_ps(max_latency_ps); - input_pipeline_stats.set_num_slow_calls(num_slow_calls); - } -} - -void SetBottleneckAnalysis(CombinedTfDataStats* combined_tf_data_stats) { - struct InputPipeline { - InputPipeline(absl::string_view host_name, - absl::string_view input_pipeline_name, int64_t max_latency_ps, - absl::string_view iterator_name, - absl::string_view iterator_long_name, - int64_t iterator_latency_ps) - : host_name(host_name), - input_pipeline_name(input_pipeline_name), - max_latency_ps(max_latency_ps), - iterator_name(iterator_name), - iterator_long_name(iterator_long_name), - iterator_latency_ps(iterator_latency_ps) {} - absl::string_view host_name; - absl::string_view input_pipeline_name; - int64_t max_latency_ps; - absl::string_view iterator_name; - absl::string_view iterator_long_name; - int64_t iterator_latency_ps; - - bool operator<(const InputPipeline& rhs) const { - return max_latency_ps > rhs.max_latency_ps; - } - }; - std::vector slow_input_pipelines; - for (const auto& host_name_and_tf_data_stats : - combined_tf_data_stats->tf_data_stats()) { - absl::string_view host_name = host_name_and_tf_data_stats.first; - const TfDataStats& tf_data_stats = host_name_and_tf_data_stats.second; - for (const auto& id_and_stats : tf_data_stats.input_pipelines()) { - const InputPipelineStats& input_pipeline_stats = id_and_stats.second; - if (input_pipeline_stats.metadata().type() == - InputPipelineMetadata::DEVICE) { - // Ignore device input pipelines. - continue; - } - // Choose the slowest execution trace of the input pipeline. - // `input_pipeline_stats.stats` is already sorted so choose the first one. - const InputPipelineStat& input_pipeline_stat = - input_pipeline_stats.stats(0); - const IteratorMetadata& metadata = tf_data_stats.iterator_metadata().at( - input_pipeline_stat.bottleneck_iterator_id()); - slow_input_pipelines.emplace_back( - host_name, input_pipeline_stats.metadata().name(), - input_pipeline_stats.max_latency_ps(), metadata.name(), - metadata.long_name(), - input_pipeline_stat.bottleneck_iterator_latency_ps()); - } - } - std::sort(slow_input_pipelines.begin(), slow_input_pipelines.end()); - for (const auto& input_pipeline : slow_input_pipelines) { - TfDataBottleneckAnalysis* bottleneck_analysis = - combined_tf_data_stats->add_bottleneck_analysis(); - bottleneck_analysis->set_host(input_pipeline.host_name.data(), - input_pipeline.host_name.size()); - bottleneck_analysis->set_input_pipeline( - input_pipeline.input_pipeline_name.data(), - input_pipeline.input_pipeline_name.size()); - bottleneck_analysis->set_max_latency_ps(input_pipeline.max_latency_ps); - bottleneck_analysis->set_iterator_name(input_pipeline.iterator_name.data(), - input_pipeline.iterator_name.size()); - bottleneck_analysis->set_iterator_long_name( - input_pipeline.iterator_long_name.data(), - input_pipeline.iterator_long_name.size()); - bottleneck_analysis->set_iterator_latency_ps( - input_pipeline.iterator_latency_ps); - } -} - -std::string GetSuggestion(BottleneckType type) { - constexpr absl::string_view kPlaybookLink = - "https://www.tensorflow.org/guide/data_performance_analysis"; - constexpr absl::string_view kPlaybookSourceDatasetLink = - "https://www.tensorflow.org/guide/" - "data_performance_analysis#source_datasets"; - constexpr absl::string_view kPlaybookCpuUtilizationLink = - "https://www.tensorflow.org/guide/" - "data_performance_analysis#3_are_you_reaching_high_cpu_utilization"; - constexpr absl::string_view kPlaybookTransformationLink = - "https://www.tensorflow.org/guide/" - "data_performance_analysis#transformation_datasets"; - constexpr absl::string_view kTfGuideParallelDataExtractionLink = - "https://www.tensorflow.org/guide/" - "data_performance#parallelizing_data_extraction"; - constexpr absl::string_view kTfGuideParallelTransformationLink = - "https://www.tensorflow.org/guide/" - "data_performance#parallelizing_data_transformation"; - constexpr absl::string_view kTfGuideCacheLink = - "https://www.tensorflow.org/guide/data_performance#caching"; - constexpr absl::string_view kTfDataServiceLink = - "https://www.tensorflow.org/api_docs/python/tf/data/experimental/" - "service?version=nightly"; - switch (type) { - case BottleneckType::kSlowSource: - return absl::StrFormat( - "1. Check the locality of a host and input data. Ideally, they " - "should be in the same cell (or very close, like the same " - "region).
" - "2. Parallelize reading from this dataset source. See %s and %s for " - "more details.
", - AnchorElement(kPlaybookSourceDatasetLink, "here"), - AnchorElement(kTfGuideParallelDataExtractionLink, "here")); - case BottleneckType::kSlowDataService: - return absl::StrFormat( - "1. Fetching data from tf.data service took a while. Profile the " - "tf.data service worker to analyze the issue further.
" - "2. See %s for more details on tf.data service.
" - "3. See %s for other suggestions.", - AnchorElement(kTfDataServiceLink, "this"), - AnchorElement(kPlaybookLink, "this")); - case BottleneckType::kSlowRemoteSource: - return absl::StrFormat( - "1. The remote data source is slow. Profile its host to analyze the " - "issue further.
" - "2. See %s for other suggestions.", - AnchorElement(kPlaybookLink, "this")); - case BottleneckType::kSlowTransformationWithParallelVersion: - return absl::StrFormat( - "1. Parallelize this transformation by setting " - "num_parallel_calls=tf.data.experimental.AUTOTUNE. See " - "%s for more details.
" - "2. Consider adding cache after this transformation if " - "your data fits into memory and it is appropriate (e.g., there is no " - "randomness in upstream transformations like shuffle). " - "See %s for more details.
" - "3. Find more resources %s.", - AnchorElement(kTfGuideParallelTransformationLink, "this"), - AnchorElement(kTfGuideCacheLink, "this"), - AnchorElement(kPlaybookTransformationLink, "here")); - case BottleneckType::kSlowTransformationWithoutParallelVersion: - return absl::StrFormat( - "1. This transformation is inherently sequential. Add outer " - "parallelism by running multiple copies of the input pipeline over " - "sharded inputs and combining the results. See %s for more " - "details.
" - "2. Consider adding cache after this transformation if " - "your data fits into memory and it is appropriate (e.g., there is no " - "randomness in upstream transformations like shuffle). " - "See %s for more details.
" - "3. Find more resources %s.", - AnchorElement(kPlaybookTransformationLink, "this"), - AnchorElement(kTfGuideCacheLink, "this"), - AnchorElement(kPlaybookCpuUtilizationLink, "here")); - default: - return absl::StrFormat("See %s for suggestions.", - AnchorElement(kPlaybookLink, "this")); - } -} - -void SetSuggestion(CombinedTfDataStats* combined_tf_data_stats) { - for (TfDataBottleneckAnalysis& bottleneck_analysis : - *combined_tf_data_stats->mutable_bottleneck_analysis()) { - bottleneck_analysis.set_suggestion( - GetSuggestion(GetBottleneckType(bottleneck_analysis.iterator_name()))); - } -} - -void SetSummary(CombinedTfDataStats* combined_tf_data_stats) { - int64_t max_latency_ps = 0; - if (combined_tf_data_stats->bottleneck_analysis_size()) { - max_latency_ps = - combined_tf_data_stats->bottleneck_analysis().at(0).max_latency_ps(); - } - if (max_latency_ps > kSlowCallThresholdPs) { - combined_tf_data_stats->set_is_input_bound(true); - combined_tf_data_stats->set_summary( - "Your profile has a tf.data input pipeline slower than 50 us. For each " - "slow input pipeline, below shows a bottleneck in the input pipeline " - "and a suggestion on how to fix it."); - } else if (max_latency_ps > 0) { - combined_tf_data_stats->set_is_input_bound(false); - combined_tf_data_stats->set_summary( - "Your profile does not have any tf.data input pipeline slower than 50 " - "us. Your job could be still input bound if this profile didn't " - "capture all workers."); - } else { - combined_tf_data_stats->set_is_input_bound(false); - combined_tf_data_stats->set_summary( - "No tf.data activity captured in your profile. If your job uses " - "tf.data, try to capture a longer profile."); - } -} - -} // namespace - -BottleneckType GetBottleneckType(absl::string_view bottleneck_iterator_name) { - static auto* kBottleneckTypeMap = new absl::flat_hash_map( - {// Read from storage. - {"TFRecord", BottleneckType::kSlowSource}, - {"SSTable", BottleneckType::kSlowSource}, - {"RecordIO", BottleneckType::kSlowSource}, - {"Spanner", BottleneckType::kSlowSource}, - {"TFColumn", BottleneckType::kSlowSource}, - {"SleepwalkRemoteDataset", BottleneckType::kSlowSource}, - {"TextLine", BottleneckType::kSlowSource}, - {"StitchedTimelineDataset", BottleneckType::kSlowSource}, - {"DateKeyDataset", BottleneckType::kSlowSource}, - {"CapacitorProto", BottleneckType::kSlowSource}, - {"LMDB", BottleneckType::kSlowSource}, - {"ExternalDataset", BottleneckType::kSlowSource}, - {"PearModel", BottleneckType::kSlowSource}, - {"FixedLengthRecordV2", BottleneckType::kSlowSource}, - // Read from local memory. - {"FromTensor", BottleneckType::kSlowSource}, - {"TensorSlice", BottleneckType::kSlowSource}, - {"Generator", BottleneckType::kSlowSource}, - {"SyntheticDatasetOp", BottleneckType::kSlowSource}, - // tf.data service. - {"DataService", BottleneckType::kSlowDataService}, - // Read from remote memory. - {"GuzzlerDataGuzzlerRemoteDataset", BottleneckType::kSlowRemoteSource}, - {"ReverbDataset", BottleneckType::kSlowRemoteSource}, - {"DatasetSampleGame", BottleneckType::kSlowRemoteSource}, - {"Courier", BottleneckType::kSlowRemoteSource}, - {"ReverbEpisodeDataset", BottleneckType::kSlowRemoteSource}, - // Transformations with parallel version. - {"Map", BottleneckType::kSlowTransformationWithParallelVersion}, - {"Interleave", BottleneckType::kSlowTransformationWithParallelVersion}, - // Transformations without parallel version. - {"Filter", BottleneckType::kSlowTransformationWithoutParallelVersion}, - {"Batch", BottleneckType::kSlowTransformationWithoutParallelVersion}, - {"Unbatch", BottleneckType::kSlowTransformationWithoutParallelVersion}}); - if (auto type = - gtl::FindOrNull(*kBottleneckTypeMap, bottleneck_iterator_name)) { - return *type; - } - return BottleneckType::kOther; -} - -void CombinedTfDataStatsBuilder::Add(absl::string_view host_name, - XPlane* host_plane) { - TfDataStats& tf_data_stats = - (*combined_tf_data_stats_ - ->mutable_tf_data_stats())[std::string(host_name)]; - tsl::profiler::EventForest event_forest; - event_forest.AddPlanes(tsl::profiler::CreateTfXPlaneVisitor, {host_plane}); - event_forest.ConnectEvents(); - event_forest.ConnectTfDataEvents(); - absl::flat_hash_set device_input_pipeline_ids; - absl::flat_hash_map> - root_iterator_event_map; - ProcessEventForest(event_forest, &device_input_pipeline_ids, - &root_iterator_event_map, &tf_data_stats); - ProcessInputPipelines(device_input_pipeline_ids, &root_iterator_event_map, - &tf_data_stats); -} - -void CombinedTfDataStatsBuilder::Finalize() { - SetBottleneckAnalysis(combined_tf_data_stats_); - if (generate_suggestion_) SetSuggestion(combined_tf_data_stats_); - SetSummary(combined_tf_data_stats_); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h deleted file mode 100644 index f5f53488791942..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ - -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -TF_CONST_INIT extern const int64_t kSlowCallThresholdPs; - -enum class BottleneckType { - kSlowSource, - kSlowDataService, - kSlowRemoteSource, - kSlowTransformationWithParallelVersion, - kSlowTransformationWithoutParallelVersion, - kOther, -}; - -BottleneckType GetBottleneckType(absl::string_view bottleneck_iterator_name); - -class CombinedTfDataStatsBuilder { - public: - explicit CombinedTfDataStatsBuilder( - CombinedTfDataStats* combined_tf_data_stats, - bool generate_suggestion = true) - : combined_tf_data_stats_(combined_tf_data_stats), - generate_suggestion_(generate_suggestion) {} - - void Add(absl::string_view host_name, XPlane* host_plane); - - // Finalizes by populating TfDataBottleneckAnalysis. - void Finalize(); - - private: - CombinedTfDataStats* combined_tf_data_stats_; - bool generate_suggestion_; -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc deleted file mode 100644 index 64f1f68fe3226e..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc +++ /dev/null @@ -1,419 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h" - -#include - -#include -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::testing::EqualsProto; - -// Test with the following example dataset: -// dataset = tf.data.Dataset.range(8) -// dataset = dataset.prefetch(2) -// for _ in dataset: -// pass -TEST(XPlaneToTfDataStatsTest, HostInputPipeline) { - constexpr int64_t kPrefetchIteratorId = 123; - constexpr int64_t kRangeIteratorId = 456; - constexpr int64_t kFirstElementId = 100; - constexpr int64_t kSecondElementId = 200; - - XPlane host_plane; - XPlaneBuilder host_plane_builder(&host_plane); - host_plane_builder.ReserveLines(2); - - auto consumer_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", 0, - 100000000, {{StatType::kStepId, kPrefetchIteratorId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, - HostEventType::kPrefetchConsume, 80000000, 20000000, - {{StatType::kElementId, kFirstElementId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", - 200000000, 20000000, {{StatType::kStepId, kPrefetchIteratorId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, - HostEventType::kPrefetchConsume, 210000000, 10000000, - {{StatType::kElementId, kSecondElementId}}); - - auto producer_thread = host_plane_builder.GetOrCreateLine(1); - // Blocking producer. - CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kPrefetchProduce, 0, 80000000, - {{StatType::kElementId, kFirstElementId}}); - CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::Prefetch::Range", 0, 80000000, - {{StatType::kStepId, kRangeIteratorId}, - {StatType::kParentId, kPrefetchIteratorId}}); - // Non-blocking producer. - CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kPrefetchProduce, 100000000, 80000000, - {{StatType::kElementId, kSecondElementId}}); - CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::Prefetch::Range", 100000000, 80000000, - {{StatType::kStepId, kRangeIteratorId}, - {StatType::kParentId, kPrefetchIteratorId}}); - - CombinedTfDataStats combined_tf_data_stats; - CombinedTfDataStatsBuilder builder(&combined_tf_data_stats); - builder.Add("host1", &host_plane); - builder.Finalize(); - EXPECT_THAT( - combined_tf_data_stats, EqualsProto(R"pb( - bottleneck_analysis: { - host: "host1" - input_pipeline: "Host:0" - max_latency_ps: 100000000 - iterator_name: "Range" - iterator_long_name: "Iterator::Prefetch::Range" - iterator_latency_ps: 80000000 - suggestion: "See this for suggestions." - } - tf_data_stats: { - key: "host1" - value: { - iterator_metadata: { - key: 123, - value: { - id: 123 - name: "Prefetch" - long_name: "Iterator::Prefetch" - is_async: true - } - } - iterator_metadata: { - key: 456, - value: { - id: 456 - parent_id: 123 - name: "Range" - long_name: "Iterator::Prefetch::Range" - is_async: false - } - } - input_pipelines { - key: 123, - value: { - metadata { id: 123 type: HOST name: "Host:0" } - avg_latency_ps: 60000000 - min_latency_ps: 20000000 - max_latency_ps: 100000000 - num_slow_calls: 1 - stats { - bottleneck_iterator_id: 456 - bottleneck_iterator_latency_ps: 80000000 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 0 - duration_ps: 100000000 - self_time_ps: 20000000 - is_blocking: true - num_calls: 1 - } - } - iterator_stats { - key: 456, - value: { - id: 456 - start_time_ps: 0 - duration_ps: 80000000 - self_time_ps: 80000000 - is_blocking: true - num_calls: 1 - } - } - } - stats { - bottleneck_iterator_id: 123 - bottleneck_iterator_latency_ps: 20000000 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 200000000 - duration_ps: 20000000 - self_time_ps: 20000000 - is_blocking: true - num_calls: 1 - } - } - iterator_stats { - key: 456, - value: { - id: 456 - start_time_ps: 100000000 - duration_ps: 80000000 - self_time_ps: 80000000 - is_blocking: false - num_calls: 1 - } - } - } - } - } - } - } - is_input_bound: true - summary: "Your profile has a tf.data input pipeline slower than 50 us. For each slow input pipeline, below shows a bottleneck in the input pipeline and a suggestion on how to fix it." - )pb")); -} - -TEST(XPlaneToTfDataStatsTest, DeviceInputPipeline) { - constexpr int64_t kPrefetchIteratorId = 123; - constexpr int64_t kRangeIteratorId = 456; - constexpr int64_t kElementId = 100; - - XPlane host_plane; - XPlaneBuilder host_plane_builder(&host_plane); - host_plane_builder.ReserveLines(2); - - auto consumer_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", 0, - 30000000, {{StatType::kStepId, kPrefetchIteratorId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", - 100000000, 100000000, - {{StatType::kStepId, kPrefetchIteratorId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, - HostEventType::kPrefetchConsume, 180000000, 20000000, - {{StatType::kElementId, kElementId}}); - - auto producer_thread = host_plane_builder.GetOrCreateLine(1); - CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kPrefetchProduce, 100000000, 80000000, - {{StatType::kElementId, kElementId}}); - CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::Prefetch::Generator", 100000000, 80000000, - {{StatType::kStepId, kRangeIteratorId}, - {StatType::kParentId, kPrefetchIteratorId}}); - - CombinedTfDataStats combined_tf_data_stats; - CombinedTfDataStatsBuilder builder(&combined_tf_data_stats); - builder.Add("host1", &host_plane); - builder.Finalize(); - // Device input pipeline is not considered for bottleneck analysis. - EXPECT_THAT( - combined_tf_data_stats, EqualsProto(R"pb( - tf_data_stats: { - key: "host1" - value: { - iterator_metadata: { - key: 123, - value: { - id: 123 - name: "Prefetch" - long_name: "Iterator::Prefetch" - is_async: true - } - } - iterator_metadata: { - key: 456, - value: { - id: 456 - parent_id: 123 - name: "Generator" - long_name: "Iterator::Prefetch::Generator" - is_async: false - } - } - input_pipelines { - key: 123, - value: { - metadata { id: 123 type: DEVICE name: "Device:0" } - avg_latency_ps: 65000000 - min_latency_ps: 30000000 - max_latency_ps: 100000000 - num_slow_calls: 1 - stats { - bottleneck_iterator_id: 456 - bottleneck_iterator_latency_ps: 80000000 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 100000000 - duration_ps: 100000000 - self_time_ps: 20000000 - is_blocking: true - num_calls: 1 - } - } - iterator_stats { - key: 456, - value: { - id: 456 - start_time_ps: 100000000 - duration_ps: 80000000 - self_time_ps: 80000000 - is_blocking: true - num_calls: 1 - } - } - } - stats { - bottleneck_iterator_id: 123 - bottleneck_iterator_latency_ps: 30000000 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 0 - duration_ps: 30000000 - self_time_ps: 30000000 - is_blocking: true - num_calls: 1 - } - } - } - } - } - } - } - summary: "No tf.data activity captured in your profile. If your job uses tf.data, try to capture a longer profile." - )pb")); -} - -// Test with the following example dataset: -// dataset = tf.data.Dataset.range(8) -// dataset = dataset.map(lambda x: x + 1) -// dataset = dataset.batch(2) -// for _ in dataset: -// pass -TEST(XPlaneToTfDataStatsTest, MapAndBatch) { - constexpr int64_t kMapAndBatchIteratorId = 123; - constexpr int64_t kRangeIteratorId = 456; - constexpr int64_t kElementId = 100; - - XPlane host_plane; - XPlaneBuilder host_plane_builder(&host_plane); - host_plane_builder.ReserveLines(2); - - XLineBuilder consumer_thread = host_plane_builder.GetOrCreateLine(0); - CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::MapAndBatch", - 0, 100000000, {{StatType::kStepId, kMapAndBatchIteratorId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, - HostEventType::kMapAndBatchConsume, 80000000, 20000000, - {{StatType::kElementId, kElementId}}); - - XLineBuilder producer_thread = host_plane_builder.GetOrCreateLine(1); - CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kMapAndBatchProduce, 0, 30000000, - {{StatType::kElementId, kElementId}}); - CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::MapAndBatch::Range", 0, 30000000, - {{StatType::kStepId, kRangeIteratorId}, - {StatType::kParentId, kMapAndBatchIteratorId}}); - CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kMapAndBatchProduce, 40000000, 30000000, - {{StatType::kElementId, kElementId}}); - CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::MapAndBatch::Range", 40000000, 30000000, - {{StatType::kStepId, kRangeIteratorId}, - {StatType::kParentId, kMapAndBatchIteratorId}}); - - CombinedTfDataStats combined_tf_data_stats; - CombinedTfDataStatsBuilder builder(&combined_tf_data_stats); - builder.Add("host1", &host_plane); - builder.Finalize(); - EXPECT_THAT( - combined_tf_data_stats, EqualsProto(R"pb( - bottleneck_analysis: { - host: "host1" - input_pipeline: "Host:0" - max_latency_ps: 100000000 - iterator_name: "Range" - iterator_long_name: "Iterator::MapAndBatch::Range" - iterator_latency_ps: 60000000 - suggestion: "See this for suggestions." - } - tf_data_stats: { - key: "host1" - value: { - iterator_metadata: { - key: 123, - value: { - id: 123 - name: "MapAndBatch" - long_name: "Iterator::MapAndBatch" - is_async: true - } - } - iterator_metadata: { - key: 456, - value: { - id: 456 - parent_id: 123 - name: "Range" - long_name: "Iterator::MapAndBatch::Range" - is_async: false - } - } - input_pipelines { - key: 123, - value: { - metadata { id: 123 type: HOST name: "Host:0" } - avg_latency_ps: 100000000 - min_latency_ps: 100000000 - max_latency_ps: 100000000 - num_slow_calls: 1 - stats { - bottleneck_iterator_id: 456 - bottleneck_iterator_latency_ps: 60000000 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 0 - duration_ps: 100000000 - self_time_ps: 40000000 - is_blocking: true - num_calls: 1 - } - } - iterator_stats { - key: 456, - value: { - id: 456 - start_time_ps: 0 - duration_ps: 60000000 - self_time_ps: 60000000 - is_blocking: true - num_calls: 2 - } - } - } - } - } - } - } - is_input_bound: true - summary: "Your profile has a tf.data input pipeline slower than 50 us. For each slow input pipeline, below shows a bottleneck in the input pipeline and a suggestion on how to fix it." - )pb")); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc deleted file mode 100644 index 1a61c032d442d5..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc +++ /dev/null @@ -1,305 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -You may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/platform/protobuf.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -std::pair Decode( - absl::string_view function_name, absl::string_view mode) { - // mode is one of ["eager", "concrete", "traced-xla", "traced-nonXla", - // "notTraced-xla", "notTraced-nonXla"] - if (mode == "eager") return {EAGER_MODE, INVALID_COMPILER}; - if (mode == "concrete") return {CONCRETE_MODE, INVALID_COMPILER}; - if (mode == "traced-xla") return {TRACED_MODE, XLA_COMPILER}; - if (mode == "traced-nonXla") return {TRACED_MODE, OTHER_COMPILER}; - if (mode == "notTraced-xla") return {NOT_TRACED_MODE, XLA_COMPILER}; - if (mode == "notTraced-nonXla") return {NOT_TRACED_MODE, OTHER_COMPILER}; - // Shouldn't reach here. - LOG(ERROR) << absl::StrCat("tf-function '", function_name, - "' has an unexpected execution mode '", mode, "'") - << std::endl; - return {INVALID_MODE, INVALID_COMPILER}; - DCHECK(false); -} - -double ComputeExpensiveCallPercent(const TfFunction& tf_function) { - // Computes the expensiveness in terms of time (rather than count). - uint64 total_call_time_ps = 0; - uint64 expensive_call_time_ps = 0; - for (const auto& mode_metrics : tf_function.metrics()) { - const auto mode = mode_metrics.first; - const auto& metrics = mode_metrics.second; - total_call_time_ps += metrics.self_time_ps(); - if (mode == TRACED_MODE || mode == EAGER_MODE) { - expensive_call_time_ps += metrics.self_time_ps(); - } - } - return tsl::profiler::SafeDivide(100.0 * expensive_call_time_ps, - total_call_time_ps); -} - -// Each invocation of a tf-function creates an ActivationRecord. -struct ActivationRecord { - std::string function_name; // name of the tf-function. - tsl::profiler::Timespan timespan; // timespan of this invocation. - TfFunctionExecutionMode execution_mode; // execution mode. - TfFunctionCompiler compiler; // compiler used. - int64_t tracing_count; // the total tracing count of this function when this - // invocation happened. - uint64 children_duration_ps; // Sum of the duration of all (immediate) - // children tf-functions of this function. - ActivationRecord() - : function_name(""), - execution_mode(INVALID_MODE), - compiler(INVALID_COMPILER), - tracing_count(0), - children_duration_ps(0) {} - ActivationRecord(absl::string_view name, - const tsl::profiler::Timespan& timespan, - TfFunctionExecutionMode exe_mode, - TfFunctionCompiler compiler, int64_t tracing_cnt) - : function_name(std::string(name)), - timespan(timespan), - execution_mode(exe_mode), - compiler(compiler), - tracing_count(tracing_cnt), - children_duration_ps(0) {} - std::string DebugString() const { - return absl::StrCat("{", function_name, ", ", - TfFunctionExecutionMode_Name(execution_mode), ", ", - TfFunctionCompiler_Name(compiler), - ", tracing_count:", tracing_count, - ", children_duration:", children_duration_ps, - " ps, timespan:", timespan.DebugString(), "}"); - } -}; - -// Entry or exit point of a tf-function. -struct EntryOrExit { - bool is_entry; // true for entry, false for exit. - int64_t index; // index to the ActivationRecord. - uint64 timestamp_ps; // the time when this entry/exit happens. - EntryOrExit() : is_entry(false), index(-1), timestamp_ps(0) {} - EntryOrExit(bool is_entry, int64_t index, uint64 timestamp_ps) - : is_entry(is_entry), index(index), timestamp_ps(timestamp_ps) {} - std::string DebugString() const { - std::string entry_or_exit = is_entry ? "entry, " : "exit, "; - return absl::StrCat("{", entry_or_exit, "idx:", index, - ", timestamp:", timestamp_ps, "}"); - } -}; - -TfFunctionCompiler CombineCompilers(TfFunctionCompiler a, - TfFunctionCompiler b) { - if (a == INVALID_COMPILER) return b; - if (b == INVALID_COMPILER) return a; - if (a == b) return a; - return MIXED_COMPILER; -} - -void CombineTfFunctionMetrics(const TfFunctionMetrics& src, - TfFunctionMetrics* dst) { - dst->set_count(src.count() + dst->count()); - dst->set_self_time_ps(src.self_time_ps() + dst->self_time_ps()); -} - -void CombineTfFunction(const TfFunction& src, TfFunction* dst) { - dst->set_total_tracing_count( - std::max(src.total_tracing_count(), dst->total_tracing_count())); - dst->set_compiler(CombineCompilers(src.compiler(), dst->compiler())); - for (const auto& mode_metrics : src.metrics()) { - int32_t execution_mode = mode_metrics.first; - const TfFunctionMetrics& src_metrics = mode_metrics.second; - TfFunctionMetrics* dst_metrics = - gtl::FindOrNull(*dst->mutable_metrics(), execution_mode); - if (dst_metrics == nullptr) { - (*dst->mutable_metrics())[execution_mode] = src_metrics; - } else { - CombineTfFunctionMetrics(src_metrics, dst_metrics); - } - } - dst->set_expensive_call_percent(ComputeExpensiveCallPercent(*dst)); -} - -// Execution history of all tf-functions invoked. -class TfFunctionExecutions { - public: - explicit TfFunctionExecutions(const XLineVisitor& line) { - // Creates points_ and activations_ from line. - line.ForEachEvent([&](const XEventVisitor& event) { - absl::string_view mode; - int64_t tracing_count = 0; - event.ForEachStat([&mode, &tracing_count](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kTfFunctionCall: - mode = stat.StrOrRefValue(); - break; - case StatType::kTfFunctionTracingCount: - tracing_count = stat.IntValue(); - break; - } - }); - if (mode.empty()) return; - - // event is a tf-function. - int64_t index = activations_.size(); - auto timespan = event.GetTimespan(); - auto mode_compiler = Decode(event.Name(), mode); - ActivationRecord activation_record = - ActivationRecord(event.Name(), timespan, mode_compiler.first, - mode_compiler.second, tracing_count); - activations_.push_back(activation_record); - EntryOrExit entry_point = - EntryOrExit(/*is_entry=*/true, index, timespan.begin_ps()); - EntryOrExit exit_point = - EntryOrExit(/*is_entry=*/false, index, timespan.end_ps()); - points_.push_back(entry_point); - points_.push_back(exit_point); - }); - - // Sorts points_ in ascending order of timestamps. - auto ascending_in_timestamp = [](const EntryOrExit& a, - const EntryOrExit& b) { - return a.timestamp_ps < b.timestamp_ps; - }; - absl::c_sort(points_, ascending_in_timestamp); - - // Calculates the children duration for each activation record. - CalculateChildrenDurations(); - } - - std::string DebugString() const { - std::string result = "\nActivations:\n"; - for (int i = 0, end = activations_.size(); i < end; i++) { - absl::StrAppend(&result, "[", i, "] ", activations_[i].DebugString(), - "\n"); - } - absl::StrAppend(&result, "tf-function Entry/Exit Points:\n"); - for (const auto& pt : points_) { - absl::StrAppend(&result, pt.DebugString(), "\n"); - } - return result; - } - - // Converts this execution history to a TfFunctionDb. - TfFunctionDb ConvertToTfFunctionDb() { - TfFunctionDb result; - for (const auto& record : activations_) { - TfFunction* fun = &(*result.mutable_tf_functions())[record.function_name]; - fun->set_total_tracing_count( - std::max(static_cast(fun->total_tracing_count()), - record.tracing_count)); - fun->set_compiler(CombineCompilers(fun->compiler(), record.compiler)); - // The self-time of this function is the difference between the duration - // of this function and the duration of its children. - uint64 self_time_ps = - record.timespan.duration_ps() - record.children_duration_ps; - // Updates the metrics for this execution mode with this invocation. - TfFunctionMetrics* metrics = - &(*fun->mutable_metrics())[record.execution_mode]; - metrics->set_count(metrics->count() + 1); - metrics->set_self_time_ps(metrics->self_time_ps() + self_time_ps); - } - for (auto& name_fun : *result.mutable_tf_functions()) { - TfFunction& fun = name_fun.second; - fun.set_expensive_call_percent(ComputeExpensiveCallPercent(fun)); - } - return result; - } - - // Calculates the children duration of every tf-function. - void CalculateChildrenDurations() { - std::stack call_stack; - for (const auto& pt : points_) { - if (pt.is_entry) { - // Function entry. - call_stack.push(pt.index); - } else { - // Function exit. - DCHECK(call_stack.top() == pt.index); // must be well nested. - uint64 call_duration = activations_[pt.index].timespan.duration_ps(); - call_stack.pop(); - if (!call_stack.empty()) { - // call_stack.top() is the parent tf-function; adds call_duration to - // its children_duration. - activations_[call_stack.top()].children_duration_ps += call_duration; - } - } - } - } - - private: - // ActivationRecords for all tf-function invocations. - std::vector activations_; - // Entry and exit points of all invocations. - std::vector points_; -}; - -} // namespace - -std::string DebugString(const TfFunctionDb& tf_function_db) { - std::string str; - tsl::protobuf::TextFormat::PrintToString(tf_function_db, &str); - return str; -} - -void CombineTfFunctionDb(const TfFunctionDb& src, TfFunctionDb* dst) { - for (const auto& name_function : src.tf_functions()) { - const auto& name = name_function.first; - const auto& src_fun = name_function.second; - TfFunction* dst_fun = gtl::FindOrNull(*dst->mutable_tf_functions(), name); - if (dst_fun == nullptr) { - (*dst->mutable_tf_functions())[name] = src_fun; - } else { - CombineTfFunction(src_fun, dst_fun); - } - } -} - -TfFunctionDb ConvertHostThreadsXLineToTfFunctionDb(const XLineVisitor& line) { - TfFunctionExecutions tf_function_executions = TfFunctionExecutions(line); - return tf_function_executions.ConvertToTfFunctionDb(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions.h b/tensorflow/core/profiler/convert/xplane_to_tf_functions.h deleted file mode 100644 index fbff7ccecc72d2..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tf_functions.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_FUNCTIONS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_FUNCTIONS_H_ - -#include - -#include "tensorflow/core/profiler/protobuf/tf_function.pb.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -// Converts from the given XLine to a TfFunctionDb. -TfFunctionDb ConvertHostThreadsXLineToTfFunctionDb(const XLineVisitor& line); - -// Returns a debugging string for the given TfFunctionDb. -std::string DebugString(TfFunctionDb tf_function_db); - -// Combines the tf-function statistics from src and dst into dst. -void CombineTfFunctionDb(const TfFunctionDb& src, TfFunctionDb* dst); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_FUNCTIONS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc deleted file mode 100644 index e77883c847e53c..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc +++ /dev/null @@ -1,185 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h" - -#include - -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/tf_function.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -const absl::string_view kEager = "eager"; -const absl::string_view kConcrete = "concrete"; -const absl::string_view kTracedNonXla = "traced-nonXla"; -const absl::string_view kTracedXla = "traced-xla"; -const absl::string_view kNotTracedNonXla = "notTraced-nonXla"; -const absl::string_view kNotTracedXla = "notTraced-xla"; - -constexpr double kMaxError = 0.001; - -TfFunctionDb ConvertXSpaceToTfFunctionDb(const XSpace& space) { - TfFunctionDb result; - const XPlane* host_plane = FindPlaneWithName(space, kHostThreadsPlaneName); - if (host_plane) { - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(host_plane); - plane.ForEachLine([&result](const XLineVisitor& line) { - TfFunctionDb tf_function_db = ConvertHostThreadsXLineToTfFunctionDb(line); - CombineTfFunctionDb(tf_function_db, &result); - }); - } - return result; -} - -TEST(ConvertXPlaneToTfFunctions, CombineTwoThreads) { - XSpace space; - XPlaneBuilder host_plane_builder(space.add_planes()); - host_plane_builder.SetName(kHostThreadsPlaneName); - host_plane_builder.ReserveLines(2); - std::string kFunctionName = "decrement"; - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateTfFunctionCallEvent(&host_plane_builder, &main_thread, kFunctionName, - 10, 100, kTracedNonXla, 1); - CreateTfFunctionCallEvent(&host_plane_builder, &main_thread, kFunctionName, - 150, 20, kNotTracedNonXla, 2); - CreateTfFunctionCallEvent(&host_plane_builder, &main_thread, kFunctionName, - 200, 80, kTracedNonXla, 3); - - auto other_thread = host_plane_builder.GetOrCreateLine(1); - CreateTfFunctionCallEvent(&host_plane_builder, &other_thread, kFunctionName, - 20, 100, kTracedNonXla, 2); - CreateTfFunctionCallEvent(&host_plane_builder, &other_thread, kFunctionName, - 160, 20, kNotTracedNonXla, 2); - CreateTfFunctionCallEvent(&host_plane_builder, &other_thread, kFunctionName, - 210, 80, kTracedXla, 4); - - TfFunctionDb tf_function_db = ConvertXSpaceToTfFunctionDb(space); - EXPECT_EQ(tf_function_db.tf_functions().size(), 1); - EXPECT_EQ(tf_function_db.tf_functions().count(kFunctionName), 1); - const TfFunction& tf_function = - tf_function_db.tf_functions().at(kFunctionName); - EXPECT_EQ(tf_function.total_tracing_count(), 4); - EXPECT_EQ(tf_function.compiler(), MIXED_COMPILER); - EXPECT_NEAR(tf_function.expensive_call_percent(), 90, kMaxError); - - const auto& metrics = tf_function.metrics(); - EXPECT_EQ(metrics.size(), 2); - EXPECT_EQ(metrics.count(TRACED_MODE), 1); - EXPECT_EQ(metrics.count(NOT_TRACED_MODE), 1); - const auto& traced_mode = metrics.at(TRACED_MODE); - EXPECT_EQ(traced_mode.count(), 4); - EXPECT_EQ(traced_mode.self_time_ps(), 360); - const auto& not_traced_mode = metrics.at(NOT_TRACED_MODE); - EXPECT_EQ(not_traced_mode.count(), 2); - EXPECT_EQ(not_traced_mode.self_time_ps(), 40); -} - -TEST(ConvertXPlaneToTfFunctions, NestedFunctions) { - XSpace space; - XPlaneBuilder host_plane_builder(space.add_planes()); - host_plane_builder.SetName(kHostThreadsPlaneName); - host_plane_builder.ReserveLines(1); - std::string kOuterFunctionName = "outer"; - std::string kInnerFunctionName = "inner"; - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateTfFunctionCallEvent(&host_plane_builder, &main_thread, - kOuterFunctionName, 10, 100, kTracedNonXla, 1); - CreateTfFunctionCallEvent(&host_plane_builder, &main_thread, - kInnerFunctionName, 30, 40, kNotTracedXla, 0); - TfFunctionDb tf_function_db = ConvertXSpaceToTfFunctionDb(space); - EXPECT_EQ(tf_function_db.tf_functions().size(), 2); - EXPECT_EQ(tf_function_db.tf_functions().count(kOuterFunctionName), 1); - EXPECT_EQ(tf_function_db.tf_functions().count(kInnerFunctionName), 1); - const TfFunction& outer = - tf_function_db.tf_functions().at(kOuterFunctionName); - EXPECT_EQ(outer.total_tracing_count(), 1); - EXPECT_EQ(outer.compiler(), OTHER_COMPILER); - EXPECT_NEAR(outer.expensive_call_percent(), 100, kMaxError); - const auto& outer_metrics = outer.metrics(); - EXPECT_EQ(outer_metrics.size(), 1); - EXPECT_EQ(outer_metrics.count(TRACED_MODE), 1); - const auto& traced_mode = outer_metrics.at(TRACED_MODE); - EXPECT_EQ(traced_mode.count(), 1); - EXPECT_EQ(traced_mode.self_time_ps(), 60); - const TfFunction& inner = - tf_function_db.tf_functions().at(kInnerFunctionName); - EXPECT_EQ(inner.total_tracing_count(), 0); - EXPECT_EQ(inner.compiler(), XLA_COMPILER); - EXPECT_NEAR(inner.expensive_call_percent(), 0, kMaxError); - const auto& inner_metrics = inner.metrics(); - EXPECT_EQ(inner_metrics.size(), 1); - EXPECT_EQ(inner_metrics.count(NOT_TRACED_MODE), 1); - const auto& not_traced_mode = inner_metrics.at(NOT_TRACED_MODE); - EXPECT_EQ(not_traced_mode.count(), 1); - EXPECT_EQ(not_traced_mode.self_time_ps(), 40); -} - -TEST(ConvertXPlaneToTfFunctions, EagerPlusConcrete) { - XSpace space; - XPlaneBuilder host_plane_builder(GetOrCreateHostXPlane(&space)); - host_plane_builder.ReserveLines(2); - std::string kEagerFunctionName = "i_am_eager"; - std::string kConcreteFunctionName = "i_am_concrete"; - - auto main_thread = host_plane_builder.GetOrCreateLine(0); - CreateTfFunctionCallEvent(&host_plane_builder, &main_thread, - kEagerFunctionName, 10, 200, kEager); - auto other_thread = host_plane_builder.GetOrCreateLine(1); - CreateTfFunctionCallEvent(&host_plane_builder, &other_thread, - kConcreteFunctionName, 20, 40, kConcrete); - TfFunctionDb tf_function_db = ConvertXSpaceToTfFunctionDb(space); - EXPECT_EQ(tf_function_db.tf_functions().size(), 2); - EXPECT_EQ(tf_function_db.tf_functions().count(kEagerFunctionName), 1); - EXPECT_EQ(tf_function_db.tf_functions().count(kConcreteFunctionName), 1); - const TfFunction& eager = - tf_function_db.tf_functions().at(kEagerFunctionName); - EXPECT_EQ(eager.total_tracing_count(), 0); - EXPECT_EQ(eager.compiler(), INVALID_COMPILER); - EXPECT_NEAR(eager.expensive_call_percent(), 100, kMaxError); - const auto& eager_metrics = eager.metrics(); - EXPECT_EQ(eager_metrics.size(), 1); - EXPECT_EQ(eager_metrics.count(EAGER_MODE), 1); - const auto& eager_mode = eager_metrics.at(EAGER_MODE); - EXPECT_EQ(eager_mode.count(), 1); - EXPECT_EQ(eager_mode.self_time_ps(), 200); - const TfFunction& concrete = - tf_function_db.tf_functions().at(kConcreteFunctionName); - EXPECT_EQ(concrete.total_tracing_count(), 0); - EXPECT_EQ(concrete.compiler(), INVALID_COMPILER); - EXPECT_NEAR(concrete.expensive_call_percent(), 0, kMaxError); - const auto& concrete_metrics = concrete.metrics(); - EXPECT_EQ(concrete_metrics.size(), 1); - EXPECT_EQ(concrete_metrics.count(CONCRETE_MODE), 1); - const auto& concrete_mode = concrete_metrics.at(CONCRETE_MODE); - EXPECT_EQ(concrete_mode.count(), 1); - EXPECT_EQ(concrete_mode.self_time_ps(), 40); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tool_names.cc b/tensorflow/core/profiler/convert/xplane_to_tool_names.cc deleted file mode 100644 index 9fb564899e5641..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tool_names.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tool_names.h" - -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/str_join.h" -#include "xla/tsl/platform/statusor.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h" -#include "tensorflow/core/profiler/convert/xplane_to_hlo.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" - -namespace tensorflow { -namespace profiler { - -absl::StatusOr GetAvailableToolNames( - const SessionSnapshot& session_snapshot) { - std::vector tools; - bool is_cloud_vertex_ai = !session_snapshot.HasAccessibleRunDir(); - if (session_snapshot.XSpaceSize() != 0) { - tools.reserve(11); - tools.push_back(is_cloud_vertex_ai ? "trace_viewer" : "trace_viewer@"); - tools.push_back("overview_page"); - tools.push_back("input_pipeline_analyzer"); - tools.push_back("framework_op_stats"); - tools.push_back("memory_profile"); - tools.push_back("pod_viewer"); - tools.push_back("op_profile"); - tools.push_back("inference_profile"); - tools.push_back("hlo_stats"); - tools.push_back("roofline_model"); - - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(0)); - - if (!FindPlanesWithPrefix(*xspace, kGpuPlanePrefix).empty()) { - tools.push_back("kernel_stats"); - } - - TF_ASSIGN_OR_RETURN(bool has_hlo, - ConvertMultiXSpaceToHloProto(session_snapshot)); - if (has_hlo) { - tools.push_back("memory_viewer"); - tools.push_back("graph_viewer"); - } - - TF_ASSIGN_OR_RETURN(bool has_dcn_collective_stats, - HasDcnCollectiveStatsInMultiXSpace(session_snapshot)); - if (has_dcn_collective_stats) { - tools.push_back("dcn_collective_stats"); - } - } - - return absl::StrJoin(tools, ","); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tool_names.h b/tensorflow/core/profiler/convert/xplane_to_tool_names.h deleted file mode 100644 index a1e936940d2b91..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tool_names.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOL_NAMES_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOL_NAMES_H_ - -#include - -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/profiler/convert/repository.h" - -namespace tensorflow { -namespace profiler { - -// Gets the names of the available tools given a session snapshot. -// Returns a comma separated list of tool names. -absl::StatusOr GetAvailableToolNames( - const SessionSnapshot& session_snapshot); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOL_NAMES_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc b/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc deleted file mode 100644 index 83fa3111374622..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tool_names.h" - -#include -#include -#include -#include -#include - -#include -#include -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" -#include "xla/tsl/platform/env.h" -#include "xla/tsl/platform/status.h" -#include "tensorflow/core/platform/file_system.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -struct XPlaneToToolsTestCase { - std::string test_name; - std::string_view plane_name; - bool has_hlo_module; - bool has_dcn_collective_stats; - std::vector expected_tools; -}; - -SessionSnapshot CreateSessionSnapshot(std::unique_ptr xspace, - bool has_hlo_module, - bool has_dcn_collective_stats) { - std::string test_name = - ::testing::UnitTest::GetInstance()->current_test_info()->name(); - std::string path = absl::StrCat("ram://", test_name, "/"); - std::unique_ptr xplane_file; - tsl::Env::Default() - ->NewAppendableFile(absl::StrCat(path, "hostname.xplane.pb"), - &xplane_file) - .IgnoreError(); - std::vector paths = {path}; - - if (has_hlo_module) { - tsl::Env::Default() - ->NewAppendableFile(absl::StrCat(path, "module_name.hlo_proto.pb"), - &xplane_file) - .IgnoreError(); - } else { - tsl::Env::Default() - ->NewAppendableFile(absl::StrCat(path, "NO_MODULE.hlo_proto.pb"), - &xplane_file) - .IgnoreError(); - } - - if (has_dcn_collective_stats) { - tsl::Env::Default() - ->NewAppendableFile( - absl::StrCat(path, "hostname.dcn_collective_stats.pb"), - &xplane_file) - .IgnoreError(); - tsl::Env::Default() - ->NewAppendableFile( - absl::StrCat(path, "ALL_HOSTS.dcn_collective_stats.pb"), - &xplane_file) - .IgnoreError(); - } else { - tsl::Env::Default() - ->NewAppendableFile( - absl::StrCat(path, "NO_HOST.dcn_collective_stats.pb"), &xplane_file) - .IgnoreError(); - } - - std::vector> xspaces; - xspaces.push_back(std::move(xspace)); - - absl::StatusOr session_snapshot = - SessionSnapshot::Create(paths, std::move(xspaces)); - TF_CHECK_OK(session_snapshot.status()); - return std::move(session_snapshot.value()); -} - -using XPlaneToToolsTest = ::testing::TestWithParam; - -TEST_P(XPlaneToToolsTest, ToolsList) { - const XPlaneToToolsTestCase& test_case = GetParam(); - auto xspace = std::make_unique(); - FindOrAddMutablePlaneWithName(xspace.get(), test_case.plane_name); - - SessionSnapshot sessionSnapshot = - CreateSessionSnapshot(std::move(xspace), test_case.has_hlo_module, - test_case.has_dcn_collective_stats); - - absl::StatusOr toolsString = - GetAvailableToolNames(sessionSnapshot); - ASSERT_TRUE(toolsString.ok()); - - std::vector tools = absl::StrSplit(toolsString.value(), ','); - - std::vector expected_tools = { - "trace_viewer", - "overview_page", - "input_pipeline_analyzer", - "framework_op_stats", - "memory_profile", - "pod_viewer", - "op_profile", - "hlo_stats", - "roofline_model", - "inference_profile", - }; - expected_tools.insert(expected_tools.end(), test_case.expected_tools.begin(), - test_case.expected_tools.end()); - EXPECT_THAT(tools, ::testing::UnorderedElementsAreArray(expected_tools)); -} - -INSTANTIATE_TEST_SUITE_P( - XPlaneToToolsTests, XPlaneToToolsTest, - ::testing::ValuesIn({ - {"ToolsForTpuWithoutHloModule", kTpuPlanePrefix, false, false, {}}, - {"ToolsForTpuWithHloModule", - kTpuPlanePrefix, - true, - false, - {"graph_viewer", "memory_viewer"}}, - {"ToolsForGpuWithoutHloModule", - kGpuPlanePrefix, - false, - false, - {"kernel_stats"}}, - {"ToolsForGpuWithHloModule", - kGpuPlanePrefix, - true, - false, - {"kernel_stats", "graph_viewer", "memory_viewer"}}, - {"ToolsForTpuWithDcnCollectiveStats", - kTpuPlanePrefix, - false, - true, - {"dcn_collective_stats"}}, - }), - [](const ::testing::TestParamInfo& info) { - return info.param.test_name; - }); - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc deleted file mode 100644 index 60476773873064..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc +++ /dev/null @@ -1,420 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_tools_data.h" - -#include -#include -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/numbers.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/env.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/file_system.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/convert/xplane_to_trace_events.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/profiler/convert/compute_inference_latency.h" -#include "tensorflow/core/profiler/convert/hlo_to_tools_data.h" -#include "tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h" -#include "tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h" -#include "tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h" -#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" -#include "tensorflow/core/profiler/convert/op_stats_to_op_profile.h" -#include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h" -#include "tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h" -#include "tensorflow/core/profiler/convert/op_stats_to_roofline_model.h" -#include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h" -#include "tensorflow/core/profiler/convert/preprocess_single_host_xplane.h" -#include "tensorflow/core/profiler/convert/process_megascale_dcn.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/tool_options.h" -#include "tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h" -#include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h" -#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" -#include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h" -#include "tensorflow/core/profiler/convert/xplane_to_tool_names.h" -#include "tensorflow/core/profiler/convert/xplane_to_trace_container.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/op_profile.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/overview_page.pb.h" -#include "tensorflow/core/profiler/protobuf/roofline_model.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/platform/protobuf.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "xprof/convert/trace_viewer/trace_events_to_json.h" // from @org_xprof -#include "xprof/convert/trace_viewer/trace_viewer_visibility.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -namespace { - -struct TraceViewOption { - uint64_t resolution = 0; - double start_time_ms = 0.0; - double end_time_ms = 0.0; -}; - -absl::StatusOr GetTraceViewOption(const ToolOptions& options) { - TraceViewOption trace_options; - auto start_time_ms_opt = - GetParamWithDefault(options, "start_time_ms", "0.0"); - auto end_time_ms_opt = - GetParamWithDefault(options, "end_time_ms", "0.0"); - auto resolution_opt = - GetParamWithDefault(options, "resolution", "0"); - - if (!absl::SimpleAtoi(resolution_opt, &trace_options.resolution) || - !absl::SimpleAtod(start_time_ms_opt, &trace_options.start_time_ms) || - !absl::SimpleAtod(end_time_ms_opt, &trace_options.end_time_ms)) { - return errors::InvalidArgument("wrong arguments"); - } - return trace_options; -} - -absl::StatusOr ConvertXSpaceToTraceEvents( - const SessionSnapshot& session_snapshot, const absl::string_view tool_name, - const ToolOptions& options) { - if (session_snapshot.XSpaceSize() != 1) { - return errors::InvalidArgument( - "Trace events tool expects only 1 XSpace path but gets ", - session_snapshot.XSpaceSize()); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(0)); - PreprocessSingleHostXSpace(xspace.get(), /*step_grouping=*/true, - /*derived_timeline=*/true); - std::string content; - if (tool_name == "trace_viewer") { - tsl::profiler::ConvertXSpaceToTraceEventsString(*xspace, &content); - return content; - } else { // streaming trace viewer. - std::string host_name = session_snapshot.GetHostname(0); - auto sstable_path = session_snapshot.GetFilePath(tool_name, host_name); - if (!sstable_path) { - return errors::Unimplemented( - "streaming trace viewer hasn't been supported in Cloud AI"); - } - if (!Env::Default()->FileExists(*sstable_path).ok()) { - ProcessMegascaleDcn(xspace.get()); - TraceEventsContainer trace_container; - ConvertXSpaceToTraceEventsContainer(host_name, *xspace, &trace_container); - std::unique_ptr file; - TF_RETURN_IF_ERROR( - tsl::Env::Default()->NewWritableFile(*sstable_path, &file)); - TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTable(std::move(file))); - } - TF_ASSIGN_OR_RETURN(TraceViewOption trace_option, - GetTraceViewOption(options)); - auto visibility_filter = std::make_unique( - tsl::profiler::MilliSpan(trace_option.start_time_ms, - trace_option.end_time_ms), - trace_option.resolution); - TraceEventsContainer trace_container; - // Trace smaller than threshold will be disabled from streaming. - constexpr int64_t kDisableStreamingThreshold = 500000; - TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable( - *sstable_path, /*filter=*/nullptr, std::move(visibility_filter), - kDisableStreamingThreshold)); - JsonTraceOptions options; - IOBufferAdapter adapter(&content); - TraceEventsToJson( - options, trace_container, &adapter); - return content; - } -} - -absl::StatusOr ConvertMultiXSpacesToOverviewPage( - const SessionSnapshot& session_snapshot) { - OpStatsOptions options; - options.generate_kernel_stats_db = true; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); - OverviewPage overview_page = ConvertOpStatsToOverviewPage(combined_op_stats); - InferenceStats inference_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpaceToInferenceStats(session_snapshot, "", - "", &inference_stats)); - *overview_page.mutable_inference_latency() = - ComputeInferenceLatencyResult(inference_stats); - return overview_page.SerializeAsString(); -} - -absl::StatusOr ConvertMultiXSpacesToInputPipeline( - const SessionSnapshot& session_snapshot) { - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); - return ConvertOpStatsToInputPipelineAnalysis(combined_op_stats) - .SerializeAsString(); -} - -absl::StatusOr ConvertMultiXSpacesToTfStats( - const SessionSnapshot& session_snapshot) { - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_kernel_stats_db = true; - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); - return ConvertOpStatsToTfStats(combined_op_stats).SerializeAsString(); -} - -absl::StatusOr ConvertMultiXSpacesToKernelStats( - const SessionSnapshot& session_snapshot) { - OpStatsOptions options; - options.generate_kernel_stats_db = true; - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); - return combined_op_stats.kernel_stats_db().SerializeAsString(); -} - -absl::StatusOr ConvertXSpaceToMemoryProfile( - const SessionSnapshot& session_snapshot) { - if (session_snapshot.XSpaceSize() != 1) { - return errors::InvalidArgument( - "Memory profile tool expects only 1 XSpace path but gets ", - session_snapshot.XSpaceSize()); - } - - std::string json_output; - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(0)); - PreprocessSingleHostXSpace(xspace.get(), /*step_grouping=*/true, - /*derived_timeline=*/false); - TF_RETURN_IF_ERROR(ConvertXSpaceToMemoryProfileJson(*xspace, &json_output)); - return json_output; -} - -absl::StatusOr ConvertMultiXSpacesToPodViewer( - const SessionSnapshot& session_snapshot) { - OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); - - std::string json_output; - tsl::protobuf::util::JsonPrintOptions opts; - opts.always_print_primitive_fields = true; - auto encode_status = tsl::protobuf::util::MessageToJsonString( - ConvertOpStatsToPodViewer(combined_op_stats), &json_output, opts); - if (!encode_status.ok()) { - const auto& error_message = encode_status.message(); - return errors::Internal( - "Could not convert pod viewer to json. Error: ", - absl::string_view(error_message.data(), error_message.length())); - } - return json_output; -} - -absl::StatusOr ConvertMultiXSpacesToTfDataBottleneckAnalysis( - const SessionSnapshot& session_snapshot) { - CombinedTfDataStats combined_tf_data_stats; - CombinedTfDataStatsBuilder builder(&combined_tf_data_stats); - - for (int idx = 0; idx < session_snapshot.XSpaceSize(); ++idx) { - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(idx)); - - PreprocessSingleHostXSpace(xspace.get(), /*step_grouping=*/true, - /*derived_timeline=*/false); - XPlane* host_plane = - FindMutablePlaneWithName(xspace.get(), kHostThreadsPlaneName); - std::string host_name_from_file = session_snapshot.GetHostname(idx); - if (host_plane == nullptr) { - return errors::InvalidArgument( - "Could not find host XPlane for tf data stats: ", - host_name_from_file); - } - absl::string_view host_name = - xspace->hostnames_size() ? xspace->hostnames(0) : host_name_from_file; - builder.Add(host_name, host_plane); - } - builder.Finalize(); - return combined_tf_data_stats.SerializeAsString(); -} - -absl::StatusOr ConvertMultiXSpacesToHloStats( - const SessionSnapshot& session_snapshot) { - OpStatsOptions options; - options.generate_op_metrics_db = true; - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); - hlo_stats::HloStatsDatabase hlo_stats_db = - ConvertOpStatsToHloStats(combined_op_stats); - return HloStatsToDataTableJson(hlo_stats_db); -} - -absl::StatusOr ConvertMultiXSpacesToRooflineModel( - const SessionSnapshot& session_snapshot) { - OpStatsOptions op_stats_options; - op_stats_options.generate_op_metrics_db = true; - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, op_stats_options, &combined_op_stats)); - RooflineModelDatabase result = - ConvertOpStatsToRooflineModel(combined_op_stats, true); - RooflineModelDatabase result_without_infeed_outfeed = - ConvertOpStatsToRooflineModel(combined_op_stats, false); - result.mutable_roofline_model_record()->MergeFrom( - result_without_infeed_outfeed.roofline_model_record()); - return result.SerializeAsString(); -} - -absl::StatusOr ConvertMultiXSpacesToOpProfileViewer( - const SessionSnapshot& session_snapshot) { - OpStatsOptions options; - options.generate_op_metrics_db = true; - OpStats combined_op_stats; - TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( - session_snapshot, options, &combined_op_stats)); - - tensorflow::profiler::op_profile::Profile profile; - ConvertOpStatsToOpProfile( - combined_op_stats, - ParseHardwareType(combined_op_stats.run_environment().device_type()), - profile); - std::string json_output; - tsl::protobuf::util::JsonPrintOptions opts; - opts.always_print_primitive_fields = true; - - auto encode_status = - tsl::protobuf::util::MessageToJsonString(profile, &json_output, opts); - if (!encode_status.ok()) { - const auto& error_message = encode_status.message(); - return errors::Internal( - "Could not convert op profile proto to json. Error: ", - absl::string_view(error_message.data(), error_message.length())); - } - return json_output; -} - -absl::StatusOr PreprocessXSpace( - const SessionSnapshot& session_snapshot) { - if (session_snapshot.XSpaceSize() != 1) { - return errors::InvalidArgument( - "PreprocessXSpace tool expects only 1 XSpace path but gets ", - session_snapshot.XSpaceSize()); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, - session_snapshot.GetXSpace(0)); - PreprocessSingleHostXSpace(xspace.get(), /*step_grouping=*/true, - /*derived_timeline=*/true); - return xspace->SerializeAsString(); -} - -absl::StatusOr ConvertDcnCollectiveStatsToToolData( - const SessionSnapshot& session_snapshot, const ToolOptions& options) { - // must provide a host_name field. - std::optional hostname = - GetParam(options, "host_name"); - if (!hostname.has_value() || hostname->empty()) { - return absl::InvalidArgumentError( - "Cannot find host_name from options for dcn_collective_stats tool."); - } - - // Load DcnSlackAnalysis for a host. - TF_ASSIGN_OR_RETURN( - DcnSlackAnalysis dcnSlackAnalysis, - GetDcnSlackAnalysisByHostName(session_snapshot, hostname.value())); - - return dcnSlackAnalysis.SerializeAsString(); -} - -absl::StatusOr ConvertMultiXSpacesToInferenceStats( - const SessionSnapshot& session_snapshot, const ToolOptions& options) { - InferenceStats inference_stats; - std::string request_column = - GetParamWithDefault(options, "request_column", ""); - std::string batch_column = - GetParamWithDefault(options, "batch_column", ""); - TF_RETURN_IF_ERROR(ConvertMultiXSpaceToInferenceStats( - session_snapshot, request_column, batch_column, &inference_stats)); - return inference_stats.SerializeAsString(); -} - -} // namespace - -absl::StatusOr ConvertMultiXSpacesToToolData( - const SessionSnapshot& session_snapshot, const absl::string_view tool_name, - const ToolOptions& options) { - LOG(INFO) << "serving tool: " << tool_name - << " with options: " << DebugString(options); - if (tool_name == "trace_viewer" || tool_name == "trace_viewer@") { - return ConvertXSpaceToTraceEvents(session_snapshot, tool_name, options); - } else if (tool_name == "overview_page") { - return ConvertMultiXSpacesToOverviewPage(session_snapshot); - } else if (tool_name == "input_pipeline_analyzer") { - return ConvertMultiXSpacesToInputPipeline(session_snapshot); - } else if (tool_name == "framework_op_stats") { - return ConvertMultiXSpacesToTfStats(session_snapshot); - } else if (tool_name == "kernel_stats") { - return ConvertMultiXSpacesToKernelStats(session_snapshot); - } else if (tool_name == "memory_profile") { - return ConvertXSpaceToMemoryProfile(session_snapshot); - } else if (tool_name == "pod_viewer") { - return ConvertMultiXSpacesToPodViewer(session_snapshot); - } else if (tool_name == "op_profile") { - return ConvertMultiXSpacesToOpProfileViewer(session_snapshot); - } else if (tool_name == "hlo_stats") { - return ConvertMultiXSpacesToHloStats(session_snapshot); - } else if (tool_name == "roofline_model") { - return ConvertMultiXSpacesToRooflineModel(session_snapshot); - } else if (tool_name == "memory_viewer" || tool_name == "graph_viewer") { - return ConvertHloProtoToToolData(session_snapshot, tool_name, options); - } else if (tool_name == "tool_names") { - return GetAvailableToolNames(session_snapshot); - } else if (tool_name == "_xplane.pb") { // internal test only. - return PreprocessXSpace(session_snapshot); - } else if (tool_name == "inference_profile") { - return ConvertMultiXSpacesToInferenceStats(session_snapshot, options); - } else { - return errors::InvalidArgument( - "Can not find tool: ", tool_name, - ". Please update to the latest version of Tensorflow."); - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.h b/tensorflow/core/profiler/convert/xplane_to_tools_data.h deleted file mode 100644 index 8a40e03a7cd1dd..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_tools_data.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOLS_DATA_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOLS_DATA_H_ - -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/profiler/convert/repository.h" -#include "tensorflow/core/profiler/convert/tool_options.h" - -namespace tensorflow { -namespace profiler { - -// Convert XSpace protos to a tool specific data. -// Return the serialized string of tool specific data when the conversion is -// successful, else return error status. -absl::StatusOr ConvertMultiXSpacesToToolData( - const SessionSnapshot& session_snapshot, absl::string_view tool_name, - const ToolOptions& options); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOLS_DATA_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc deleted file mode 100644 index 27aaa7af86d039..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc +++ /dev/null @@ -1,253 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_trace_container.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/trace_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" -#include "xprof/convert/trace_viewer/trace_event_arguments_builder.h" // from @org_xprof -#include "xprof/convert/trace_viewer/trace_events_util.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { -namespace { - -using tsl::profiler::FindPlanesWithPrefix; -using tsl::profiler::FindPlaneWithName; -using tsl::profiler::HostEventType; -using tsl::profiler::StatType; -using tsl::profiler::XEventVisitor; -using tsl::profiler::XFlow; -using tsl::profiler::XLineVisitor; -using tsl::profiler::XPlaneVisitor; -using tsl::profiler::XStatVisitor; - -struct SpecialArguments { - std::optional group_id; - absl::string_view step_name; - bool is_async_event = false; - // Both flow and async events share the flow specification. - std::optional flow; -}; - -inline TraceEvent::FlowEntryType FlowEntryTypeFromDirection( - XFlow::FlowDirection direction) { - switch (direction) { - case XFlow::kFlowUnspecified: - return TraceEvent::FLOW_NONE; - case XFlow::kFlowIn: - return TraceEvent::FLOW_END; - case XFlow::kFlowOut: - return TraceEvent::FLOW_START; - case XFlow::kFlowInOut: - return TraceEvent::FLOW_MID; - } -} - -template -void ConvertXStatToTraceEventArgument(const XStatVisitor& stat, T value, - SpecialArguments& special_args, - TraceEventArgumentsBuilder& args) { - if (stat.Type() == StatType::kFlow) { - special_args.flow = XFlow::FromStatValue(value); - } else if (stat.Type() == StatType::kGroupId) { - special_args.group_id = value; - } else if (stat.Type() == StatType::kIsAsync) { - special_args.is_async_event = true; - } else { - args.Append(stat.Name(), value); - } -} - -SpecialArguments ConvertXStatsToTraceEventArguments( - const XEventVisitor& event, RawData* raw_data, - TraceEventArguments* raw_args) { - TraceEventArgumentsBuilder args(raw_args); - SpecialArguments special_args; - auto for_each_stat = [&special_args, &args](const XStatVisitor& stat) { - if (tsl::profiler::IsInternalStat(stat.Type())) return; - switch (stat.ValueCase()) { - case XStat::kInt64Value: - ConvertXStatToTraceEventArgument(stat, stat.IntValue(), special_args, - args); - break; - case XStat::kUint64Value: - ConvertXStatToTraceEventArgument(stat, stat.UintValue(), special_args, - args); - break; - case XStat::kDoubleValue: - args.Append(stat.Name(), stat.DoubleValue()); - break; - case XStat::kStrValue: - case XStat::kRefValue: { - auto stat_value = stat.StrOrRefValue(); - if (stat.Type() == StatType::kStepName) { - special_args.step_name = stat_value; - } - args.Append(stat.Name(), stat_value); - break; - } - case XStat::kBytesValue: - break; - case XStat::VALUE_NOT_SET: - break; - } - }; - // Ensure the metadata stats appear before the per-occurrence stats. - event.Metadata().ForEachStat(for_each_stat); - event.ForEachStat(for_each_stat); - return special_args; -} - -void ConvertXLineToTraceEventsContainer(uint32_t device_id, - const XLineVisitor& line, - TraceEventsContainer* container) { - std::optional resource_id; - - if (line.Name() != tsl::profiler::kCounterEventsLineName) { - resource_id = line.DisplayId(); - Resource* resource = container->MutableResource(*resource_id, device_id); - resource->set_resource_id(*resource_id); - resource->set_name(std::string(line.DisplayName())); - resource->set_num_events(line.NumEvents()); - } - - RawData raw_data; // hoisted for performance - line.ForEachEvent([device_id, resource_id, &raw_data, - container](const XEventVisitor& event) { - int64_t event_type = - event.Type().value_or(HostEventType::kUnknownHostEventType); - if (tsl::profiler::IsInternalEvent(event_type)) return; - TraceEventArguments* raw_args = raw_data.mutable_args(); - absl::string_view event_name; - if (event.HasDisplayName()) { - event_name = event.DisplayName(); - TraceEventArgumentsBuilder args(raw_args); - constexpr size_t kMaxLongName = 10000; - if (event.Name().size() > kMaxLongName) { - args.Append("long_name", - absl::StrCat(event.Name().substr(0, kMaxLongName), - "...")); - } else { - args.Append("long_name", event.Name()); - } - } else { - event_name = event.Name(); - } - SpecialArguments special_args = - ConvertXStatsToTraceEventArguments(event, &raw_data, raw_args); - if (!special_args.step_name.empty()) { - event_name = special_args.step_name; - } - if (!resource_id) { - container->AddCounterEvent(event_name, device_id, event.TimestampPs(), - raw_data); - } else if (special_args.flow) { - tsl::profiler::Timespan span(event.TimestampPs(), event.DurationPs()); - if (special_args.is_async_event) { - container->AddAsyncEvent( - event_name, device_id, span, special_args.flow->Id(), - FlowEntryTypeFromDirection(special_args.flow->Direction()), - special_args.flow->Category(), &raw_data, special_args.group_id); - } else { - container->AddFlowEvent( - event_name, *resource_id, device_id, span, special_args.flow->Id(), - FlowEntryTypeFromDirection(special_args.flow->Direction()), - special_args.flow->Category(), &raw_data, special_args.group_id); - } - } else { - tsl::profiler::Timespan span(event.TimestampPs(), event.DurationPs()); - container->AddCompleteEvent(event_name, *resource_id, device_id, span, - &raw_data, special_args.group_id); - } - // Cleanup hoisted structure for next event. - if (raw_data.has_args()) raw_args->clear_arg(); - }); -} - -void ConvertXPlaneToTraceEventsContainer(uint64_t device_id, - absl::string_view hostname, - const XPlane& xplane, - TraceEventsContainer* container) { - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&xplane); - std::unique_ptr resource_grouper = - CreateDefaultResourceGrouper(device_id, plane.Name()); - - if (plane.NumLines() == 0) return; - - for (const auto& [device_id, name] : resource_grouper->Devices()) { - Device* device = container->MutableDevice(device_id); - device->set_device_id(device_id); - device->set_name(absl::StrCat(hostname, " ", name)); - } - - plane.ForEachLine([&](const XLineVisitor& line) { - if (line.DisplayName() == tsl::profiler::kXlaAsyncOpLineName) return; - if (line.NumEvents() == 0) return; - // Capture a copy of XLineVisitor because it will go out of scope. - uint32_t device_id = resource_grouper->GetDeviceId(line.DisplayId()); - ConvertXLineToTraceEventsContainer(device_id, line, container); - }); -} - -} // namespace - -void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname, - const XSpace& space, - TraceEventsContainer* container) { - const XPlane* host_plane = - FindPlaneWithName(space, tsl::profiler::kHostThreadsPlaneName); - if (host_plane != nullptr) { - ConvertXPlaneToTraceEventsContainer(tsl::profiler::kHostThreadsDeviceId, - hostname, *host_plane, container); - } - - std::vector device_planes = - FindPlanesWithPrefix(space, tsl::profiler::kGpuPlanePrefix); - - if (device_planes.empty()) { - device_planes = FindPlanesWithPrefix(space, tsl::profiler::kTpuPlanePrefix); - } - - for (const XPlane* device_plane : device_planes) { - ConvertXPlaneToTraceEventsContainer( - tsl::profiler::kFirstDeviceId + device_plane->id(), hostname, - *device_plane, container); - } - for (const XPlane* custom_plane : - FindPlanesWithPrefix(space, tsl::profiler::kCustomPlanePrefix)) { - ConvertXPlaneToTraceEventsContainer( - tsl::profiler::kFirstCustomPlaneDeviceId + custom_plane->id(), hostname, - *custom_plane, container); - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container.h b/tensorflow/core/profiler/convert/xplane_to_trace_container.h deleted file mode 100644 index 644848460661e6..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_trace_container.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_CONTAINER_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_CONTAINER_H_ - -#include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" -#include "xprof/convert/trace_viewer/trace_events.h" // from @org_xprof - -namespace tensorflow { -namespace profiler { - -using TraceEventsContainer = TraceEventsContainerBase; - -// Converts XEvents within the XSpace into trace_viewer events container. -void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname, - const XSpace& xspace, - TraceEventsContainer* container); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_CONTAINER_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container_test.cc b/tensorflow/core/profiler/convert/xplane_to_trace_container_test.cc deleted file mode 100644 index 821582610fd6ec..00000000000000 --- a/tensorflow/core/profiler/convert/xplane_to_trace_container_test.cc +++ /dev/null @@ -1,122 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/convert/xplane_to_trace_container.h" - -#include -#include - -#include -#include -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/strings/match.h" -#include "absl/strings/substitute.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/util/proto/proto_utils.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::testing::Pair; -using ::testing::UnorderedElementsAre; - -TEST(XPlaneToTraceContainerTest, CounterLine) { - XSpace xspace; - CHECK_OK(tensorflow::proto_utils::ParseTextFormatFromString( - absl::Substitute( - "planes {" - " name: \"/device:GPU:0\"" - " lines {" - " name: \"_counters_\"" - " events {" - " metadata_id: 100" - " offset_ps: $0" - " stats { metadata_id: 200 uint64_value: 100 }" - " }" - " events {" - " metadata_id: 100" - " offset_ps: $1" - " stats { metadata_id: 200 uint64_value: 200 }" - " }" - " events {" - " metadata_id: 101" - " offset_ps: $0" - " stats { metadata_id: 201 uint64_value: 300 }" - " }" - " events {" - " metadata_id: 101" - " offset_ps: $1" - " stats { metadata_id: 201 uint64_value: 400 }" - " }" - " }" - " lines {" - " id: 14" - " name: \"Stream #14(MemcpyH2D)\"" - " timestamp_ns: $3" - " events {" - " metadata_id: 10" - " offset_ps: 0" - " duration_ps: $1" - " stats { metadata_id: 8 uint64_value: 100 }" - " stats { metadata_id: 9 str_value: \"$$1\" }" - " }" - " events {" - " metadata_id: 10" - " offset_ps: $0" - " duration_ps: $3" - " stats { metadata_id: 8 uint64_value: 200 }" - " stats { metadata_id: 9 str_value: \"abcd\" }" - " }" - " }" - " event_metadata {key: 10 value: { id: 10 name: \"MemcpyD2D\" }}" - " event_metadata {key: 100 value: { id: 100 name: \"Counter 1\" }}" - " event_metadata {key: 101 value: { id: 101 name: \"Counter 2\" }}" - " stat_metadata {key: 8 value: { id: 8 name: \"RemoteCall\"}}" - " stat_metadata {key: 9 value: { id: 8 name: \"context_id\"}}" - " stat_metadata {key: 200 value: { id: 200 name: \"counter_1\"}}" - " stat_metadata {key: 201 value: { id: 201 name: \"counter_2\"}}" - "}", - tsl::profiler::UniToPico(1), tsl::profiler::UniToPico(2), - tsl::profiler::UniToNano(1), tsl::profiler::UniToNano(500)), - &xspace)); - TraceEventsContainer container; - ConvertXSpaceToTraceEventsContainer("localhost", xspace, &container); - absl::flat_hash_map> - counter_offset_to_values; - container.ForAllEvents([&counter_offset_to_values](const TraceEvent& event) { - if (absl::StrContains(event.name(), "Counter")) { - uint64_t offset = event.timestamp_ps(); - RawData raw_data; - raw_data.ParseFromString(event.raw_data()); - counter_offset_to_values[event.name()][offset] = - raw_data.args().arg(0).uint_value(); - } - }); - EXPECT_THAT( - counter_offset_to_values, - UnorderedElementsAre( - Pair("Counter 1", - UnorderedElementsAre(Pair(tsl::profiler::UniToPico(1), 100), - Pair(tsl::profiler::UniToPico(2), 200))), - Pair("Counter 2", - UnorderedElementsAre(Pair(tsl::profiler::UniToPico(1), 300), - Pair(tsl::profiler::UniToPico(2), 400))))); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc deleted file mode 100644 index 55977e2ed00833..00000000000000 --- a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc +++ /dev/null @@ -1,525 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h" - -#include - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/shape_util.h" -#include "xla/side_effect_util.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_utils.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "xla/xla_data.pb.h" -#include "tensorflow/core/profiler/protobuf/dcn_collective_info.pb.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" -#include "tensorflow/core/profiler/protobuf/topology.pb.h" -#include "tensorflow/core/profiler/utils/hlo_module_utils.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tensorflow/core/profiler/utils/hlo_proto_to_module.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/platform/regexp.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using tensorflow::profiler::DcnSlackSummary; -using tensorflow::profiler::Topology; -using tsl::profiler::CreateTfXPlaneVisitor; -using tsl::profiler::FindLineWithName; -using tsl::profiler::kXlaOpLineName; -using tsl::profiler::NanoToMicro; -using tsl::profiler::PicoToMicro; -using tsl::profiler::SafeDivide; -using tsl::profiler::StatType; -using tsl::profiler::Timespan; -using tsl::profiler::XEventContextTracker; -using tsl::profiler::XEventVisitor; -using tsl::profiler::XLineVisitor; -using tsl::profiler::XPlaneVisitor; -using tsl::profiler::XStatVisitor; -using xla::HloOpcode; - -// TODO: Identify mechanism to maintain consistency between producer and -// consumer here. -const char kHostEventRegex[] = { - "device_[0-9]+([0-9][0-9][0-9][0-9][0-9])_gid_(.*)"}; - -std::optional GetAttributeFromInstr( - const xla::HloInstruction* instr, std::string_view attribute) { - std::optional attribute_value; - if (instr->frontend_attributes().IsInitialized() && - !instr->frontend_attributes().map().empty() && - instr->frontend_attributes().map().contains(attribute)) { - attribute_value = instr->frontend_attributes().map().at(attribute); - } - return attribute_value; -} -std::optional GetRendezvous(const xla::HloInstruction* instr) { - return GetAttributeFromInstr(instr, xla::kXlaHostTransferRendezvousNameAttr); -} - -dcn_analysis_internal::DcnHostEvent ParseDcnHostEvent( - const XEventVisitor& visitor) { - dcn_analysis_internal::DcnHostEvent event; - static const LazyRE2 re = {kHostEventRegex}; - RE2::FullMatch(visitor.Name(), *re, &event.multi_slice_device_id, - &event.rendezvous_name); - - event.timespan = visitor.GetTimespan(); - return event; -} - -std::optional GetTransferType(const xla::HloInstruction* instr) { - return GetAttributeFromInstr(instr, "_xla_megascale_transfer_type"); -} - -std::string HostCollectiveKey(int index_on_host, - std::string_view rendezvous_name) { - return absl::StrCat(index_on_host, "_", rendezvous_name); -} - -DcnCollectiveInfoProto GetDcnCollectiveInfoProto(const XEventVisitor& xevent) { - DcnCollectiveInfoProto dcn_collective_info; - xevent.Metadata().ForEachStat([&](const XStatVisitor& xstat) { - if (static_cast(*xstat.Type()) == StatType::kDcnCollectiveInfo) { - absl::string_view byte_value = xstat.BytesValue(); - if (!dcn_collective_info.ParseFromArray(byte_value.data(), - byte_value.size())) { - LOG(WARNING) << "Could not parse DcnCollectiveInfoProto from metadata."; - } - } - }); - - return dcn_collective_info; -} - -} // namespace - -namespace dcn_analysis_internal { - -void DcnHostEventList::insert(DcnHostEvent event) { - if (iter_ != events_.end() && event.timespan < iter_->timespan) { - // The event being inserted is from a new line, Reset iterator to the - // beginning. - iter_ = events_.begin(); - } - while (iter_ != events_.end() && iter_->timespan < event.timespan) { - iter_++; - } - iter_ = events_.insert(iter_, event); -} - -std::optional DcnHostEventList::pop(const Timespan& timespan) { - while (!events_.empty() && events_.front().timespan < timespan) { - events_.pop_front(); - } - - if (!events_.empty() && - (timespan.Includes(events_.front().timespan.begin_ps()) || - events_.front().timespan.Includes(timespan.begin_ps()))) { - DcnHostEvent front = events_.front(); - events_.pop_front(); - return front; - } else { - return std::nullopt; - } -} - -absl::StatusOr DcnTracker::GetInstrMetadataFromHloModule( - std::string_view module_name, std::string_view instr_name) { - if (!hlo_module_cache_.contains(module_name)) { - TF_ASSIGN_OR_RETURN(auto hlo_proto, - hlo_proto_map_.GetHloProtoByModuleName(module_name)); - TF_ASSIGN_OR_RETURN(auto module, ConvertHloProtoToModule(*hlo_proto)); - hlo_module_cache_[module_name] = std::move(module); - } - const auto& hlo_module = hlo_module_cache_[module_name]; - dcn_analysis_internal::InstrMetadata instr_metadata; - auto instr = FindInstruction(*hlo_module, std::string(instr_name)); - - instr_metadata.opcode = instr->opcode(); - instr_metadata.channel_id = instr->channel_id().value(); - instr_metadata.rendezvous_name = GetRendezvous(instr); - instr_metadata.transfer_type = GetTransferType(instr); - instr_metadata.size = 0; - if (instr->shape().IsArray()) { - instr_metadata.size = xla::ShapeUtil::ByteSizeOfElements(instr->shape()); - } else if (instr->shape().IsTuple()) { - for (const auto& shape : instr->shape().tuple_shapes()) { - instr_metadata.size += xla::ShapeUtil::ByteSizeOf(shape); - } - } - return instr_metadata; -} - -absl::StatusOr DcnTracker::GetInstructionMetadata( - std::string_view module, std::string_view instr) { - std::string key = absl::StrCat(module, "_", instr); - if (const auto& it = instruction_metadata_map_.find(key); - it != instruction_metadata_map_.end()) { - return it->second; - } - - absl::StatusOr instr_metadata = - GetInstrMetadataFromHloModule(module, instr); - if (instr_metadata.ok()) { - instruction_metadata_map_[key] = *instr_metadata; - } - - return instr_metadata; -} - -DcnSlackAnalysis DcnTracker::Finalize() { - SummarizeDcnSlackAnalysis(); - return slack_analysis_; -} - -void DcnTracker::DebugString() { - for (const DcnSlack& analysis : slack_analysis_.dcn_slack()) { - LOG(INFO) << analysis.rendezvous() << " : " << analysis.slack_us(); - } -} - -void DcnTracker::UpdateActiveOps(uint64_t duration) { - for (auto& [rendezvous, opState] : rendezvous_to_op_map_) { - opState.overlapping_duration += duration; - } -} - -int DcnTracker::GetReplicaGroupSize(const std::string& rendezvous_name, - const XEventVisitor& visitor) { - if (rendezvous_to_replica_group_size_map_.contains(rendezvous_name)) { - return rendezvous_to_replica_group_size_map_[rendezvous_name]; - } - - DcnCollectiveInfoProto dcn_collective_info = - GetDcnCollectiveInfoProto(visitor); - - if (dcn_collective_info.one_to_one_groups_size() != 0) { - // OneToOneGroup has a source and a destination, which is one replica group - rendezvous_to_replica_group_size_map_[rendezvous_name] = 1; - } else if (dcn_collective_info.endpoint_groups_size() != 0) { - rendezvous_to_replica_group_size_map_[rendezvous_name] = - dcn_collective_info.endpoint_groups(0).endpoints().size(); - } else { - rendezvous_to_replica_group_size_map_[rendezvous_name] = 0; - } - - return rendezvous_to_replica_group_size_map_[rendezvous_name]; -} - -// ComputeTransmittedDataSize is called with the buffer_size for recv-done. -uint64_t DcnTracker::ComputeTransmittedDataSize( - const int64_t recv_buffer_size, const int group_size, - const std::string& transfer_type) { - uint64_t transmitted_bytes = 0; - if (group_size == 0) { - LOG(ERROR) << "Replica group size is 0."; - return transmitted_bytes; - } - - if (transfer_type == "ONE_TO_ONE") { - transmitted_bytes = group_size * recv_buffer_size; - } else if (transfer_type == "ALL_GATHER") { - transmitted_bytes = - SafeDivide((group_size - 1) * recv_buffer_size, group_size); - } else if (transfer_type == "ALL_REDUCE") { - // Since the reduced buffer now has to be sent back to the replicas, - // the total bytes transmitted over the network is 2x the shape of the op. - transmitted_bytes = - 2 * SafeDivide(group_size - 1, group_size) * recv_buffer_size; - } else if (transfer_type == "ALL_TO_ALL") { - transmitted_bytes = - SafeDivide(group_size - 1, group_size) * recv_buffer_size; - } else if (transfer_type == "REDUCE_SCATTER") { - transmitted_bytes = recv_buffer_size * (group_size - 1); - } else { - LOG(ERROR) << "Unsupported transfer type: " << transfer_type; - } - return transmitted_bytes; -} - -void DcnTracker::VisitOp(const InstrMetadata& instr, - const XEventVisitor& visitor) { - std::string rendezvous_name; - if (instr.rendezvous_name.has_value()) { - rendezvous_name = *instr.rendezvous_name; - channel_id_to_rendezvous_map_[instr.channel_id] = rendezvous_name; - } else { - if (auto it = channel_id_to_rendezvous_map_.find(instr.channel_id); - it != channel_id_to_rendezvous_map_.end()) { - rendezvous_name = it->second; - } else { - // Ignore ops as we have not seen the corresponding send/recv. - return; - } - } - - DcnOpState& opState = rendezvous_to_op_map_[rendezvous_name]; - opState.stall_duration_ns += visitor.DurationNs(); - - switch (instr.opcode) { - case HloOpcode::kSend: - opState.start_time = visitor.TimestampNs(); - opState.rendezvous_name = rendezvous_name; - opState.transfer_type = - instr.transfer_type.has_value() ? *instr.transfer_type : ""; - opState.overlapping_duration = 0; - opState.stall_duration_ns = visitor.DurationNs(); - opState.send_op_name = visitor.DisplayName(); - opState.send.set_duration_ps(visitor.DurationPs()); - opState.send.set_start_time_ps(visitor.TimestampPs()); - opState.replica_group_size = - GetReplicaGroupSize(rendezvous_name, visitor); - break; - case HloOpcode::kRecv: - opState.recv.set_duration_ps(visitor.DurationPs()); - opState.recv.set_start_time_ps(visitor.TimestampPs()); - break; - case HloOpcode::kSendDone: - opState.send_done.set_duration_ps(visitor.DurationPs()); - opState.send_done.set_start_time_ps(visitor.TimestampPs()); - break; - case HloOpcode::kRecvDone: { - opState.recv_done.set_duration_ps(visitor.DurationPs()); - opState.recv_done.set_start_time_ps(visitor.TimestampPs()); - if (opState.start_time != 0) { - DcnSlack* analysis = slack_analysis_.add_dcn_slack(); - analysis->set_rendezvous(rendezvous_name); - analysis->set_transfer_type(opState.transfer_type); - analysis->set_send_start_time_us(NanoToMicro(opState.start_time)); - analysis->set_recv_done_end_time_us( - NanoToMicro(visitor.EndTimestampNs())); - analysis->set_slack_us(NanoToMicro(visitor.TimestampNs() - - opState.start_time - - opState.overlapping_duration)); - analysis->set_bytes_transmitted_over_network(ComputeTransmittedDataSize( - instr.size, opState.replica_group_size, opState.transfer_type)); - analysis->set_stall_duration_us(NanoToMicro(opState.stall_duration_ns)); - analysis->set_recv_op_name(std::string(visitor.DisplayName())); - analysis->set_send_op_name(opState.send_op_name); - *analysis->mutable_send() = opState.send; - *analysis->mutable_recv() = opState.recv; - *analysis->mutable_send_done() = opState.send_done; - *analysis->mutable_recv_done() = opState.recv_done; - } - - break; - } - default: - LOG(ERROR) << "Received unexpected op"; - } - UpdateActiveOps(visitor.DurationNs()); -} - -std::optional DcnTracker::GetCollectiveHostEvent( - int core_id, std::string_view rendezvous, Timespan timespan) { - return core_id_to_host_event_map_[HostCollectiveKey(core_id, rendezvous)].pop( - timespan); -} - -void DcnTracker::SummarizeDcnSlackAnalysis() { - absl::flat_hash_map summary; - // TODO(b/302596260) : Expand to process all cores. - int core_id = 0; - for (DcnSlack& analysis : *slack_analysis_.mutable_dcn_slack()) { - DcnSlackSummary& s = summary[analysis.rendezvous()]; - s.set_slack_us(s.slack_us() + analysis.slack_us()); - s.set_occurrences(s.occurrences() + 1); - s.set_rendezvous(analysis.rendezvous()); - s.set_transfer_type(analysis.transfer_type()); - s.set_bytes_transmitted_over_network( - analysis.bytes_transmitted_over_network()); - s.set_stall_duration_us(s.stall_duration_us() + - analysis.stall_duration_us()); - s.set_observed_duration_us(s.observed_duration_us() + - analysis.recv_done_end_time_us() - - analysis.send_start_time_us()); - s.set_recv_op_name(analysis.recv_op_name()); - s.set_send_op_name(analysis.send_op_name()); - s.set_send_duration_us(s.send_duration_us() + - PicoToMicro(analysis.send().duration_ps())); - s.set_recv_duration_us(s.recv_duration_us() + - PicoToMicro(analysis.recv().duration_ps()) / 1E6); - s.set_send_done_duration_us( - s.send_done_duration_us() + - PicoToMicro(analysis.send_done().duration_ps())); - s.set_recv_done_duration_us( - s.recv_done_duration_us() + - PicoToMicro(analysis.recv_done().duration_ps())); - - // Populate Host summary to DcnSlackSummary - std::optional host_event = GetCollectiveHostEvent( - core_id, analysis.rendezvous(), - Timespan::FromEndPoints(analysis.send().start_time_ps(), - analysis.recv_done().start_time_ps() + - analysis.recv_done().duration_ps())); - if (host_event.has_value()) { - OpInstance* host_graph_execution = - analysis.mutable_host_graph_execution(); - host_graph_execution->set_start_time_ps(host_event->timespan.begin_ps()); - host_graph_execution->set_duration_ps(host_event->timespan.duration_ps()); - s.set_host_stall_us(s.host_stall_us() + - (((int64_t)host_event->timespan.end_ps() - - (int64_t)analysis.recv_done().start_time_ps()) / - 1E6)); - s.set_host_events_count(s.host_events_count() + 1); - } - } - - for (auto& [_, s] : summary) { - s.set_slack_us(SafeDivide(s.slack_us(), s.occurrences())); - s.set_stall_duration_us(SafeDivide(s.stall_duration_us(), s.occurrences())); - s.set_observed_duration_us( - SafeDivide(s.observed_duration_us(), s.occurrences())); - s.set_send_done_duration_us( - SafeDivide(s.send_done_duration_us(), s.occurrences())); - s.set_recv_done_duration_us( - SafeDivide(s.recv_done_duration_us(), s.occurrences())); - s.set_send_duration_us(SafeDivide(s.send_duration_us(), s.occurrences())); - s.set_recv_duration_us(SafeDivide(s.recv_duration_us(), s.occurrences())); - s.set_host_stall_us(SafeDivide(s.host_stall_us(), s.host_events_count())); - *slack_analysis_.add_dcn_slack_summary() = s; - } -} - -void DcnTracker::ProcessTopology(const Topology& topology) { - for (const auto& mesh_location : topology.mesh_location()) { - global_chip_id_to_local_index_map_[mesh_location.global_id()] = - mesh_location.index_on_host(); - } -} - -int DcnTracker::GetLocalIndex(int dcn_device_id) { - /* Based on if megacore was present or not, the LocalIndex calculation will - * differ, - * dcn device id would use the global index in cases of megacore, and use - * 2*global_index (+1) for non megacore instances - * TODO(b/302145703): Identify if transformation can be obtained from the - * TpuTopology directly - */ - int global_device_id = dcn_device_id; - if (!is_megacore_) { - if (global_chip_id_to_local_index_map_.contains(global_device_id)) { - return global_chip_id_to_local_index_map_[dcn_device_id / 2] + - dcn_device_id % 2; - } - } - if (global_chip_id_to_local_index_map_.contains(global_device_id)) { - return global_chip_id_to_local_index_map_[global_device_id]; - } - LOG(WARNING) << "Could not map dcn_device_id to Local index, Using " - "dcn_device_id : " - << global_device_id; - return global_device_id; -} - -void DcnTracker::VisitHostEvent(const DcnHostEvent& event) { - std::string key = HostCollectiveKey( - GetLocalIndex(event.multi_slice_device_id), event.rendezvous_name); - if (event.rendezvous_name.empty()) return; - core_id_to_host_event_map_[key].insert(event); -} - -void ProcessDcnTraces(const XPlane& xplane, DcnTracker& dcn_tracker) { - XPlaneVisitor xplane_visitor = CreateTfXPlaneVisitor(&xplane); - HloProtoMap hlo_proto_map; - xplane_visitor.ForEachLine([&](const XLineVisitor& line) { - line.ForEachEvent([&](const XEventVisitor& event) { - dcn_tracker.VisitHostEvent(ParseDcnHostEvent(event)); - }); - }); -} - -} // namespace dcn_analysis_internal - -DcnSlackAnalysis ConvertXSpaceToDcnSlackAnalysis(const XSpace& xspace, - const XPlane* dcn_host_plane, - const Topology* topology, - bool is_megacore) { - int num_cores = tsl::profiler::FindTensorCorePlanes(xspace).size(); - if (num_cores == 0) return DcnSlackAnalysis(); - const XPlane* xplane = - FindPlaneWithName(xspace, tsl::profiler::TpuPlaneName(0)); - XPlaneVisitor xplane_visitor = CreateTfXPlaneVisitor(xplane); - HloProtoMap hlo_proto_map; - hlo_proto_map.AddHloProtosFromXSpace(xspace); - dcn_analysis_internal::DcnTracker dcn_tracker(hlo_proto_map, is_megacore); - XEventContextTracker hlo_module_context( - &xplane_visitor, - FindLineWithName(*xplane, tsl::profiler::kXlaModuleLineName)); - xplane_visitor.ForEachLine([&](const XLineVisitor& xline) { - if (xline.Name() == kXlaOpLineName) { - xline.ForEachEvent([&](const XEventVisitor& xevent) { - std::string_view hlo_category; - - xevent.Metadata().ForEachStat([&](const XStatVisitor& xstat) { - switch (static_cast(*xstat.Type())) { - case StatType::kHloCategory: - hlo_category = xstat.StrOrRefValue(); - break; - default: - break; - } - }); - auto module = - hlo_module_context.GetContainingEvent(xevent.GetTimespan()); - if (!module.has_value()) return; - if (absl::StrContains(hlo_category, "host send") || - absl::StrContains(hlo_category, "host recv")) { - // All Dcn send/send-done/recv/recv-done ops. - auto instr = dcn_tracker.GetInstructionMetadata(module->Name(), - xevent.DisplayName()); - if (instr.ok()) { - dcn_tracker.VisitOp(*instr, xevent); - } - } - }); - } - }); - - if (dcn_host_plane != nullptr) { - VLOG(1) << "Processing host traces."; - if (topology != nullptr) { - dcn_tracker.ProcessTopology(*topology); - } - ProcessDcnTraces(*dcn_host_plane, dcn_tracker); - } - return dcn_tracker.Finalize(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h deleted file mode 100644 index 388fe80d22d3b6..00000000000000 --- a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h +++ /dev/null @@ -1,165 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XSPACE_TO_DCN_SLACK_ANALYSIS_H_ -#define TENSORFLOW_CORE_PROFILER_CONVERT_XSPACE_TO_DCN_SLACK_ANALYSIS_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/profiler/protobuf/dcn_collective_info.pb.h" -#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" -#include "tensorflow/core/profiler/protobuf/topology.pb.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -using tensorflow::profiler::DcnSlackAnalysis; - -namespace dcn_analysis_internal { - -struct DcnOpState { - uint64_t start_time = 0; - uint64_t end_time = 0; - - // Duration of containing send/send-done/recv/recv-done ops that needs to be - // subtracted from the total duration - uint64_t overlapping_duration = 0; - std::string rendezvous_name; - std::string transfer_type; - uint64_t stall_duration_ns = 0; - std::string send_op_name; - int replica_group_size = 0; - - OpInstance send; - OpInstance send_done; - OpInstance recv; - OpInstance recv_done; -}; - -// Structure to extract and store the DcnHostEvents. -struct DcnHostEvent { - std::string rendezvous_name; - tsl::profiler::Timespan timespan; - int multi_slice_device_id; -}; - -// When visiting DcnHostEvents from the megascale planes, The events are stored -// in separate lines in an ascending (by time) order. The List allows insertion -// of multiple arrays of sorted events. -class DcnHostEventList { - public: - // Insert the event into the sorted list. - void insert(DcnHostEvent event); - - // Pop the events from the front that is included within the timestamp when - // available. - std::optional pop(const tsl::profiler::Timespan& timespan); - - // Number of events. - int size() const { return events_.size(); } - - private: - std::list events_; - std::list::iterator iter_ = events_.begin(); -}; - -struct InstrMetadata { - xla::HloOpcode opcode; - uint64_t channel_id; - std::optional rendezvous_name; - int64_t size = 0; - std::optional transfer_type; -}; - -class DcnTracker { - public: - explicit DcnTracker(const tensorflow::profiler::HloProtoMap& hlo_proto_map, - bool is_megacore) - : hlo_proto_map_(hlo_proto_map), is_megacore_(is_megacore) {} - - absl::StatusOr GetInstructionMetadata(std::string_view module, - std::string_view instr); - - DcnSlackAnalysis Finalize(); - - void DebugString(); - - void VisitOp(const InstrMetadata& instr, - const tsl::profiler::XEventVisitor& visitor); - - void VisitHostEvent(const DcnHostEvent& event); - - void ProcessTopology(const tensorflow::profiler::Topology& topology); - - private: - DcnSlackAnalysis slack_analysis_; - absl::flat_hash_map rendezvous_to_op_map_; - absl::flat_hash_map channel_id_to_rendezvous_map_; - absl::flat_hash_map instruction_metadata_map_; - absl::flat_hash_map core_id_to_host_event_map_; - const tensorflow::profiler::HloProtoMap& hlo_proto_map_; - absl::flat_hash_map global_chip_id_to_local_index_map_; - absl::flat_hash_map> - hlo_module_cache_; - absl::flat_hash_map rendezvous_to_replica_group_size_map_; - bool is_megacore_ = true; - - absl::StatusOr GetInstrMetadataFromHloModule( - std::string_view module, std::string_view instr); - - void UpdateActiveOps(uint64_t duration); - - void SummarizeDcnSlackAnalysis(); - - std::optional GetCollectiveHostEvent( - int core_id, std::string_view rendezvous_name, - tsl::profiler::Timespan timespan); - - // GetLocalIndex when available, else return the global_device_id itself. - int GetLocalIndex(int dcn_device_id); - - // Get number of replica group - int GetReplicaGroupSize(const std::string& rendezvous_name, - const tsl::profiler::XEventVisitor& visitor); - - // Compute data transmitted size based on number of replica groups - uint64_t ComputeTransmittedDataSize(int64_t buffer_size, int group_size, - const std::string& transfer_type); -}; - -} // namespace dcn_analysis_internal - -// Convert Hlo Events in XSpace to Dcn Slack analysis. -DcnSlackAnalysis ConvertXSpaceToDcnSlackAnalysis( - const tsl::profiler::XSpace& xspace, - const tsl::profiler::XPlane* dcn_host_plane, - const tensorflow::profiler::Topology* topology, bool is_megacore = true); - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XSPACE_TO_DCN_SLACK_ANALYSIS_H_ diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index 7e52d2c96a6bbb..76cafa8a8aa196 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -111,6 +111,8 @@ cc_library( hdrs = ["profiler_controller.h"], deps = [ "@com_google_absl//absl/base:core_headers", + "@local_tsl//tsl/profiler/lib:profiler_controller", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) diff --git a/tensorflow/core/profiler/protobuf/BUILD b/tensorflow/core/profiler/protobuf/BUILD index 6d949f5b67d504..fcdfed4c4711f9 100644 --- a/tensorflow/core/profiler/protobuf/BUILD +++ b/tensorflow/core/profiler/protobuf/BUILD @@ -26,19 +26,6 @@ package_group( ], ) -tf_proto_library( - name = "xplane_proto", - srcs = ["xplane.proto"], - make_default_target_header_only = True, - protodeps = [ - "@local_tsl//tsl/profiler/protobuf:xplane_proto", - ], - visibility = [":friends"], - exports = [ - "@local_tsl//tsl/profiler/protobuf:xplane_proto", - ], -) - # This is needed because of how tf_android_core_proto_sources parses proto paths. exports_files( srcs = ["xplane.proto"], @@ -280,11 +267,6 @@ tf_proto_library( ) # copybara:uncomment_begin(google-only) -# py_proto_library( -# name = "xplane_py_pb2", -# visibility = [":friends"], -# deps = [":xplane_proto"], -# ) # # py_proto_library( # name = "memory_viewer_preprocess_py_pb2", diff --git a/tensorflow/core/profiler/protobuf/xplane.proto b/tensorflow/core/profiler/protobuf/xplane.proto deleted file mode 100644 index 69655b76d3e189..00000000000000 --- a/tensorflow/core/profiler/protobuf/xplane.proto +++ /dev/null @@ -1,5 +0,0 @@ -syntax = "proto3"; - -package tensorflow.profiler.empty; - -import public "tsl/profiler/protobuf/xplane.proto"; diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index 2826a9747c29d5..e41694b21f67d2 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -1,4 +1,3 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") @@ -15,69 +14,61 @@ package_group( ], ) +# DO NOT ADD NEW DEPENDENCIES TO ANY TARGET IN THIS FILE. +# Instead, use //third_party/xprof/utils. + cc_library( name = "diagnostics", - srcs = ["diagnostics.cc"], hdrs = ["diagnostics.h"], copts = tf_profiler_copts(), + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", + "@org_xprof//xprof/utils:diagnostics", ], ) cc_library( name = "event_span", - srcs = ["event_span.cc"], hdrs = ["event_span.h"], copts = tf_profiler_copts(), + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:timespan", + "@org_xprof//xprof/utils:event_span", ], ) cc_library( name = "hardware_type_utils", - srcs = ["hardware_type_utils.cc"], hdrs = ["hardware_type_utils.h"], copts = tf_profiler_copts(), - deps = [ - ":xplane_schema", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], -) - -tf_cc_test( - name = "hardware_type_utils_test", - srcs = ["hardware_type_utils_test.cc"], deps = [ - ":hardware_type_utils", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@local_xla//xla/tsl/profiler/utils:math_utils", + "@org_xprof//xprof/utils:hardware_type_utils", ], ) cc_library( name = "math_utils", hdrs = ["math_utils.h"], + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/accelerators/xprof/service:__pkg__", + "//perftools/accelerators/xprof/xplane:__pkg__", + "//perftools/accelerators/xprof/xprofilez/integration_tests:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/rpc:__pkg__", + ], deps = [ "@com_google_absl//absl/base:core_headers", "@local_xla//xla/tsl/profiler/utils:math_utils", @@ -87,64 +78,40 @@ cc_library( cc_library( name = "html_utils", hdrs = ["html_utils.h"], + visibility = [ + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - "@com_google_absl//absl/strings", + "@org_xprof//xprof/utils:html_utils", ], ) cc_library( name = "op_metrics_db_utils", - srcs = ["op_metrics_db_utils.cc"], hdrs = ["op_metrics_db_utils.h"], copts = tf_profiler_copts(), - deps = [ - ":xplane_visitor", - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/accelerators/xprof/xplane:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], -) - -tf_cc_test( - name = "op_metrics_db_utils_test", - srcs = ["op_metrics_db_utils_test.cc"], deps = [ - ":op_metrics_db_utils", - ":xplane_builder", - ":xplane_schema", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", + "@org_xprof//xprof/utils:op_metrics_db_utils", ], ) cc_library( name = "op_utils", - srcs = ["op_utils.cc"], hdrs = ["op_utils.h"], copts = tf_profiler_copts(), + visibility = [ + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - ":hlo_module_map", - ":op_metrics_db_utils", - "//tensorflow/core/profiler/convert:op_metrics_db_combiner", - "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:protobuf", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/tsl/platform:types", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:timespan", + "@org_xprof//xprof/utils:op_utils", ], ) @@ -152,6 +119,12 @@ cc_library( name = "trace_utils", hdrs = ["trace_utils.h"], copts = tf_profiler_copts(), + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/accelerators/xprof/xprofilez/nvidia_gpu:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ "@local_xla//xla/tsl/profiler/utils:trace_utils", ], @@ -190,7 +163,11 @@ cc_library( testonly = True, hdrs = ["xplane_test_utils.h"], copts = tf_profiler_copts(), - visibility = [":friends"], + visibility = [ + "//perftools/accelerators/xprof/db:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = ["@local_xla//xla/tsl/profiler/utils:xplane_test_utils"], ) @@ -206,275 +183,132 @@ cc_library( cc_library( name = "cost_utils", - srcs = ["cost_utils.cc"], hdrs = ["cost_utils.h"], copts = tf_profiler_copts(), - deps = [ - ":xplane_schema", - ":xplane_visitor", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler/costs:cost_estimator", - "//tensorflow/core/grappler/costs:op_context", - "//tensorflow/core/grappler/costs:op_level_cost_estimator", - "//tensorflow/core/grappler/costs:op_performance_data_cc", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], -) - -cc_library( - name = "host_offload_utils", - srcs = ["host_offload_utils.cc"], - hdrs = ["host_offload_utils.h"], - copts = tf_profiler_copts(), deps = [ - ":trace_utils", - ":xplane_builder", - ":xplane_schema", - ":xplane_visitor", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_xla//xla:shape_util", - "@local_xla//xla/tsl/profiler/utils:timespan", + "@org_xprof//xprof/utils:cost_utils", ], ) cc_library( name = "derived_timeline", - srcs = ["derived_timeline.cc"], hdrs = ["derived_timeline.h"], copts = tf_profiler_copts(), - visibility = [":friends"], - deps = [ - ":gpu_event_stats", - ":hlo_module_map", - ":hlo_proto_map", - ":host_offload_utils", - ":trace_utils", - ":xplane_builder", - ":xplane_schema", - ":xplane_utils", - ":xplane_visitor", - "//tensorflow/core:lib_internal", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:device_utils", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:math_utils", - "@local_xla//xla/tsl/profiler/utils:tf_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:timespan", - "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", - "@local_xla//xla/tsl/profiler/utils:trace_utils", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", - "@local_xla//xla/tsl/profiler/utils:xplane_visitor", - "@local_xla//xla/tsl/util:stats_calculator_portable", + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/accelerators/xprof/xplane:__pkg__", + "//platforms/darwinn/tools/xprof_trace:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], -) - -tf_cc_test( - name = "derived_timeline_test", - srcs = ["derived_timeline_test.cc"], deps = [ - ":derived_timeline", - ":trace_utils", - ":xplane_builder", - ":xplane_schema", - ":xplane_test_utils", - ":xplane_visitor", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@org_xprof//xprof/utils:derived_timeline", ], ) cc_library( name = "kernel_stats_utils", - srcs = ["kernel_stats_utils.cc"], hdrs = ["kernel_stats_utils.h"], copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "kernel_stats_utils_test", - srcs = ["kernel_stats_utils_test.cc"], - deps = [ - ":kernel_stats_utils", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/backends/profiler/gpu:cupti_buffer_events", + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], + deps = ["@org_xprof//xprof/utils:kernel_stats_utils"], ) cc_library( name = "tfstreamz_utils", - srcs = ["tfstreamz_utils.cc"], hdrs = ["tfstreamz_utils.h"], copts = tf_profiler_copts(), + visibility = ["//perftools/accelerators/xprof/xprofilez/cpu:__pkg__"], deps = [ - ":xplane_builder", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/framework:protos_all_cc", - "//tensorflow/core/profiler/protobuf:tfstreamz_proto_cc", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", + "@org_xprof//xprof/utils:tfstreamz_utils", ], ) cc_library( name = "step_intersection", - srcs = ["step_intersection.cc"], hdrs = ["step_intersection.h"], copts = tf_profiler_copts(), - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/platform:types", - "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/profiler/utils:timespan", + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], -) - -tf_cc_test( - name = "step_intersection_test", - srcs = ["step_intersection_test.cc"], deps = [ - ":step_intersection", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/container:flat_hash_map", + "@org_xprof//xprof/utils:step_intersection", ], ) cc_library( name = "device_caps_utils", - srcs = ["device_caps_utils.cc"], hdrs = ["device_caps_utils.h"], copts = tf_profiler_copts(), - visibility = [":friends"], - deps = [ - ":xplane_builder", - ":xplane_schema", - ":xplane_visitor", - "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + visibility = [ + "//perftools/accelerators/xprof/xplane:__pkg__", + "//platforms/xla/tools:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], + deps = ["@org_xprof//xprof/utils:device_caps_utils"], ) cc_library( name = "gpu_event_stats", - srcs = ["gpu_event_stats.cc"], hdrs = ["gpu_event_stats.h"], copts = tf_profiler_copts(), - visibility = [":friends"], + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - ":xplane_schema", - ":xplane_visitor", - "@com_google_absl//absl/strings", + "@org_xprof//xprof/utils:gpu_event_stats", ], ) cc_library( name = "hlo_proto_map", - srcs = ["hlo_proto_map.cc"], hdrs = ["hlo_proto_map.h"], - visibility = [":friends"], + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/accelerators/xprof/xplane:__pkg__", + "//perftools/accelerators/xprof/xprofilez/integration_tests:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + "@org_xprof//xprof/convert/google:__pkg__", + ], deps = [ - ":xplane_schema", - ":xplane_utils", - ":xplane_visitor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@org_xprof//xprof/utils:hlo_proto_map", ], ) cc_library( name = "hlo_proto_to_module", - srcs = ["hlo_proto_to_module.cc"], hdrs = ["hlo_proto_to_module.h"], - visibility = [":friends"], + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@local_xla//xla:util", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/tsl/platform:statusor", + "@org_xprof//xprof/utils:hlo_proto_to_module", ], ) -tf_cuda_library( +cc_library( name = "hlo_module_map", - srcs = ["hlo_module_map.cc"], hdrs = ["hlo_module_map.h"], - cuda_deps = [ - "@local_xla//xla/service/gpu/model:gpu_hlo_cost_analysis", + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], - visibility = [":friends"], deps = [ - ":hlo_module_utils", - ":hlo_proto_map", - ":hlo_proto_to_module", - "//tensorflow/core/platform:path", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/lib:traceme_encode", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_xla//xla:shape_util", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:hlo_cost_analysis", - "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", + "@org_xprof//xprof/utils:hlo_module_map", ], ) @@ -482,77 +316,48 @@ cc_library( name = "hlo_module_utils", hdrs = ["hlo_module_utils.h"], visibility = [ - ":friends", - # copybara:uncomment "//tensorflow/compiler/mlir/lite/experimental/google/tooling/google:__subpackages__", + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/compiler/mlir/lite/experimental/google/tooling/hlo_adapter:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", ], deps = [ - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/tsl/profiler/convert:xla_op_utils", - ], -) - -tf_cc_test( - name = "hlo_module_utils_test", - srcs = ["hlo_module_utils_test.cc"], - deps = [ - ":hlo_module_utils", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest_main", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/tests:hlo_test_base", + "@org_xprof//xprof/utils:hlo_module_utils", ], ) cc_library( name = "xprof_gpu_cost_analysis", - srcs = ["xprof_gpu_cost_analysis.cc"], hdrs = ["xprof_gpu_cost_analysis.h"], - visibility = [":friends"], + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//perftools/accelerators/xprof/xplane:__pkg__", + ], deps = [ - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@local_xla//xla:shape_util", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:hlo_cost_analysis", - "@local_xla//xla/service/gpu:cublas_cudnn", - "@local_xla//xla/service/gpu/model:gpu_hlo_cost_analysis", - "@local_xla//xla/tsl/platform:errors", + "@org_xprof//xprof/utils:xprof_gpu_cost_analysis", ], ) cc_library( name = "tpu_step_breakdown_utils", hdrs = ["tpu_step_breakdown_utils.h"], - visibility = [":friends"], - deps = ["//tensorflow/core/profiler/protobuf:steps_db_proto_cc"], + visibility = [ + "//perftools/accelerators/xprof/convert:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], + deps = [ + "@org_xprof//xprof/utils:tpu_step_breakdown_utils", + ], ) cc_library( name = "tpu_step_details_utils", hdrs = ["tpu_step_details_utils.h"], - visibility = [":friends"], - deps = ["//tensorflow/core/profiler/protobuf:tpu_input_pipeline_proto_cc"], -) - -tf_cc_test( - name = "xprof_gpu_cost_analysis_test", - srcs = ["xprof_gpu_cost_analysis_test.cc"], + visibility = [ + "//perftools/gputools/profiler/collector:__pkg__", + "//tensorflow/core/profiler/convert:__pkg__", + ], deps = [ - ":xprof_gpu_cost_analysis", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_xla//xla:shape_util", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/hlo/testlib:test_helpers", - "@local_xla//xla/service:hlo_cost_analysis", - "@local_xla//xla/tests:hlo_test_base", - "@local_xla//xla/tests:xla_internal_test_main", - "@local_xla//xla/tsl/platform:statusor", + "@org_xprof//xprof/utils:tpu_step_details_utils", ], ) diff --git a/tensorflow/core/profiler/utils/cost_utils.cc b/tensorflow/core/profiler/utils/cost_utils.cc deleted file mode 100644 index 8d44fd513d6e91..00000000000000 --- a/tensorflow/core/profiler/utils/cost_utils.cc +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/cost_utils.h" - -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/log/log.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/strings/strip.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/grappler/costs/cost_estimator.h" -#include "tensorflow/core/grappler/costs/op_context.h" -#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -// Decode the string that encodes tensor shape and type information and convert -// to TensorProperties. -// Returns an empty TensorProperties if error or input is "". -// See OpKernel::TraceString() to see when the shape is encoded as "". -// Input format is [, ,...] -static OpInfo::TensorProperties GetTensorProperties(absl::string_view info) { - OpInfo::TensorProperties tensor_prop; - std::vector parts = absl::StrSplit(info, '['); - if (parts.size() != 2) return tensor_prop; - DataType data_type = DT_INVALID; - if (!DataTypeFromString(parts[0], &data_type)) return tensor_prop; - tensor_prop.set_dtype(data_type); - absl::ConsumeSuffix(&parts[1], "]"); - if (parts[1].empty()) { // Scalar type. - tensor_prop.mutable_shape()->add_dim()->set_size(1); - return tensor_prop; - } - std::vector dims = absl::StrSplit(parts[1], ','); - for (const auto dim : dims) { - int size; - if (!absl::SimpleAtoi(dim, &size)) return OpInfo::TensorProperties(); - tensor_prop.mutable_shape()->add_dim()->set_size(size); - } - return tensor_prop; -} - -} // namespace - -TfOpRoofLineCostEstimator::~TfOpRoofLineCostEstimator() { - if (!unsupported_ops_.empty()) { - LOG(ERROR) << "Unsupported Op for Roofline Cost Analysis are:" - << absl::StrJoin(unsupported_ops_, ","); - } -} - -grappler::DeviceInfo TfOpRoofLineCostEstimator::GetDeviceInfo( - const DeviceProperties& device) const { - // Hypothetical devices that is used to measure peak flops and memory bytes - // accessed. - return grappler::DeviceInfo(/*gigaops=*/1, /*gb_per_sec=*/1); -} - -TfOpRoofLineCostEstimator::OpRoofLineStats TfOpRoofLineCostEstimator::Predict( - const XEventVisitor& event) { - tsl::profiler::TfOp tf_op; - absl::string_view tensor_shapes; - event.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kTfOp: - tf_op = tsl::profiler::ParseTfOpFullname(stat.StrOrRefValue()); - break; - case StatType::kTensorShapes: - tensor_shapes = stat.StrOrRefValue(); - break; - } - }); - - // Return empty OpRoofLineStats if shape is not traced or this is not a tf op. - if (tf_op.type.empty() || tensor_shapes.empty()) { - return {0ULL, 0ULL, /*inaccurate=*/true}; - } - - grappler::OpContext op_context; - op_context.name = std::string(tf_op.type); - op_context.op_info.set_op(op_context.name); - for (absl::string_view tensor : - tsl::profiler::ParseTensorShapes(tensor_shapes)) { - *op_context.op_info.add_inputs() = GetTensorProperties(tensor); - } - grappler::Costs costs = PredictCosts(op_context); - if (costs.inaccurate) unsupported_ops_.insert(std::string(tf_op.type)); - - VLOG(1) << tf_op.type << tensor_shapes - << " flops:" << costs.compute_time.count() - << " bytes:" << costs.memory_time.count(); - - /* The compute_time is measured in nanoseconds, therefore numerically it is - * equal to flops because giga ops / second cancel the nanoseconds. - * Same for memory_time */ - return {/*flops=*/static_cast(costs.compute_time.count()), - /*bytes_accessed=*/static_cast(costs.memory_time.count()), - /*inaccurate=*/costs.inaccurate}; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/cost_utils.h b/tensorflow/core/profiler/utils/cost_utils.h index 01a6540d8145cb..b7f139ffdb5915 100644 --- a/tensorflow/core/profiler/utils/cost_utils.h +++ b/tensorflow/core/profiler/utils/cost_utils.h @@ -15,45 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_COST_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_COST_UTILS_H_ -#include - -#include "absl/container/flat_hash_set.h" -#include "tensorflow/core/grappler/costs/cost_estimator.h" -#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -// This is a wrapper of tensorflow::grappler::OpLevelCostEstimator and use -// tracing time information to estimate the roof line stats for each traced -// tensorflow op. -class TfOpRoofLineCostEstimator - : public tensorflow::grappler::OpLevelCostEstimator { - public: - TfOpRoofLineCostEstimator() = default; - ~TfOpRoofLineCostEstimator() override; - - grappler::DeviceInfo GetDeviceInfo( - const DeviceProperties& device) const override; - - struct OpRoofLineStats { - uint64 flops = 0LL; - uint64 bytes_accessed = 0LL; - bool inaccurate = false; - }; - OpRoofLineStats Predict(const XEventVisitor& event); - - private: - absl::flat_hash_set - unsupported_ops_; // summary for unsupported ops. - - TfOpRoofLineCostEstimator(const TfOpRoofLineCostEstimator&) = delete; - void operator=(const TfOpRoofLineCostEstimator&) = delete; -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/cost_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_COST_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc deleted file mode 100644 index 721c283c7dda7c..00000000000000 --- a/tensorflow/core/profiler/utils/derived_timeline.cc +++ /dev/null @@ -1,772 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/utils/derived_timeline.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/device_utils.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" -#include "xla/tsl/profiler/utils/trace_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "xla/tsl/util/stats_calculator.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/profiler/utils/gpu_event_stats.h" -#include "tensorflow/core/profiler/utils/hlo_module_map.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tensorflow/core/profiler/utils/host_offload_utils.h" -#include "tensorflow/core/profiler/utils/trace_utils.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::tsl::profiler::DeviceType; -using ::tsl::profiler::FindMutableTensorCorePlanes; -using ::tsl::profiler::GetDeviceType; - -inline std::string HloModuleEventName(const GpuEventStats& stats) { - return stats.program_id ? tsl::profiler::HloModuleNameWithProgramId( - stats.hlo_module_name, *stats.program_id) - : std::string(stats.hlo_module_name); -} - -// Returns a prefix that uniquely identifies the HLO module. -inline std::string HloOpEventPrefix(const GpuEventStats& stats) { - return stats.program_id ? absl::StrCat(*stats.program_id, "/") - : absl::StrCat(stats.hlo_module_name, "/"); -} - -std::vector GetOrCreateHloOpEventsMetadata( - XPlaneBuilder& xplane, const GpuEventStats& stats, const Symbol symbol) { - DCHECK(stats.IsXlaOp()); - std::vector hlo_op_events_metadata; - hlo_op_events_metadata.reserve(stats.hlo_op_names.size()); - // Prepend an HLO module identifier so HLO operators with the same name but in - // different modules have different metadata. - std::string hlo_op_event_prefix = HloOpEventPrefix(stats); - for (absl::string_view hlo_op_name : stats.hlo_op_names) { - XEventMetadata* hlo_op_event_metadata = xplane.GetOrCreateEventMetadata( - absl::StrCat(hlo_op_event_prefix, hlo_op_name)); - // Display the HLO name without the module name in tools. - if (hlo_op_event_metadata->display_name().empty()) { - hlo_op_event_metadata->set_display_name(std::string(hlo_op_name)); - } - hlo_op_events_metadata.push_back(hlo_op_event_metadata); - if (!symbol.hlo_text.empty()) { - XStatsBuilder event_stats(hlo_op_event_metadata, &xplane); - event_stats.SetOrAddStatValue(*xplane.GetOrCreateStatMetadata("hlo_text"), - symbol.hlo_text); - } - } - return hlo_op_events_metadata; -} - -// Get the derived line id for a given derived line in group which starts from -// first_derived_line_id. -// According to definition in trace_utils.h, the derived lines are: -// kThreadIdTfNameScope to kThreadIdSource. Keep the line id sequence in each -// group as this original group.. -inline int64_t GetDerivedLineId(int64_t first_derived_line_id, - int64_t target_line_id) { - return first_derived_line_id + (target_line_id - kThreadIdTfNameScope); -} - -// Get the derived line name for a given derived line in group which starts from -// first_derived_line_id. -std::string GetDerivedLineName(int64_t first_derived_line_id, - int64_t target_line_id, - absl::Span source_line_ids) { - int64_t offset = target_line_id - kThreadIdTfNameScope; - std::string suffix; - if (first_derived_line_id != kThreadIdTfNameScope && - !source_line_ids.empty()) { - suffix = absl::StrCat(" - from #", source_line_ids[0]); - } - switch (offset) { - case kThreadIdTfNameScope - kThreadIdTfNameScope: - return absl::StrCat(kTensorFlowNameScopeLineName, suffix); - case kThreadIdHloOp - kThreadIdTfNameScope: - return absl::StrCat(kXlaOpLineName, suffix); - case kThreadIdHloModule - kThreadIdTfNameScope: - return absl::StrCat(kXlaModuleLineName, suffix); - case kThreadIdTfOp - kThreadIdTfNameScope: - return absl::StrCat(kTensorFlowOpLineName, suffix); - case kThreadIdSource - kThreadIdTfNameScope: - return absl::StrCat(kSourceLineName, suffix); - default: - LOG(ERROR) << "Invalid target line id: " << target_line_id - << " for first_derived_line_id: " << first_derived_line_id; - return absl::StrCat("UnknownDerived#", first_derived_line_id + offset); - } -} - -// Derive events from the given line ids using annotations. -// Returns the derived line ids in the order of tf_name_scope, tf_op, hlo_op, -// hlo_module, source. Where the derived line id for tf_name_scope is -// first_derived_line_id. -std::vector DeriveEventsFromAnnotationsForLines( - const SymbolResolver& symbol_resolver, XPlane* device_trace, - absl::Span line_ids, int64_t first_derived_line_id, - const ScopeRangeIdTree* scope_range_id_tree = nullptr) { - XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(device_trace); - - XPlaneBuilder plane_builder(device_trace); - int64_t start_timestamp_ns = GetStartTimestampNs(*device_trace); - DerivedXLineBuilder tf_ops( - &plane_builder, GetDerivedLineId(first_derived_line_id, kThreadIdTfOp), - GetDerivedLineName(first_derived_line_id, kThreadIdTfOp, line_ids), - start_timestamp_ns, {}); - DerivedXLineBuilder tf_name_scope( - &plane_builder, - GetDerivedLineId(first_derived_line_id, kThreadIdTfNameScope), - GetDerivedLineName(first_derived_line_id, kThreadIdTfNameScope, line_ids), - start_timestamp_ns, {&tf_ops}); - DerivedXLineBuilder hlo_ops( - &plane_builder, GetDerivedLineId(first_derived_line_id, kThreadIdHloOp), - GetDerivedLineName(first_derived_line_id, kThreadIdHloOp, line_ids), - start_timestamp_ns, {}); - DerivedXLineBuilder hlo_modules( - &plane_builder, - GetDerivedLineId(first_derived_line_id, kThreadIdHloModule), - GetDerivedLineName(first_derived_line_id, kThreadIdHloModule, line_ids), - start_timestamp_ns, {&tf_name_scope, &hlo_ops}); - DerivedXLineBuilder source( - &plane_builder, GetDerivedLineId(first_derived_line_id, kThreadIdSource), - GetDerivedLineName(first_derived_line_id, kThreadIdSource, line_ids), - start_timestamp_ns, {}); - - // Declare this vector here so that its memory will be reused during the loop, - // instead of being allocated and deallocated for each iteration. - std::vector> level_range_ids; - for (const XEventVisitor& event : - GetSortedEvents(plane_visitor, false, line_ids)) { - GpuEventStats stats(&event); - // For HLO/TF op lines, only use kernel events (i.e. excluding memcpy or - // allocation events). Also CudaGraph executions are also treated as - // kernel events. - if (!stats.IsKernel() && !stats.IsCudaGraphExecution()) continue; - tsl::profiler::Timespan event_span = event.GetTimespan(); - - if ((!stats.hlo_module_name.empty() || stats.IsXlaOp())) { - level_range_ids.clear(); - if (stats.scope_range_id.has_value()) { - level_range_ids.push_back(stats.scope_range_id); - if (scope_range_id_tree) { - for (auto it = scope_range_id_tree->find(*stats.scope_range_id); - it != scope_range_id_tree->end(); - it = scope_range_id_tree->find(it->second)) { - level_range_ids.push_back(it->second); - } - } - } - // Now, level_range_ids looks like: - // [child_level_n, child_level_n-1, ..., child_level_1, root_level] - } - - if (!stats.hlo_module_name.empty()) { - // back() of the level_range_ids, i.e. root_level in above comment, - // is the scope range id of HLO module. - hlo_modules.ExpandOrAddEvent( - *plane_builder.GetOrCreateEventMetadata(HloModuleEventName(stats)), - event_span, stats.group_id, - level_range_ids.empty() ? std::nullopt : level_range_ids.back()); - } - - if (stats.IsXlaOp()) { - auto symbol = symbol_resolver(stats.program_id, stats.hlo_module_name, - stats.hlo_op_names.back()); - auto hlo_events_metadata = - GetOrCreateHloOpEventsMetadata(plane_builder, stats, symbol); - // level_range_ids, if not empty, should be of same size as - // hlo_events_metadata. If not of same size, do not use those ids. - absl::Span> xla_op_level_range_ids = {}; - if (level_range_ids.size() == hlo_events_metadata.size()) { - std::reverse(level_range_ids.begin(), level_range_ids.end()); - // after reverse, the level_range_ids looks like: - // [root_level, child_level_1, ..., child_level_n-1, child_level_n] - xla_op_level_range_ids = absl::MakeSpan(level_range_ids); - } - hlo_ops.ExpandOrAddEvents(hlo_events_metadata, event_span, stats.group_id, - xla_op_level_range_ids); - - // If the kernel event is nodes of a CudaGraph or a whole cuda graph - // exec, try to mark extra stats to to corresponding XLA op event here. - if (stats.cuda_graph_id_for_inner_node.has_value() && - *stats.cuda_graph_id_for_inner_node != 0) { - int level = static_cast(hlo_events_metadata.size()) - 1; - if (level >= 0) { - hlo_ops.AddStatToLevelEvent(level, *hlo_ops.GetCudaGraphIdMetadata(), - *stats.cuda_graph_id_for_inner_node); - if (stats.correlation_id.has_value()) { - hlo_ops.AddStatToLevelEvent(level, - *hlo_ops.GetCorrelationIdMetadata(), - *stats.correlation_id); - } - } - } - - if (!symbol.tf_op_name.empty()) { - ProcessTfOpEvent(symbol.tf_op_name, event_span, stats.group_id, - plane_builder, tf_name_scope, tf_ops); - } - if (!symbol.source_info.empty()) { - source.ExpandOrAddEvent( - *plane_builder.GetOrCreateEventMetadata(symbol.source_info), - event_span, stats.group_id); - } - } else if (stats.IsTfOp()) { - ProcessTfOpEvent(stats.tf_op_fullname, event_span, stats.group_id, - plane_builder, tf_name_scope, tf_ops); - } - } - return {tf_name_scope.Line().Id(), tf_ops.Line().Id(), - hlo_modules.Line().Id(), hlo_ops.Line().Id(), source.Line().Id()}; -} - -} // namespace - -void ProcessTfOpEvent(absl::string_view tf_op_full_name, - tsl::profiler::Timespan event_span, - std::optional group_id, - XPlaneBuilder& plane_builder, - DerivedXLineBuilder& tf_name_scope_line_builder, - DerivedXLineBuilder& tf_op_line_builder) { - tsl::profiler::TfOp tf_op = tsl::profiler::ParseTfOpFullname(tf_op_full_name); - tsl::profiler::Category category = tf_op.category; - if (category == tsl::profiler::Category::kTensorFlow || - category == tsl::profiler::Category::kJax) { - tf_name_scope_line_builder.ExpandOrAddEvents( - plane_builder.GetOrCreateEventsMetadata( - tsl::profiler::ParseTfNameScopes(tf_op)), - event_span, group_id); - } - XEventMetadata* tf_op_event_metadata = - plane_builder.GetOrCreateEventMetadata(tf_op_full_name); - // Set the display name to op_type so that the events of the same op_type have - // the same color in the trace viewer. - if (tf_op_event_metadata->display_name().empty()) { - tf_op_event_metadata->set_display_name(tsl::profiler::TfOpEventName(tf_op)); - } - tf_op_line_builder.ExpandOrAddEvent(*tf_op_event_metadata, event_span, - group_id); -} - -DerivedXEventBuilder::DerivedXEventBuilder( - XEventBuilder event, std::optional group_id, - std::optional scope_range_id) - : event_(std::move(event)), - group_id_(group_id), - scope_range_id_(scope_range_id) {} - -bool DerivedXEventBuilder::ShouldExpand( - const XEventMetadata& event_metadata, std::optional group_id, - std::optional scope_range_id) const { - return event_.MetadataId() == event_metadata.id() && group_id_ == group_id && - (!scope_range_id.has_value() || !scope_range_id_.has_value() || - scope_range_id_ == scope_range_id); -} - -void DerivedXEventBuilder::Expand(tsl::profiler::Timespan event_span) { - tsl::profiler::Timespan timespan = event_.GetTimespan(); - DCHECK_LE(timespan.begin_ps(), event_span.begin_ps()); - timespan.ExpandToInclude(event_span); - event_.SetTimespan(timespan); -} - -DerivedXLineBuilder::DerivedXLineBuilder( - XPlaneBuilder* plane, int64_t line_id, absl::string_view name, - int64_t timestamp_ns, std::vector dependent_lines) - : group_id_stat_metadata_( - plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))), - correlation_id_metadata_(plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kCorrelationId))), - cuda_graph_id_metadata_(plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kCudaGraphId))), - line_(plane->GetOrCreateLine(line_id)), - dependent_lines_(std::move(dependent_lines)) { - line_.SetName(name); - line_.SetTimestampNs(timestamp_ns); - is_gpu_plane_ = GetDeviceType(plane->Name()) == DeviceType::kGpu; -} - -void DerivedXLineBuilder::ExpandOrAddEvent( - const XEventMetadata& event_metadata, tsl::profiler::Timespan event_span, - std::optional group_id, std::optional scope_range_id) { - ExpandOrAddLevelEvent(event_metadata, event_span, group_id, scope_range_id, - /*level=*/0); -} - -void DerivedXLineBuilder::ExpandOrAddEvents( - const std::vector& events_metadata_per_level, - tsl::profiler::Timespan event_span, std::optional group_id, - absl::Span> scope_range_ids) { - if (events_metadata_per_level.empty()) return; - - size_t current_nested_level = events_metadata_per_level.size(); - for (size_t level = 0; level < current_nested_level; ++level) { - ExpandOrAddLevelEvent( - *events_metadata_per_level[level], event_span, group_id, - level < scope_range_ids.size() ? scope_range_ids[level] : std::nullopt, - level); - } - ResetLastEvents(current_nested_level); -} - -void DerivedXLineBuilder::ExpandOrAddLevelEvent( - const XEventMetadata& event_metadata, tsl::profiler::Timespan event_span, - std::optional group_id, std::optional scope_range_id, - int level) { - auto& last_event = last_event_by_level_[level]; - // If group_id is not set and we still choose to expand, put an extra check: - // Expand only if the gap between the last event and the new event is less - // than 2 * duration of the last event. - // TODO: b/373944719 - add the extra node_id check for GPU profiles. - if (last_event.has_value() && - last_event->ShouldExpand(event_metadata, group_id, scope_range_id) && - (is_gpu_plane_ || group_id.has_value() || - (last_event->GetTimespan().end_ps() + - 2 * last_event->GetTimespan().duration_ps()) >= - event_span.begin_ps())) { - // Expand the last event to cover the given event. - last_event->Expand(event_span); - } else { - // Otherwise, reset the last events lower than or equal to the given level. - ResetLastEvents(level); - // And create a new event for the given level. - XEventBuilder event = line_.AddEvent(event_metadata); - event.SetTimespan(event_span); - if (group_id.has_value()) { - event.AddStatValue(*group_id_stat_metadata_, *group_id); - } - last_event.emplace(std::move(event), group_id, scope_range_id); - } -} - -void DerivedXLineBuilder::AddStatToLevelEvent(int level, - const XStatMetadata& metadata, - int64_t value) { - if (auto it = last_event_by_level_.find(level); - it != last_event_by_level_.end() && it->second.has_value()) { - it->second->SetOrAddStatValue(metadata, value); - } -} - -void DerivedXLineBuilder::AddStatToLevelEvent(int level, - const XStatMetadata& metadata, - uint64_t value) { - if (auto it = last_event_by_level_.find(level); - it != last_event_by_level_.end() && it->second.has_value()) { - it->second->SetOrAddStatValue(metadata, value); - } -} - -// When deriving a bunch of events with the same timespan, there could be -// indeterministic behavior of how trace viewer stacking these events. -// This function will shrink the stack of events with the same timespan when -// necessary. Event at top of stack might shrink more than event at the -// bottom. Because the time unit in trace viewer is nanosecond, therefore the -// minimum difference is 1ns. However to prevent shrink induced inconsitency, -// we can not shrink more than the duration of event at the top of the stack. -void DerivedXLineBuilder::AdjustDurationForTraceViewer(int level) { - if (level >= last_event_by_level_.size() || !last_event_by_level_[level]) - return; - - int max_level = level; - for (; max_level < last_event_by_level_.size(); ++max_level) { - if (!last_event_by_level_[max_level].has_value()) { - break; - } - } - --max_level; - if (max_level <= level) return; - auto& event_on_top_stack = *last_event_by_level_[max_level]; - tsl::profiler::Timespan timespan = event_on_top_stack.GetTimespan(); - // We will at most shrink the top of the stack to 1ns. - int64_t max_shrink_ns = timespan.duration_ps() / 1000 - 1; - int64_t shrink_ns = 0; - std::optional last_level_timespan; - for (int i = level; i <= max_level; ++i) { - auto& current_event = *last_event_by_level_[i]; - if (shrink_ns < max_shrink_ns && - last_level_timespan == current_event.GetTimespan()) { - shrink_ns++; - } - last_level_timespan = current_event.GetTimespan(); - if (shrink_ns) { - current_event.SetTimespan(tsl::profiler::Timespan::FromEndPoints( - last_level_timespan->begin_ps(), - last_level_timespan->end_ps() - 1000 * shrink_ns)); - } - } -} - -void DerivedXLineBuilder::ResetLastEvents(int level) { - AdjustDurationForTraceViewer(level); - for (int i = level, end = last_event_by_level_.size(); i < end; ++i) { - last_event_by_level_[i].reset(); - } - if (level == 0) { - for (DerivedXLineBuilder* line : dependent_lines_) { - line->ResetLastEvents(0); - } - } -} - -void DeriveStepEventsFromGroups( - const tsl::profiler::GroupMetadataMap& group_metadata_map, - XPlane* device_trace) { - XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(device_trace); - const XStatMetadata* group_id_stat_metadata = - plane_visitor.GetStatMetadataByType(StatType::kGroupId); - if (group_id_stat_metadata == nullptr) return; - XPlaneBuilder plane_builder(device_trace); - int64_t start_timestamp_ns = GetStartTimestampNs(*device_trace); - DerivedXLineBuilder steps(&plane_builder, kThreadIdStepInfo, kStepLineName, - start_timestamp_ns, {}); - for (const XEventVisitor& event_visitor : - GetSortedEvents(plane_visitor)) { - std::optional group_id_stat = - event_visitor.GetStat(StatType::kGroupId, *group_id_stat_metadata); - if (group_id_stat.has_value()) { - int64_t group_id = group_id_stat->IntValue(); - steps.ExpandOrAddEvent( - *plane_builder.GetOrCreateEventMetadata(absl::StrCat(group_id)), - event_visitor.GetTimespan(), group_id); - } - } - AddGroupMetadataToStepEvents(group_metadata_map, steps.Line()); -} - -void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver, - XPlane* device_trace, - const ScopeRangeIdTree* scope_range_id_tree) { - if (tsl::profiler::GetDeviceType(*device_trace) != - tsl::profiler::DeviceType::kGpu) { - DeriveEventsFromAnnotationsForLines(symbol_resolver, device_trace, {}, - kThreadIdTfNameScope); - } else { - // TODO: Currently we derive events only from the line with the most number - // of events. We should consider deriving events from all lines in the - // future, also then we need to utilize the derived relation provided by - // DeriveEventsFromAnnotationsForLines(), and find solid way to sort all - // lines. - int64_t line_id_with_most_events = -1; - int64_t max_num_events_per_line = -1; - { - XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(device_trace); - plane_visitor.ForEachLine([&](const XLineVisitor& line) { - if (IsDerivedThreadId(line.Id())) return; - int num_events = line.NumEvents(); - // make sure strong ordering - if (num_events > max_num_events_per_line || - (num_events == max_num_events_per_line && - line.Id() < line_id_with_most_events)) { - max_num_events_per_line = num_events; - line_id_with_most_events = line.Id(); - } - }); - } - - if (line_id_with_most_events >= 0) { - DeriveEventsFromAnnotationsForLines( - symbol_resolver, device_trace, {line_id_with_most_events}, - kThreadIdTfNameScope, scope_range_id_tree); - } - } - RemoveEmptyLines(device_trace); -} - -void DeriveEventsFromHostTrace( - const XPlane* host_trace, - const tsl::profiler::GroupMetadataMap& group_metadata_map, - std::vector device_traces) { - struct GroupLaunchInfo { // "Group" normally means step. - tsl::profiler::Timespan timespan; - tsl::Stat stat; - - void AddEventTimespan(tsl::profiler::Timespan event_span) { - if (stat.count() == 0) { - timespan = event_span; - } else { - timespan.ExpandToInclude(event_span); - } - stat.UpdateStat(event_span.duration_ps()); - } - }; - using DeviceLaunchInfo = - absl::flat_hash_map; - - const int num_devices = device_traces.size(); - std::vector per_device_launch_info(num_devices); - - XPlaneVisitor host_plane = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - host_plane.ForEachLine([&](const XLineVisitor& line) { - if (IsDerivedThreadId(line.Id())) return; - line.ForEachEvent([&](const XEventVisitor& event) { - // Filter out API calls for cuEventRecord/cuEventQuery/cuCtxSynchronize - // etc for now. TODO: find a better way to filter out only the memcpy and - // kernel launch events. - if (absl::StartsWith(event.Name(), "cu")) return; - LaunchEventStats stats(&event); - if (stats.group_id.has_value() && stats.IsLaunch() && - 0 <= *stats.device_id && *stats.device_id < num_devices) { - // This is a launch event on a known device. - GroupLaunchInfo& group_launch_info = - per_device_launch_info[*stats.device_id][*stats.group_id]; - group_launch_info.AddEventTimespan(event.GetTimespan()); - } - }); - }); - - int64_t host_plane_start = GetStartTimestampNs(*host_trace); - for (int i = 0; i < num_devices; ++i) { - if (per_device_launch_info[i].empty()) continue; - int64_t device_plane_start = GetStartTimestampNs(*device_traces[i]); - - XPlaneBuilder device_plane(device_traces[i]); - const XStatMetadata& group_id_stat_metadata = - *device_plane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kGroupId)); - const XStatMetadata& num_launches_stat_metadata = - *device_plane.GetOrCreateStatMetadata("num_launches"); - const XStatMetadata& max_launch_time_us_stat_metadata = - *device_plane.GetOrCreateStatMetadata("max_launch_time_us"); - const XStatMetadata& avg_launch_time_us_stat_metadata = - *device_plane.GetOrCreateStatMetadata("avg_launch_time_us"); - - XLineBuilder launch_line = - device_plane.GetOrCreateLine(kThreadIdKernelLaunch); - launch_line.SetName(kKernelLaunchLineName); - launch_line.SetTimestampNs(std::min(device_plane_start, host_plane_start)); - for (const auto& kv : per_device_launch_info[i]) { - int64_t group_id = kv.first; - const GroupLaunchInfo& group_info = kv.second; - if (const tsl::profiler::GroupMetadata* group_metadata = - gtl::FindOrNull(group_metadata_map, group_id)) { - XEventBuilder device_event = - launch_line.AddEvent(*device_plane.GetOrCreateEventMetadata( - absl::StrCat("Launch Stats for ", group_metadata->name))); - device_event.SetTimespan(group_info.timespan); - device_event.AddStatValue(group_id_stat_metadata, group_id); - device_event.AddStatValue(num_launches_stat_metadata, - group_info.stat.count()); - device_event.AddStatValue( - max_launch_time_us_stat_metadata, - tsl::profiler::PicoToMicro(group_info.stat.max())); - device_event.AddStatValue( - avg_launch_time_us_stat_metadata, - tsl::profiler::PicoToMicro(group_info.stat.avg())); - } - } - } -} - -void GenerateDerivedTimeLines( - const tsl::profiler::GroupMetadataMap& group_metadata_map, XSpace* space) { - HloModuleMap hlo_module_map; - { - HloProtoMap hlo_proto_map; - hlo_proto_map.AddHloProtosFromXSpace(*space); - for (const auto& [program_id, hlo_proto] : hlo_proto_map) { - AddHloProto(hlo_module_map, program_id, *hlo_proto); - } - } - - auto symbol_resolver = [&](absl::optional program_id, - absl::string_view hlo_module, - absl::string_view hlo_op) -> Symbol { - Symbol output; - const auto* hlo_instruction = - GetHloInstruction(hlo_module_map, program_id, hlo_op); - if (hlo_instruction != nullptr) { - output.tf_op_name = hlo_instruction->op_full_name(); - output.source_info = std::string(hlo_instruction->source_info()); - } - return output; - }; - - ScopeRangeIdTree scope_range_id_tree; - const XPlane* namespace_tree_plane = - FindPlaneWithName(*space, tsl::profiler::kScopeRangeIdTreePlaneName); - if (namespace_tree_plane) { - XPlaneVisitor namespace_tree_visitor = - tsl::profiler::CreateTfXPlaneVisitor(namespace_tree_plane); - namespace_tree_visitor.ForEachStat([&](const XStatVisitor& stat) { - scope_range_id_tree.emplace(stat.Id(), stat.IntValue()); - }); - } - - std::vector device_planes = - FindMutablePlanesWithPrefix(space, kGpuPlanePrefix); - for (XPlane* plane : device_planes) { - DeriveStepEventsFromGroups(group_metadata_map, plane); - DeriveEventsFromAnnotations(symbol_resolver, plane, &scope_range_id_tree); - } - - const XPlane* host_plane = FindPlaneWithName(*space, kHostThreadsPlaneName); - if (host_plane) { - DeriveEventsFromHostTrace(host_plane, group_metadata_map, device_planes); - } - for (XPlane* plane : FindMutableTensorCorePlanes(space)) { - DeriveLinesFromStats(plane); - SortXPlane(plane); - } -} - -void DeriveLinesFromStats(XPlane* device_trace) { - XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(device_trace); - XPlaneBuilder plane_builder(device_trace); - int64_t start_timestamp_ns = GetStartTimestampNs(*device_trace); - DerivedXLineBuilder tf_ops( - &plane_builder, tensorflow::profiler::kThreadIdTfOp, - tensorflow::profiler::kTensorFlowOpLineName, start_timestamp_ns, {}); - DerivedXLineBuilder tf_name_scope( - &plane_builder, tensorflow::profiler::kThreadIdTfNameScope, - tensorflow::profiler::kTensorFlowNameScopeLineName, start_timestamp_ns, - {&tf_ops}); - DerivedXLineBuilder source( - &plane_builder, tensorflow::profiler::kThreadIdSource, - tensorflow::profiler::kSourceLineName, start_timestamp_ns, {}); - - HostOffloadEventProcessor host_offload_event_processor(&plane_builder, - start_timestamp_ns); - - for (const XEventVisitor& event : - GetSortedEvents(plane_visitor, true)) { - tsl::profiler::Timespan event_span = event.GetTimespan(); - std::optional tf_op_name; - std::optional source_info; - std::optional group_id; - std::optional is_async; - auto for_each_stat = [&](const XStatVisitor& stat) { - if (stat.Type() == StatType::kTfOp) { - tf_op_name = stat.StrOrRefValue(); - } else if (stat.Type() == StatType::kGroupId) { - group_id = stat.IntOrUintValue(); - } else if (stat.Type() == StatType::kSourceInfo) { - source_info = stat.StrOrRefValue(); - } else if (stat.Type() == StatType::kIsAsync) { - is_async = stat.IntOrUintValue(); - } - }; - event.Metadata().ForEachStat(for_each_stat); - event.ForEachStat(for_each_stat); - - if (is_async && *is_async) continue; // Disregard asynchronous events. - - if (tf_op_name && !tf_op_name->empty()) { - ProcessTfOpEvent(*tf_op_name, event_span, group_id, plane_builder, - tf_name_scope, tf_ops); - } - if (source_info && !source_info->empty()) { - source.ExpandOrAddEvent( - *plane_builder.GetOrCreateEventMetadata(*source_info), event_span, - group_id); - } - if (host_offload_event_processor.IsHostOffloadOpName(event)) { - host_offload_event_processor.ProcessHostOffloadOpEvent(event, group_id); - } - } - tf_name_scope.ResetLastEvents(0); - - RemoveEmptyLines(device_trace); -} - -void DeriveLinesForXlaCpuOps(XPlane* host_trace) { - if (host_trace == nullptr || - !absl::StartsWith(host_trace->name(), kHostThreadsPlaneName)) - return; - XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(host_trace); - XPlane destination_plane; - XPlaneBuilder plane_builder(&destination_plane); - int64_t line_id = tsl::profiler::kThreadIdHostXlaRegionStart; - visitor.ForEachLine([&](const XLineVisitor& line) { - int64_t start_timestamp_ns = line.TimestampNs(); - DerivedXLineBuilder tf_ops( - &plane_builder, line_id++, - absl::StrCat(line.Name(), "-", - tensorflow::profiler::kTensorFlowOpLineName), - start_timestamp_ns, {}); - DerivedXLineBuilder tf_name_scope( - &plane_builder, line_id++, - absl::StrCat(line.Name(), "-", - tensorflow::profiler::kTensorFlowNameScopeLineName), - start_timestamp_ns, {&tf_ops}); - DerivedXLineBuilder xla_cpu_ops( - &plane_builder, line_id++, - absl::StrCat(line.Name(), "-", tsl::profiler::kXlaModuleLineName), - start_timestamp_ns, {}); - line.ForEachEvent([&](const XEventVisitor& event) { - std::optional hlo_module_name; - std::optional framework_op_name; - event.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - // TODO: Add additional stats for framework ops. - switch (stat.Type().value()) { - case StatType::kHloModule: - hlo_module_name = stat.StrOrRefValue(); - break; - case StatType::kTfOp: - framework_op_name = stat.StrOrRefValue(); - break; - } - }); - if (hlo_module_name.has_value()) { - xla_cpu_ops.ExpandOrAddEvent( - *plane_builder.GetOrCreateEventMetadata(*hlo_module_name), - event.GetTimespan(), std::nullopt); - if (framework_op_name.has_value()) { - ProcessTfOpEvent(*framework_op_name, event.GetTimespan(), - std::nullopt, plane_builder, tf_name_scope, tf_ops); - } - } - }); - }); - RemoveEmptyLines(&destination_plane); - MergePlanes(destination_plane, host_trace); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/derived_timeline.h b/tensorflow/core/profiler/utils/derived_timeline.h index a152327319ccff..f2d41461fa2f1d 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.h +++ b/tensorflow/core/profiler/utils/derived_timeline.h @@ -15,186 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_DERIVED_TIMELINE_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_DERIVED_TIMELINE_H_ -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -// Store the mapping from child scope range id to parent scope range id, which -// logically form a scope range call stack tree/forest. -typedef absl::flat_hash_map - ScopeRangeIdTree; - -// Helper for deriving XEvents. -class DerivedXEventBuilder { - public: - DerivedXEventBuilder(XEventBuilder event, std::optional group_id, - std::optional scope_range_id = std::nullopt); - - bool ShouldExpand(const XEventMetadata& event_metadata, - std::optional group_id, - std::optional scope_range_id = std::nullopt) const; - - void Expand(tsl::profiler::Timespan event_span); - tsl::profiler::Timespan GetTimespan() const { return event_.GetTimespan(); } - void SetTimespan(tsl::profiler::Timespan event_span) { - event_.SetTimespan(event_span); - } - - template - void SetOrAddStatValue(const XStatMetadata& metadata, ValueT&& value) { - event_.SetOrAddStatValue(metadata, std::forward(value)); - } - - private: - XEventBuilder event_; - std::optional group_id_; - std::optional scope_range_id_; -}; - -// Helper for deriving an XLine from events in another XLine. -class DerivedXLineBuilder { - public: - DerivedXLineBuilder(XPlaneBuilder* plane, int64_t line_id, - absl::string_view name, int64_t timestamp_ns, - std::vector dependent_lines); - - XLineBuilder& Line() { return line_; } - - // Either merges event with the last event or creates a new event on this - // XLine. group_id and low_level_event_name may be passed to separate - // consecutive invocations of the same event, depending on the XEvent type: - // TF-op, TF name scope: both group_id and low_level_event_name are used. - // HLO-op, step: only group_id is used. - // HLO module, source: both group_id and low_level_event_name are NOT used. - // If scope_range_id is provided, it will be compared with the one in the - // event which is to be merged with. If they are different, merging is not - // allowed. - void ExpandOrAddEvent(const XEventMetadata& event_metadata, - tsl::profiler::Timespan event_span, - std::optional group_id, - std::optional scope_range_id = std::nullopt); - - // The multi-level version of ExpandOrAddEvent. Here, the XEvents at different - // levels all share the same group_id and low_level_event_name. - // Conceptually, the scope_range_ids should be of same length as the - // events_metadata_per_level. However, if it is shorter, this function will - // assume the missing elements at the end of scope_range_ids vector with the - // value of std::nullopt; and if it is longer, the extra elements in - // scope_range_ids will be ignored. - void ExpandOrAddEvents( - const std::vector& events_metadata_per_level, - tsl::profiler::Timespan event_span, std::optional group_id, - absl::Span> scope_range_ids = {}); - - // Reset the last events lower than or equal to the given level. - void ResetLastEvents(int level = 0); - - // To avoid using templates while need hide its implementation in .cc file, - // use two functions to set stat value for int64_t and uint64_t here. - void AddStatToLevelEvent(int level, const XStatMetadata& metadata, - int64_t value); - - void AddStatToLevelEvent(int level, const XStatMetadata& metadata, - uint64_t value); - - const XStatMetadata* GetCorrelationIdMetadata() const { - return correlation_id_metadata_; - } - - const XStatMetadata* GetCudaGraphIdMetadata() const { - return cuda_graph_id_metadata_; - } - - private: - // If the last event of the given level has the same metadata, expands it to - // include the time until the given event's end time. - // Otherwise, adds a new event and clears last_event_by_level_ for the levels - // below the given level and all levels of the dependent lines. Clearing - // last_event_by_level_ prevents a nested event from growing larger than the - // parent event(s). - void ExpandOrAddLevelEvent(const XEventMetadata& event_metadata, - tsl::profiler::Timespan event_span, - std::optional group_id, - std::optional scope_range_id, int level); - void AdjustDurationForTraceViewer(int level); - - const XStatMetadata* group_id_stat_metadata_ = nullptr; - const XStatMetadata* correlation_id_metadata_ = nullptr; - const XStatMetadata* cuda_graph_id_metadata_ = nullptr; - - XLineBuilder line_; - absl::flat_hash_map> - last_event_by_level_; - std::vector dependent_lines_; - bool is_gpu_plane_ = false; -}; - -struct Symbol { - absl::string_view tf_op_name; - std::string source_info; - std::string hlo_text; -}; - -using SymbolResolver = std::function program_id, - absl::string_view hlo_module_name, - absl::string_view hlo_op)>; - -// Derives TF name scope and op events from the TF op's fully qualified name -// with the name of the originating low-level event. -void ProcessTfOpEvent(absl::string_view tf_op_full_name, - tsl::profiler::Timespan event_span, - std::optional group_id, - XPlaneBuilder& plane_builder, - DerivedXLineBuilder& tf_name_scope_line_builder, - DerivedXLineBuilder& tf_op_line_builder); - -// Derives "Steps" line from group_id XStat in XEvents. -void DeriveStepEventsFromGroups( - const tsl::profiler::GroupMetadataMap& group_metadata_map, - XPlane* device_trace); - -// Derives "TensorFlow Ops", "TensorFlow Name Scope", "XLA Ops" and "XLA Module" -// lines in an NVIDIA_GPU device trace from data passed as ScopedAnnotations and -// stored as XStats in XEvents corresponding to GPU Kernels. Consecutive -// annotations with the same value are merged into a single event except for XLA -// modules. The device_trace is both input and output. -void DeriveEventsFromAnnotations( - const SymbolResolver& symbol_resolver, XPlane* device_trace, - const ScopeRangeIdTree* scope_range_id_tree = nullptr); - -// Derives "Launch Activities Summary" line from host trace. -void DeriveEventsFromHostTrace( - const XPlane* host_trace, - const tsl::profiler::GroupMetadataMap& group_metadata_map, - std::vector device_traces); - -// Loops through XPlanes of input XSpace, if it is "device" XPlane, generating -// derived timelines for the plane by calling DeriveEventsFromAnnotations. -void GenerateDerivedTimeLines( - const tsl::profiler::GroupMetadataMap& group_metadata_map, XSpace* space); - -// Derives `Tensorflow Ops`, `Tensorflow Name Scope` and `Source Code` lines -// from device_trace. -void DeriveLinesFromStats(tensorflow::profiler::XPlane* device_trace); - -// Devices Framework Op and Module lines for XLA:CPU ops. -void DeriveLinesForXlaCpuOps(tensorflow::profiler::XPlane* host_trace); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/derived_timeline.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_DERIVED_TIMELINE_H_ diff --git a/tensorflow/core/profiler/utils/derived_timeline_test.cc b/tensorflow/core/profiler/utils/derived_timeline_test.cc deleted file mode 100644 index 1e728003531fae..00000000000000 --- a/tensorflow/core/profiler/utils/derived_timeline_test.cc +++ /dev/null @@ -1,576 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/derived_timeline.h" - -#include - -#include -#include -#include -#include - -#include -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/group_events.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/utils/trace_utils.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -TEST(DerivedTimelineTest, EmptySpaceTest) { - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - GenerateDerivedTimeLines(group_metadata_map, &space); - EXPECT_EQ(space.planes_size(), 0); -} - -// Checks that HLO module events are expanded. -TEST(DerivedTimelineTest, HloModuleNameTest) { - const absl::string_view kHloModuleName = "hlo_module"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kHloModule, kHloModuleName}, - {StatType::kKernelDetails, kKernelDetails}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kHloModule, kHloModuleName}, - {StatType::kKernelDetails, kKernelDetails}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // Only the hlo module line is added and other empty lines are removed at the - // end. - EXPECT_EQ(plane_visitor.NumLines(), 2); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule); - EXPECT_EQ(line_visitor.NumEvents(), 1); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kHloModuleName); - }); - }); -} - -// Checks that HLO module events are expanded, with both same name and scope -// range id. Note that strange XStatValue{int64_t{10}} is to handle different -// compilers behavior. -TEST(DerivedTimelineTest, HloModuleNameSameScopeRangeIdTest) { - const absl::string_view kHloModuleName = "hlo_module"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kHloModule, XStatValue{kHloModuleName}}, - {StatType::kKernelDetails, XStatValue{kKernelDetails}}, - {StatType::kScopeRangeId, XStatValue{int64_t{10}}}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kHloModule, XStatValue{kHloModuleName}}, - {StatType::kKernelDetails, XStatValue{kKernelDetails}}, - {StatType::kScopeRangeId, XStatValue{int64_t{10}}}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // Only the hlo module line is added and other empty lines are removed at the - // end. - EXPECT_EQ(plane_visitor.NumLines(), 2); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule); - EXPECT_EQ(line_visitor.NumEvents(), 1); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kHloModuleName); - }); - }); -} - -// Checks that HLO module events are expanded, with same name only, -// but different scope range id. -TEST(DerivedTimelineTest, HloModuleNameDifferentScopeRangeIdTest) { - const absl::string_view kHloModuleName = "hlo_module"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kHloModule, XStatValue{kHloModuleName}}, - {StatType::kKernelDetails, XStatValue{kKernelDetails}}, - {StatType::kScopeRangeId, XStatValue{int64_t{10}}}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kHloModule, XStatValue{kHloModuleName}}, - {StatType::kKernelDetails, XStatValue{kKernelDetails}}, - {StatType::kScopeRangeId, XStatValue{int64_t{20}}}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // Only the hlo module line is added and other empty lines are removed at the - // end. - EXPECT_EQ(plane_visitor.NumLines(), 2); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule); - EXPECT_EQ(line_visitor.NumEvents(), 2); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kHloModuleName); - }); - }); -} - -// Checks that HLO module events are expanded. -TEST(DerivedTimelineTest, NoHloModuleNameTest) { - const absl::string_view kKernelDetails = "kernel_details"; - const uint64_t kCudaGraphExecId = 1; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane& plane = *GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(&plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kKernelDetails, kKernelDetails}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kKernelDetails, kKernelDetails}}); - // Also add a CudaGraph Execution event. - CreateXEvent(&plane_builder, &line_builder, "op3", 500, 100, - {{StatType::kCudaGraphExecId, kCudaGraphExecId}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&plane); - // Only the hlo module line is added and other empty lines are removed at the - // end. - EXPECT_EQ(plane_visitor.NumLines(), 1); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule); - EXPECT_EQ(line_visitor.NumEvents(), 0); - }); -} - -// Checks that the TF op events are expanded. -TEST(DerivedTimelineTest, TfOpLineTest) { - const absl::string_view kTfOpName = "mul:Mul"; - const absl::string_view kKernelDetails = "kernel_details"; - const uint64_t kCudaGraphExecId = 1; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - // Also add a CudaGraph Execution event. - CreateXEvent(&plane_builder, &line_builder, "op3", 500, 100, - {{StatType::kTfOp, kTfOpName}, - {StatType::kCudaGraphExecId, kCudaGraphExecId}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // Only the tf op line is added and other empty lines are removed at the end. - EXPECT_EQ(plane_visitor.NumLines(), 2); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_EQ(line_visitor.Id(), kThreadIdTfOp); - EXPECT_EQ(line_visitor.NumEvents(), 1); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kTfOpName); - EXPECT_EQ(event_visitor.OffsetPs(), 0); - EXPECT_EQ(event_visitor.DurationPs(), 600); - }); - }); -} - -// Checks that the dependency between the step line and the TF op line prevents -// TF op events from being expanded. -TEST(DerivedTimelineTest, DependencyTest) { - constexpr int64_t kFirstGroupId = 0; - constexpr int64_t kSecondGroupId = 1; - - const absl::string_view kTfOpName = "mul:Mul"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map( - {{0, {"train 0"}}, {1, {"train 1"}}}); - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kGroupId, kFirstGroupId}, - {StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kGroupId, kSecondGroupId}, - {StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // The step line and the TF op line are added. - EXPECT_EQ(plane_visitor.NumLines(), 3); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_TRUE(line_visitor.Id() == kThreadIdStepInfo || - line_visitor.Id() == kThreadIdTfOp); - EXPECT_EQ(line_visitor.NumEvents(), 2); - }); -} - -// Checks that the TF op events are expanded. -TEST(DerivedTimelineTest, TfOpNameScopeTest) { - const absl::string_view kTfOpName = "scope1/scope2/mul:Mul"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // The TF name scope line and the TF op line are added. - EXPECT_EQ(plane_visitor.NumLines(), 3); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - int64_t line_id = line_visitor.Id(); - if (line_id == 0) { - return; - } else if (line_id == kThreadIdTfNameScope) { - EXPECT_EQ(line_visitor.NumEvents(), 2); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.OffsetPs(), 0); - EXPECT_EQ(event_visitor.DurationPs(), 500); - }); - } else if (line_id == kThreadIdTfOp) { - EXPECT_EQ(line_visitor.NumEvents(), 1); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kTfOpName); - EXPECT_EQ(event_visitor.OffsetPs(), 0); - EXPECT_EQ(event_visitor.DurationPs(), 500); - }); - } - }); -} - -// Checks that the TF op events are expanded. -TEST(DerivedTimelineTest, TfNameScopeMaintainsOrder) { - const absl::string_view kTfOpName = "scope1/scope2/mul:Mul"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = - GetOrCreateTpuXPlane(&space, /*device_ordinal=*/0, "TPU V4", 0, 0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 10000, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // The TF name scope line and the TF op line are added. - EXPECT_EQ(plane_visitor.NumLines(), 3); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Name() == tsl::profiler::kTensorFlowNameScopeLineName) { - EXPECT_EQ(line_visitor.NumEvents(), 2); - uint64_t expected_duration = 10000; - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - LOG(INFO) << "scope: " << event_visitor.Name(); - EXPECT_EQ(event_visitor.OffsetPs(), 0); - EXPECT_EQ(event_visitor.DurationPs(), expected_duration); - expected_duration -= 1000; - }); - } - }); -} - -// Checks only derived events from line with most events for gpu trace. -TEST(DerivedTimelineTest, OnlyDerivedEventsFromLineWithMostEvents) { - const absl::string_view kTfOpName = "scope1/scope2/mul:Mul"; - const absl::string_view kKernelDetails = "kernel_details"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - // Add first line with two events. - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - // Add second line with only one event. - auto line_builder_2 = plane_builder.GetOrCreateLine(1); - CreateXEvent(&plane_builder, &line_builder_2, "op3", 50, 850, - {{StatType::kTfOp, kTfOpName}, - {StatType::kKernelDetails, kKernelDetails}}); - // Derive lines for the plane. - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // The TF name scope line and the TF op line are added. - EXPECT_EQ(plane_visitor.NumLines(), 4); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - int64_t line_id = line_visitor.Id(); - if (line_id == 0 || line_id == 1) { - return; - } else if (line_id == kThreadIdTfNameScope) { - EXPECT_EQ(line_visitor.NumEvents(), 2); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.OffsetPs(), 0); - // When derived from first line only, we should get single event which - // starts from op1' start (0), end at op2's end (200 + 300), - // duration is 500. - // If derived from both lines, the derived event duration will be - // (50 + 850) - 0 = 900. - EXPECT_EQ(event_visitor.DurationPs(), 500); - }); - } else if (line_id == kThreadIdTfOp) { - EXPECT_EQ(line_visitor.NumEvents(), 1); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kTfOpName); - EXPECT_EQ(event_visitor.OffsetPs(), 0); - EXPECT_EQ(event_visitor.DurationPs(), 500); - }); - } - }); -} - -// Checks that the TF op events are expanded. -TEST(DerivedTimelineTest, TfOpNameScopeShrinkTest) { - { - // Case 1: shirnk is possible. - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 10000, - {{StatType::kTfOp, "a/b/c/Add:Add"}, - {StatType::kKernelDetails, "blah"}}); - CreateXEvent( - &plane_builder, &line_builder, "op2", 20000, 30000, - {{StatType::kTfOp, "a/d/Mul:Mul"}, {StatType::kKernelDetails, "blah"}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // The TF name scope line and the TF op line are added. - EXPECT_EQ(plane_visitor.NumLines(), 3); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - int64_t line_id = line_visitor.Id(); - if (line_id == 0) { - return; - } else if (line_id == kThreadIdTfNameScope) { - EXPECT_EQ(line_visitor.NumEvents(), 4); - std::map durations; - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - durations[event_visitor.Name()] = event_visitor.DurationPs(); - }); - EXPECT_EQ(durations["a"], 50000); - EXPECT_EQ(durations["b"], 10000); - EXPECT_EQ(durations["c"], 9000); // shrinked to be distinguish with b. - EXPECT_EQ(durations["d"], 30000); - } - }); - } - { - // Case 2: shirnk is impossible due to top event is too small. - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 10000, - {{StatType::kTfOp, "a/b/c/d/e/Add:Add"}, - {StatType::kKernelDetails, "blah"}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 10000, 2000, - {{StatType::kTfOp, "a/b/c/d/f/Sub:Sub"}, - {StatType::kKernelDetails, "blah"}}); - CreateXEvent( - &plane_builder, &line_builder, "op3", 20000, 30000, - {{StatType::kTfOp, "a/g/Mul:Mul"}, {StatType::kKernelDetails, "blah"}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // The TF name scope line and the TF op line are added. - EXPECT_EQ(plane_visitor.NumLines(), 3); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - int64_t line_id = line_visitor.Id(); - if (line_id == 0) { - return; - } else if (line_id == kThreadIdTfNameScope) { - EXPECT_EQ(line_visitor.NumEvents(), 7); - std::map durations; - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - durations[event_visitor.Name()] = event_visitor.DurationPs(); - }); - for (const auto& [name, duration] : durations) { - LOG(ERROR) << name << ": " << duration; - } - EXPECT_EQ(durations["a"], 50000); - EXPECT_EQ(durations["b"], 12000); - EXPECT_EQ(durations["c"], 11000); // shrinked to be distinguish with b. - EXPECT_EQ(durations["d"], 11000); // not shrinked because of f. - EXPECT_EQ(durations["e"], 10000); - EXPECT_EQ(durations["f"], 1000); - EXPECT_EQ(durations["g"], 30000); - } - }); - } -} - -// Checks that XLA Ops mapping to CudaGraph launch has extra stats. -TEST(DerivedTimelineTest, XloOpHasCudaGraphStats) { - constexpr absl::string_view kModuleName = "module"; - constexpr absl::string_view kHloOpName = "op_level_2"; - constexpr absl::string_view kKernelDetails = "kernel_details"; - constexpr int64_t kGroupIdValue = 1; - constexpr int64_t kCorrelationIdValue = 10000; - const uint64_t kCudaGraphIdValue = 20; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - - // Build Input Plane/Line/Events and derive events from them. - XPlane& plane = *GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); - XPlaneBuilder plane_builder(&plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kKernelDetails, kKernelDetails}, - {StatType::kGroupId, kGroupIdValue}, - {StatType::kHloModule, kModuleName}, - {StatType::kHloOp, kHloOpName}, - {StatType::kCorrelationId, kCorrelationIdValue}, - {StatType::kCudaGraphId, kCudaGraphIdValue}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kKernelDetails, kKernelDetails}, - {StatType::kGroupId, kGroupIdValue}, - {StatType::kHloModule, kModuleName}, - {StatType::kHloOp, kHloOpName}, - {StatType::kCorrelationId, kCorrelationIdValue}, - {StatType::kCudaGraphId, kCudaGraphIdValue}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - - // Check that the HLO op line is added and has the extra stats for the first - // derived event. - size_t num_hlo_op_line = 0; - size_t num_events = 0; - std::optional correlation_id; - std::optional cuda_graph_id; - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&plane); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == kThreadIdHloOp) { - num_hlo_op_line++; - if (num_hlo_op_line == 1) { - num_events = line_visitor.NumEvents(); - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - correlation_id = event_visitor.GetStat(StatType::kCorrelationId); - cuda_graph_id = event_visitor.GetStat(StatType::kCudaGraphId); - }); - } - } - }); - EXPECT_EQ(num_hlo_op_line, 1); - EXPECT_EQ(num_events, 1); - ASSERT_TRUE(correlation_id.has_value()); - EXPECT_EQ(correlation_id->IntValue(), kCorrelationIdValue); - ASSERT_TRUE(cuda_graph_id.has_value()); - EXPECT_EQ(cuda_graph_id->UintValue(), kCudaGraphIdValue); -} - -TEST(DerivedTimelineTest, DeriveLinesForXlaCpuOps) { - XPlane xplane; - XPlaneBuilder plane_builder(&xplane); - plane_builder.SetName(tsl::profiler::kHostThreadsPlaneName); - - absl::string_view main_line_name = "main"; - auto line_builder = plane_builder.GetOrCreateLine(0); - line_builder.SetName(main_line_name); - CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kHloModule, "Module1"}}); - CreateXEvent(&plane_builder, &line_builder, "op2", 200, 400, - {{StatType::kHloModule, "Module2"}}); - - DeriveLinesForXlaCpuOps(&xplane); - - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&xplane); - EXPECT_EQ(plane_visitor.NumLines(), 2); - plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { - if (line_visitor.Name() == main_line_name) return; - line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { - if (event_visitor.Name() == "Module1") { - EXPECT_EQ(event_visitor.DurationPs(), 100); - EXPECT_EQ(event_visitor.OffsetPs(), 0); - } else if (event_visitor.Name() == "Module2") { - EXPECT_EQ(event_visitor.DurationPs(), 400); - EXPECT_EQ(event_visitor.OffsetPs(), 200); - } else { - FAIL() << "Found Event " << event_visitor.Name(); - } - }); - }); -} - -TEST(DerivedTimelineTest, MergeAndNoMerge) { - constexpr absl::string_view kHloModuleName = "Framework Ops"; - static constexpr absl::string_view kTfOpName = "abc:model/layer/MatMul_1"; - XSpace space; - tsl::profiler::GroupMetadataMap group_metadata_map; - XPlane* plane = - GetOrCreateTpuXPlane(&space, /*device_ordinal=*/0, "DummyTPU", 1.0, 1.0); - XPlaneBuilder plane_builder(plane); - auto line_builder = plane_builder.GetOrCreateLine(0); - CreateXEvent( - &plane_builder, &line_builder, "op1", 0, 100, - {{StatType::kHloModule, kHloModuleName}, {StatType::kTfOp, kTfOpName}}); - CreateXEvent( - &plane_builder, &line_builder, "op2", 200, 300, - {{StatType::kHloModule, kHloModuleName}, {StatType::kTfOp, kTfOpName}}); - // The above two events are merged into one. This event will not be merged - // because the gap is > 2x(0..200+300) = 1000. - CreateXEvent( - &plane_builder, &line_builder, "op3", 1501, 300, - {{StatType::kHloModule, kHloModuleName}, {StatType::kTfOp, kTfOpName}}); - GenerateDerivedTimeLines(group_metadata_map, &space); - XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane); - // Only the hlo module line is added and other empty lines are removed at the - // end. - EXPECT_EQ(plane_visitor.NumLines(), 2); - plane_visitor.ForEachLine([](const XLineVisitor& line_visitor) { - if (line_visitor.Id() == 0) return; - EXPECT_EQ(line_visitor.NumEvents(), 2); - line_visitor.ForEachEvent([](const XEventVisitor& event_visitor) { - EXPECT_EQ(event_visitor.Name(), kTfOpName); - }); - }); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/device_caps_utils.cc b/tensorflow/core/profiler/utils/device_caps_utils.cc deleted file mode 100644 index 3b149ad528b654..00000000000000 --- a/tensorflow/core/profiler/utils/device_caps_utils.cc +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/device_caps_utils.h" - -#include - -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -void SetDeviceCaps(const DeviceCapabilities& caps, XPlane* plane) { - XPlaneBuilder xplane(plane); - int clock_rate_in_khz = - static_cast(caps.clock_rate_in_ghz() * 1000000.0); - xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapClockRateKHz)), - clock_rate_in_khz); - xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapCoreCount)), - caps.num_cores()); - xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapMemoryBandwidth)), - caps.memory_bandwidth()); - xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapMemorySize)), - caps.memory_size_in_bytes()); - if (caps.has_compute_capability()) { - xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapComputeCapMajor)), - caps.compute_capability().major()); - xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapComputeCapMinor)), - caps.compute_capability().minor()); - } - xplane.AddStatValue( - *xplane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kDevVendor)), - caps.device_vendor()); -} - -DeviceCapabilities GetDeviceCaps(const XPlane& plane) { - DeviceCapabilities caps; - XPlaneVisitor xplane = tsl::profiler::CreateTfXPlaneVisitor(&plane); - xplane.ForEachStat([&](const tensorflow::profiler::XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kDevCapClockRateKHz: - caps.set_clock_rate_in_ghz(stat.IntOrUintValue() / 1000000.0); - break; - case StatType::kDevCapCoreCount: - caps.set_num_cores(stat.IntOrUintValue()); - break; - case StatType::kDevCapMemoryBandwidth: - caps.set_memory_bandwidth(stat.IntOrUintValue()); - break; - case StatType::kDevCapMemorySize: - caps.set_memory_size_in_bytes(stat.IntOrUintValue()); - break; - case StatType::kDevCapComputeCapMajor: - caps.mutable_compute_capability()->set_major(stat.IntOrUintValue()); - break; - case StatType::kDevCapComputeCapMinor: - caps.mutable_compute_capability()->set_minor(stat.IntOrUintValue()); - break; - case StatType::kDevVendor: - caps.set_device_vendor(std::string(stat.StrOrRefValue())); - break; - } - }); - return caps; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/device_caps_utils.h b/tensorflow/core/profiler/utils/device_caps_utils.h index c6c84133db3aaf..a500ed1d18acc6 100644 --- a/tensorflow/core/profiler/utils/device_caps_utils.h +++ b/tensorflow/core/profiler/utils/device_caps_utils.h @@ -16,16 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_DEVICE_CAPS_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_DEVICE_CAPS_UTILS_H_ -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -void SetDeviceCaps(const DeviceCapabilities& caps, XPlane* plane); -DeviceCapabilities GetDeviceCaps(const XPlane& plane); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/device_caps_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_DEVICE_CAPS_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/diagnostics.cc b/tensorflow/core/profiler/utils/diagnostics.cc deleted file mode 100644 index c4ff0f2069f07a..00000000000000 --- a/tensorflow/core/profiler/utils/diagnostics.cc +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/diagnostics.h" - -#include - -#include "absl/algorithm/container.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" - -namespace tensorflow { -namespace profiler { - -const absl::string_view kErrorIncompleteStep = - "Incomplete step observed and hence the step time is unknown." - "Instead, we use the trace duration as the step time. This may happen" - " if your profiling duration is shorter than the step time. In this" - " case, you may try to profile longer."; - -const absl::string_view kErrorEmptyIntersect = - "Although there are steps observed on some host(s), the intersection of " - "the steps over all hosts is empty (because the differences among " - "individual host's step sequences are too big). Consequently, the overall " - "step time is " - "unknown."; - -const absl::string_view kErrorNoStepMarker = - "No step marker observed and hence the step time is unknown." - " This may happen if (1) training steps are not instrumented (e.g., if" - " you are not using Keras) or (2) the profiling duration is shorter" - " than the step time. For (1), you need to add step instrumentation;" - " for (2), you may try to profile longer."; - -const absl::string_view kNoDeviceTraceCollected = - "No TensorCore device trace was collected. This might happen if your job " - "hadn't been run on the device when sampling was turned on. You could try " - "the sampling again later."; - -const absl::string_view kStepsDropped = - " steps dropped. This might happen when you profile many hosts and/or many " - "steps. You could try to profile shorter or reduce the number of hosts " - "you profile."; - -void PopulateStepDiagnostics(const OpStats& op_stats, Diagnostics* diag) { - if (op_stats.step_db().use_incomplete_step()) { - *diag->add_warnings() = std::string(kErrorIncompleteStep); - } else if (op_stats.step_db().step_sequence().empty()) { - *diag->add_warnings() = op_stats.step_db().empty_intersect() - ? std::string(kErrorEmptyIntersect) - : std::string(kErrorNoStepMarker); - } - if (op_stats.step_db().num_steps_dropped()) { - *diag->add_warnings() = - absl::StrCat(op_stats.step_db().num_steps_dropped(), kStepsDropped); - } -} - -void PopulateOverviewDiagnostics(const OpStats& op_stats, Diagnostics* diag) { - *diag->mutable_errors() = op_stats.diagnostics().errors(); - absl::c_sort(*diag->mutable_errors()); - if (diag->errors().empty()) { - // Shows run-environment error only if there is no other existing error. - if (op_stats.run_environment().device_type() != "CPU" && - op_stats.run_environment().device_core_count() <= 0) { - *diag->add_errors() = std::string(kNoDeviceTraceCollected); - } - } - *diag->mutable_warnings() = op_stats.diagnostics().warnings(); - PopulateStepDiagnostics(op_stats, diag); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/diagnostics.h b/tensorflow/core/profiler/utils/diagnostics.h index 25fb16900f2575..67eb4020d54c14 100644 --- a/tensorflow/core/profiler/utils/diagnostics.h +++ b/tensorflow/core/profiler/utils/diagnostics.h @@ -16,30 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_DIAGNOSTICS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_DIAGNOSTICS_H_ -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/macros.h" -#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -// Error message that the visualization is based on incomplete step. -TF_CONST_INIT extern const absl::string_view kErrorIncompleteStep; - -// Error message that no step marker is seen and visualization contains no -// step info. -TF_CONST_INIT extern const absl::string_view kErrorNoStepMarker; - -TF_CONST_INIT extern const absl::string_view kNoDeviceTraceCollected; - -TF_CONST_INIT extern const absl::string_view kStepsDropped; - -void PopulateStepDiagnostics(const OpStats& op_stats, Diagnostics* diag); - -void PopulateOverviewDiagnostics(const OpStats& op_stats, Diagnostics* diag); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/diagnostics.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_DIAGNOSTICS_H_ diff --git a/tensorflow/core/profiler/utils/event_span.cc b/tensorflow/core/profiler/utils/event_span.cc deleted file mode 100644 index b5e9b813a15c01..00000000000000 --- a/tensorflow/core/profiler/utils/event_span.cc +++ /dev/null @@ -1,449 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/utils/event_span.h" - -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -// Representing a boundary of an event. -struct EventBoundary { - // Time at this boundary. - uint64 time_ps; - // Type of the event. - EventType type; - // True if this is the start of the event; False if this is the end. - bool is_start; - EventBoundary(uint64 time_ps, EventType type, bool is_start) - : time_ps(time_ps), type(type), is_start(is_start) {} -}; - -// Returns true if EventBoundary a should appear before EventBoundary b. -bool CmpEventBoundaries(const EventBoundary& a, const EventBoundary& b) { - if (a.time_ps == b.time_ps) { - if (a.is_start == b.is_start) { - // Puts the higher-priority type before the lower-priority type if they - // have the same time and same boundary type. - return a.type > b.type; - } else { - // Puts the "end" bounary before the "start" boundary if they have the - // same time. - return !a.is_start; - } - } - // In ascending order of time. - return a.time_ps < b.time_ps; -} - -// Generates vector of event boundaries from the given overlapped_events. -std::vector GenerateEventBoundaries( - const std::vector& overlapped_events) { - std::vector boundaries; - boundaries.reserve(2 * overlapped_events.size()); - for (const auto& event : overlapped_events) { - boundaries.push_back( - {event.span.begin_ps(), event.type, /*is_start=*/true}); - boundaries.push_back({event.span.end_ps(), event.type, /*is_start=*/false}); - } - absl::c_sort(boundaries, CmpEventBoundaries); - return boundaries; -} - -// A class to track the highest priority that an event should be assigned. -class PriorityTracker { - private: - // The current maximum priority. - EventType current_max_priority_; - // A count for each possible priority. - std::vector priority_count_; - - public: - PriorityTracker() { - current_max_priority_ = UNKNOWN_TIME; - priority_count_.resize(LAST_EVENT_TYPE + 1, 0); - } - // Updates current_max_priority_ and priority_count_[] given the boundary. - // Returns the new current_max_priority_. - EventType Update(const EventBoundary& boundary) { - EventType event_type = boundary.type; - bool is_start = boundary.is_start; - if (is_start) { - priority_count_[event_type]++; - if (event_type > current_max_priority_) { - current_max_priority_ = event_type; - } - } else { - priority_count_[event_type]--; - if (event_type == current_max_priority_ && - priority_count_[event_type] == 0) { - // Reduces current_max_priority_ to the first event type (starting from - // the highest priority) that has a non-zero count. - bool found = false; - for (int i = event_type - 1; i >= 0; i--) { - if (priority_count_[i] > 0) { - current_max_priority_ = static_cast(i); - found = true; - break; - } - } - if (!found) current_max_priority_ = UNKNOWN_TIME; - } - } - return current_max_priority_; - } -}; - -constexpr int kNumGenericEventTypes = GenericEventType::kLastGenericEventType - - GenericEventType::kFirstGenericEventType + - 1; - -using GenericEventTypeStrMap = - absl::flat_hash_map; - -const GenericEventTypeStrMap& GetGenericEventTypeStrMap() { - static const auto* generic_event_type_str_map = new GenericEventTypeStrMap({ - {kDeviceCompute, "Device compute"}, - {kDeviceToDevice, "Device to device"}, - {kDeviceCollectives, "Device collective communication"}, - {kHostCompute, "Host compute"}, - {kHostPrepare, "Kernel launch"}, - {kInput, "Input"}, - {kOutput, "Output"}, - {kCompile, "Compilation"}, - {kAllOthers, "All others"}, - }); - DCHECK_EQ(generic_event_type_str_map->size(), kNumGenericEventTypes); - return *generic_event_type_str_map; -} - -} // namespace - -absl::string_view GetGenericEventTypeStr(GenericEventType event_type) { - return GetGenericEventTypeStrMap().at(event_type); -} - -std::string PrintEventType(EventType event_type) { - switch (event_type) { - case UNKNOWN_TIME: - return "unknown_time"; - case HOST_COMPUTE: - return "host_compute"; - case HOST_COMPILE: - return "host_compile"; - case HOST_TO_HOST: - return "host_to_host"; - case HOST_TO_DEVICE: - return "host_to_device"; - case HOST_PREPARE: - return "host_prepare"; - case DEVICE_COLLECTIVES: - return "device_collectives"; - case HOST_WAIT_INPUT: - return "host_wait_input"; - case DEVICE_TO_DEVICE: - return "device_to_device"; - case DEVICE_TO_HOST: - return "device_to_host"; - case DEVICE_COMPUTE_32: - return "device_compute_32"; - case DEVICE_COMPUTE_16: - return "device_compute_16"; - case DEVICE_WAIT_DEVICE: - return "device_wait_device"; - case DEVICE_WAIT_HOST: - return "device_wait_host"; - default: - return "unexpected"; - } -} - -std::string PrintEventTypeSpan(const EventTypeSpan& event_type_span) { - return absl::StrCat("(", PrintEventType(event_type_span.type), ", ", - event_type_span.span.DebugString(), ")"); -} - -absl::string_view PrintStepMarkerType(StepMarkerType type) { - switch (type) { - case StepMarkerType::kExplicitHostStepMarker: - return "ExplicitHostStepMarker"; - case StepMarkerType::kImplicitHostStepMarker: - return "ImplicitHostStepMarker"; - case StepMarkerType::kDeviceStepMarker: - return "DeviceStepMarker"; - } -} - -std::string PrintStepMarker(const StepMarker& step_marker) { - return absl::StrCat("(", PrintStepMarkerType(step_marker.type), ", ", - step_marker.event_name, ", ", - step_marker.span.DebugString(), ")"); -} - -std::string PrintStepEvents(const StepEvents& step_events) { - std::vector step_ids; - step_ids.reserve(step_events.size()); - for (const auto& id_details : step_events) { - step_ids.push_back(id_details.first); - } - absl::c_sort(step_ids); - std::string result = "{"; - for (auto id : step_ids) { - absl::StrAppend(&result, "\n"); - auto* details = gtl::FindOrNull(step_events, id); - std::string details_str = details ? details->DebugString() : "()"; - absl::StrAppend(&result, id, ":", details_str); - } - return absl::StrCat(result, "\n}"); -} - -void UnionCombineStepEvents(const StepEvents& src, StepEvents* dst) { - for (const auto& step_details : src) { - int64_t step_id = step_details.first; - const StepDetails& src_details = step_details.second; - StepDetails* dst_details = &(*dst)[step_id]; - dst_details->Combine(src_details); - } -} - -void IntersectCombineStepEvents(const StepEvents& src, StepEvents* dst) { - if (dst->empty()) { - *dst = src; - return; - } - auto iter = dst->begin(); - while (iter != dst->end()) { - if (!src.contains(iter->first)) { - // This is safe because the post-increment is sequenced after the full - // expression that contains it. - dst->erase(iter++); - } else { - iter->second.Combine(src.at(iter->first)); - iter++; - } - } -} - -std::vector ToNonOverlappedEvents( - const std::vector& overlapped_events) { - std::vector event_boundaries = - GenerateEventBoundaries(overlapped_events); - std::vector result; - if (event_boundaries.empty()) return result; - result.reserve(event_boundaries.size()); - PriorityTracker priority_tracker; - for (int64_t i = 0, end = (event_boundaries.size() - 1); i < end; i++) { - EventType highest_priority = priority_tracker.Update(event_boundaries[i]); - result.push_back({highest_priority, tsl::profiler::Timespan::FromEndPoints( - event_boundaries[i].time_ps, - event_boundaries[i + 1].time_ps)}); - } - return result; -} - -// Converts from overlapped step-events to non-overlapped step-events. -StepEvents ToNonOverlappedStepEvents(const StepEvents& overlapped_step_events) { - StepEvents non_overlapped_step_events; - for (const auto& step_events : overlapped_step_events) { - const auto& step_id = step_events.first; - const auto& step_details = step_events.second; - non_overlapped_step_events.try_emplace(step_id, - step_details.ToNonOverlapped()); - } - return non_overlapped_step_events; -} - -void StepDetails::AddMarker(const StepMarker& m) { markers_.push_back(m); } - -void StepDetails::AddEvent(const EventTypeSpan& e) { events_.push_back(e); } - -void StepDetails::AggregateDeviceMemoryTransfers( - const std::vector& device_memory_transfers) { - if (device_memory_transfers.size() != device_memory_transfers_.size()) { - return; // Sanity check. - } - for (size_t i = 0; i < device_memory_transfers.size(); ++i) { - device_memory_transfers_[i].set_occurrence( - device_memory_transfers_[i].occurrence() + - device_memory_transfers[i].occurrence()); - device_memory_transfers_[i].set_bytes_transferred( - device_memory_transfers_[i].bytes_transferred() + - device_memory_transfers[i].bytes_transferred()); - device_memory_transfers_[i].set_time_us( - device_memory_transfers_[i].time_us() + - device_memory_transfers[i].time_us()); - } -} - -void StepDetails::AddCollectiveOpEvent(uint64 core_id, const AllReduceInfo& e) { - *collectives_[core_id].add_all_reduce_info() = e; -} - -void StepDetails::AddDeviceMemoryTransferEvent( - EventType event_type, const tsl::profiler::Timespan& time_span, - uint64 bytes) { - int index = 0; - switch (event_type) { - case HOST_TO_DEVICE: - index = 0; - break; - case DEVICE_TO_HOST: - index = 1; - break; - case DEVICE_TO_DEVICE: - index = 2; - break; - default: - return; - } - device_memory_transfers_[index].set_occurrence( - device_memory_transfers_[index].occurrence() + 1); - device_memory_transfers_[index].set_time_us( - device_memory_transfers_[index].time_us() + - time_span.duration_ps() / 1000000.0); - device_memory_transfers_[index].set_bytes_transferred( - device_memory_transfers_[index].bytes_transferred() + bytes); -} - -tsl::profiler::Timespan StepDetails::StepTime() const { - tsl::profiler::Timespan max_host_step_time; - tsl::profiler::Timespan max_device_step_time; - for (const auto& marker : markers_) { - tsl::profiler::Timespan& cur_max_step_time = - marker.type == StepMarkerType::kDeviceStepMarker ? max_device_step_time - : max_host_step_time; - const tsl::profiler::Timespan& new_step_time = marker.span; - if (new_step_time.duration_ps() > cur_max_step_time.duration_ps()) - cur_max_step_time = new_step_time; - } - // CPU-only profile. - if (max_device_step_time.Empty()) { - return max_host_step_time; - } - - // If the host step time includes the device step time, use the host step - // time. This covers the case where the device is synchronized at the end of - // each step. - if (max_host_step_time.Includes(max_device_step_time)) { - return max_host_step_time; - } - return max_device_step_time; -} - -StepDetails StepDetails::ToNonOverlapped() const { - StepDetails non_overlapped_step_details; - non_overlapped_step_details.markers_ = markers_; - non_overlapped_step_details.events_ = ToNonOverlappedEvents(events_); - non_overlapped_step_details.collectives_ = collectives_; - non_overlapped_step_details.device_memory_transfers_ = - device_memory_transfers_; - non_overlapped_step_details.step_name_ = step_name_; - non_overlapped_step_details.per_core_op_metrics_db_ = per_core_op_metrics_db_; - return non_overlapped_step_details; -} - -void StepDetails::Combine(const StepDetails& other) { - markers_.insert(markers_.end(), other.markers_.begin(), other.markers_.end()); - events_.insert(events_.end(), other.events_.begin(), other.events_.end()); - collectives_.insert(other.collectives_.begin(), other.collectives_.end()); - AggregateDeviceMemoryTransfers(other.device_memory_transfers_); - for (const auto& [core_id, op_metric_db] : other.per_core_op_metrics_db_) { - per_core_op_metrics_db_[core_id] = op_metric_db; - } - if (step_name_.empty()) step_name_ = other.step_name_; -} - -std::string StepDetails::DebugString() const { - std::string result = "(["; - for (int i = 0, end = markers_.size(); i < end; i++) { - if (i > 0) absl::StrAppend(&result, ", "); - absl::StrAppend(&result, PrintStepMarker(markers_[i])); - } - absl::StrAppend(&result, "], ["); - for (int i = 0, end = events_.size(); i < end; i++) { - if (i > 0) absl::StrAppend(&result, ", "); - absl::StrAppend(&result, PrintEventTypeSpan(events_[i])); - } - return absl::StrCat(result, "])"); -} - -bool StepDetails::operator==(const StepDetails& other) const { - const auto& other_markers = other.Markers(); - if (markers_.size() != other_markers.size()) return false; - for (uint64 i = 0; i < markers_.size(); i++) { - if (markers_[i] != other_markers[i]) return false; - } - const auto& other_events = other.Events(); - if (events_.size() != other_events.size()) return false; - for (uint64 i = 0; i < events_.size(); i++) { - if (events_[i] != other_events[i]) return false; - } - return true; -} - -bool operator==(const StepEvents& a, const StepEvents& b) { - if (a.size() != b.size()) return false; - for (const auto& id_details : a) { - const auto a_id = id_details.first; - const auto& a_details = id_details.second; - const auto* b_details = gtl::FindOrNull(b, a_id); - if (b_details == nullptr) return false; - if (a_details != *b_details) return false; - } - return true; -} - -PrecisionStats ComputePrecisionStats( - const StepEvents& nonoverlapped_step_events) { - int64_t compute_32bit_ps = 0; - int64_t compute_16bit_ps = 0; - for (const auto& id_details : nonoverlapped_step_events) { - for (const auto& event : id_details.second.Events()) { - switch (event.type) { - case DEVICE_COMPUTE_32: - compute_32bit_ps += event.span.duration_ps(); - break; - case DEVICE_COMPUTE_16: - compute_16bit_ps += event.span.duration_ps(); - break; - default: - break; - } - } - } - PrecisionStats precision_stats; - precision_stats.set_compute_32bit_ps(compute_32bit_ps); - precision_stats.set_compute_16bit_ps(compute_16bit_ps); - return precision_stats; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/event_span.h b/tensorflow/core/profiler/utils/event_span.h index 2b7b2c75b2f700..04506b6e6c6811 100644 --- a/tensorflow/core/profiler/utils/event_span.h +++ b/tensorflow/core/profiler/utils/event_span.h @@ -16,254 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_EVENT_SPAN_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_EVENT_SPAN_H_ -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" - -namespace tensorflow { -namespace profiler { - -// The various event types. Enumerations are numbered such that a bigger number -// has a higher priority than a smaller number when used in execution-time -// breakdown. -enum EventType { - // No event associated with the time. It could be that the machine was idle or - // executing some events which were not traced. - UNKNOWN_TIME = 0, - // Host is computing. - HOST_COMPUTE = 10, - // Host is preprocessing the data before the execution on device. - HOST_PREPROCESS = 20, - // Host is postprocessing the data after the execution on device. - HOST_POSTPROCESS = 30, - // Host is batching data (for inference). - HOST_BATCH_FORMATION = 40, - // Host runtime, like memory allocation and etc. - HOST_RUNTIME = 50, - // Host is compiling. - HOST_COMPILE = 60, - // Host-to-host communication. - HOST_TO_HOST = 70, - // Host-to-device communication. - HOST_TO_DEVICE = 80, - // Host is preparing to launch a computation on device. - HOST_PREPARE = 90, - // Assigns a smaller priority to DEVICE_COLLECTIVES than HOST_WAIT_INPUT, - // because if an all-reduce event is overlapped with an host-wait-input event, - // we want to count it as waiting for input. - // Collective Ops such as All-Reduce. - DEVICE_COLLECTIVES = 100, - // Host is waiting for input. - HOST_WAIT_INPUT = 110, - // Device-to-device communication. - DEVICE_TO_DEVICE = 120, - // Device-to-host communication. - DEVICE_TO_HOST = 130, - // Device is computing with 32-bit precision. - DEVICE_COMPUTE_32 = 140, - // Device is computing with 16-bit precision. - DEVICE_COMPUTE_16 = 150, - // Device is waiting for another device. - DEVICE_WAIT_DEVICE = 160, - // Device is waiting for host. - DEVICE_WAIT_HOST = 170, - LAST_EVENT_TYPE = DEVICE_WAIT_HOST -}; - -// Generic event types that shown to the user. -enum GenericEventType { - kFirstGenericEventType = 1, - // Device is computing. - kDeviceCompute = kFirstGenericEventType, - // Device-to-device communication. - kDeviceToDevice, - // Collective Ops such as All-Reduce and NCCL. - kDeviceCollectives, - // Host is computing. - kHostCompute, - // Host is preparing to launch a computation on device. - kHostPrepare, - // Device waiting for input from the host. - kInput, - // Device sending output to the host. - kOutput, - // Host is compling. - kCompile, - // No recognized event associated with the time. - kAllOthers, - kLastGenericEventType = kAllOthers, -}; - -// Contains the type and timespan of an event. -struct EventTypeSpan { - EventType type; // type of this event. - tsl::profiler::Timespan span; // timespan of this event. - EventTypeSpan(EventType t, tsl::profiler::Timespan s) : type(t), span(s) {} - // Equality test. - bool operator==(const EventTypeSpan& other) const { - return type == other.type && span == other.span; - } - // Inequality test. - bool operator!=(const EventTypeSpan& other) const { - return !(*this == other); - } -}; - -enum class StepMarkerType { - // "TraceContext" TraceMe events. - kExplicitHostStepMarker, - // Identified by group_events (e.g., FunctionRun, SessionRun). - kImplicitHostStepMarker, - // Derived from the result of group_events. A device step marker starts with - // the first device event of the group and ends with the last event of the - // group. - kDeviceStepMarker, -}; - -// Record of an event that is used as a step marker. -struct StepMarker { - StepMarkerType type; - std::string event_name; // name of this event. - std::string step_name; - tsl::profiler::Timespan span; // timespan of this event. - StepMarker(StepMarkerType step_marker_type, absl::string_view name, - tsl::profiler::Timespan s) - : type(step_marker_type), event_name(name), span(s) {} - // Equality test. - bool operator==(const StepMarker& other) const { - return type == other.type && event_name == other.event_name && - span == other.span; - } - // Inequality test. - bool operator!=(const StepMarker& other) const { return !(*this == other); } -}; - -// Details of a step. Note that this could be the result of combining the -// StepDetails of the same step executed on different cores. -class StepDetails { - public: - StepDetails() : device_memory_transfers_(3) {} - - const std::vector& Markers() const { return markers_; } - const std::vector& Events() const { return events_; } - - const absl::flat_hash_map& Collectives() const { - return collectives_; - } - const std::vector& DeviceMemoryTransfers() const { - return device_memory_transfers_; - } - - absl::flat_hash_map& PerCoreOpMetricsDb() { - return per_core_op_metrics_db_; - } - // Returns the step time. - tsl::profiler::Timespan StepTime() const; - // Adds a step-marker to this step. - void AddMarker(const StepMarker& m); - // Adds an EventTypeSpan to this step. - void AddEvent(const EventTypeSpan& e); - // Adds a collective op to this step. - void AddCollectiveOpEvent(uint64 core_id, const AllReduceInfo& e); - // Appends device memory transfer events to this step. - // Only event type of HOST_TO_DEVICE/DEVICE_TO_DEVICE/DEVICE_TO_HOST are - // allowed. - void AddDeviceMemoryTransferEvent(EventType event_type, - const tsl::profiler::Timespan& time_span, - uint64 bytes); - // Returns the step name. - std::string StepName() const { return step_name_; } - // Sets the name of this step. - void SetStepName(std::string step_name) { step_name_ = step_name; } - - // Converts from overlapped events to non-overlapped events. - StepDetails ToNonOverlapped() const; - - // Combines other. - void Combine(const StepDetails& other); - - // Equality test. - bool operator==(const StepDetails& other) const; - // Inequality test. - bool operator!=(const StepDetails& other) const { return !(*this == other); } - - // Returns a string that prints the content of this object. - std::string DebugString() const; - - void SetPerCoreOpMetricsDb(OpMetricsDb db, uint32 core_id) { - per_core_op_metrics_db_[core_id] = db; - } - - private: - // Accumulates the device memory transfers from another step to this step. - void AggregateDeviceMemoryTransfers( - const std::vector& device_memory_transfers); - - // All step-markers found for marking this step in the traces. There could be - // multiple step-markers for a single step for different reasons. One such - // reason is that there may be one step-marker for the same step on each core; - // so after combining the StepDetails from multiple cores, there would be - // multiple step-markers for the same step. - std::vector markers_; - // All events belonging to this step. - std::vector events_; - // Collective operation related events such as all-reduce etc. - absl::flat_hash_map collectives_; - // Device memory transfers (including time and bytes involved). - // TODO(jiesun): Consider to use IntervalSet instead of just sum up the event - // durations. - std::vector device_memory_transfers_; - std::string step_name_; - - absl::flat_hash_map per_core_op_metrics_db_; -}; - -// Map from step_id to the events happened in that step. -using StepEvents = absl::flat_hash_map; - -// Equality test for StepEvents. -bool operator==(const StepEvents& a, const StepEvents& b); - -// Returns the name of the given EventType. -std::string PrintEventType(EventType event_type); - -// Returns the string of the given GenericEventType. -absl::string_view GetGenericEventTypeStr(GenericEventType event_type); - -// Returns a string that prints the given EventTypeSpan. -std::string PrintEventTypeSpan(const EventTypeSpan& event_type_span); - -// Returns a string that prints the given StepMarker. -std::string PrintStepMarker(const StepMarker& step_marker); - -// Returns a string that prints the given StepEvents. -std::string PrintStepEvents(const StepEvents& step_events); - -// Unions the map of StepEvents and combines the src StepEvents into dst. -void UnionCombineStepEvents(const StepEvents& src, StepEvents* dst); - -// Intersects the map of StepEvents and combines the src StepEvents into dst. -void IntersectCombineStepEvents(const StepEvents& src, StepEvents* dst); - -// Converts from overlapped events to non-overlapped events. -std::vector ToNonOverlappedEvents( - const std::vector& overlapped_events); - -// Converts from overlapped step-events to non-overlapped step events. -StepEvents ToNonOverlappedStepEvents(const StepEvents& overlapped_step_events); - -// Returns the precision stats of the given non-overlapped step events. -PrecisionStats ComputePrecisionStats( - const StepEvents& nonoverlapped_step_events); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/event_span.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_EVENT_SPAN_H_ diff --git a/tensorflow/core/profiler/utils/gpu_event_stats.cc b/tensorflow/core/profiler/utils/gpu_event_stats.cc deleted file mode 100644 index eaa4c6ae17ae9d..00000000000000 --- a/tensorflow/core/profiler/utils/gpu_event_stats.cc +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/gpu_event_stats.h" - -#include - -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { -namespace { - -const absl::string_view kAnnotationDelimiter = "::"; - -} - -GpuEventStats::GpuEventStats(const XEventVisitor* event) { - event->ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kTfOp: - tf_op_fullname = stat.StrOrRefValue(); - break; - case StatType::kEquation: - equation = stat.StrOrRefValue(); - break; - case StatType::kTensorShapes: - tensor_shapes = stat.StrOrRefValue(); - break; - case StatType::kHloOp: - hlo_op_names = - absl::StrSplit(stat.StrOrRefValue(), kAnnotationDelimiter); - break; - case StatType::kHloModule: - hlo_module_name = stat.StrOrRefValue(); - break; - case StatType::kProgramId: - program_id = stat.IntOrUintValue(); - break; - case StatType::kKernelDetails: - kernel_details = stat.StrOrRefValue(); - break; - case StatType::kMemcpyDetails: - memcpy_details = stat.StrOrRefValue(); - break; - case StatType::kCorrelationId: - correlation_id = static_cast(stat.IntOrUintValue()); - break; - case StatType::kGroupId: - group_id = stat.IntValue(); - break; - case StatType::kIsEager: - is_eager = stat.BoolValue(); - break; - case StatType::kCudaGraphExecId: - cuda_graph_exec_id = stat.UintValue(); - break; - case StatType::kCudaGraphId: - cuda_graph_id_for_inner_node = stat.UintValue(); - break; - case StatType::kScopeRangeId: - scope_range_id = stat.IntValue(); - break; - default: - break; - } - }); -} - -LaunchEventStats::LaunchEventStats(const XEventVisitor* event) { - event->ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type().has_value()) return; - switch (stat.Type().value()) { - case StatType::kDeviceId: - device_id = stat.IntOrUintValue(); - break; - case StatType::kCorrelationId: - correlation_id = static_cast(stat.IntOrUintValue()); - break; - case StatType::kGroupId: - group_id = stat.IntValue(); - break; - default: - break; - } - }); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/gpu_event_stats.h b/tensorflow/core/profiler/utils/gpu_event_stats.h index 369492dcceef88..574e333ae6784f 100644 --- a/tensorflow/core/profiler/utils/gpu_event_stats.h +++ b/tensorflow/core/profiler/utils/gpu_event_stats.h @@ -16,67 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_GPU_EVENT_STATS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_GPU_EVENT_STATS_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -// Stats from a GPU stream XEvent. -struct GpuEventStats { - explicit GpuEventStats(const XEventVisitor* event); - - bool IsKernel() const { return !kernel_details.empty(); } - bool IsMemCpy() const { return !memcpy_details.empty(); } - bool IsCudaGraphExecution() const { return cuda_graph_exec_id.has_value(); } - - bool IsXlaOp() const { return !hlo_op_names.empty(); } - bool IsTfOp() const { return !tf_op_fullname.empty(); } - - // Stats from TensorFlow. - absl::string_view tf_op_fullname; - absl::string_view equation; - absl::string_view tensor_shapes; - - // Stats from XLA. - std::vector hlo_op_names; - absl::string_view hlo_module_name; - std::optional program_id; - - // Stats from CUPTI. - absl::string_view kernel_details; - absl::string_view memcpy_details; - std::optional correlation_id; - std::optional scope_range_id; - - // Stats derived by grouping. - std::optional group_id; - bool is_eager = false; - std::optional cuda_graph_exec_id; - std::optional cuda_graph_id_for_inner_node; -}; - -// Stats for a host-side GPU launch XEvent. -struct LaunchEventStats { - explicit LaunchEventStats(const XEventVisitor* event); - - bool IsLaunch() const { - return device_id.has_value() && correlation_id.has_value(); - } - - // Stats from CUPTI. - std::optional device_id; - std::optional correlation_id; - - // Stat derived by grouping. - std::optional group_id; -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/gpu_event_stats.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_GPU_EVENT_STATS_H_ diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.cc b/tensorflow/core/profiler/utils/hardware_type_utils.cc deleted file mode 100644 index 22beb1d51bc860..00000000000000 --- a/tensorflow/core/profiler/utils/hardware_type_utils.cc +++ /dev/null @@ -1,347 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" - -#include - -#include "absl/container/btree_map.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" - -namespace tensorflow { -namespace profiler { -namespace { - -// The calculation methods is referred from Nvidia developer forum: -// https://forums.developer.nvidia.com/t/how-to-calculate-the-tensor-core-fp16-performance-of-h100/244727 -// Below data are calculated from the various NVidia whitepapers/specs. - -// https://resources.nvidia.com/en-us-tensor-core/gtc22-whitepaper-hopper -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_9_0 = { - .cuda_core = - { - .fp64_tflops = 128, - .fp32_tflops = 256, - .bf16_tflops = 512, - .fp16_tflops = 512, - .int8_tops = 1024, - }, - .tensor_core = - { - .fp64_tflops = 256, - .fp32_tflops = 2048, - .bf16_tflops = 4096, - .fp16_tflops = 4096, - .fp8_tflops = 8192, - .int8_tops = 8192, - }, - .has_tensor_core_sparsity_support = true, -}; - -// https://images.nvidia.com/aem-dam/Solutions/geforce/ada/nvidia-ada-gpu-architecture.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_8_9 = { - .cuda_core = - { - .fp64_tflops = 128, - .fp32_tflops = 256, - .bf16_tflops = 256, - .fp16_tflops = 256, - .int8_tops = 512, - }, - .tensor_core = - { - .fp32_tflops = 512, - .bf16_tflops = 1024, - .fp16_tflops = 1024, - .fp8_tflops = 2048, - .int8_tops = 2048, - .int4_tops = 4096, - }, - .has_tensor_core_sparsity_support = true, -}; - -// https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_8_6 = { - .cuda_core = - { - .fp64_tflops = 128, - .fp32_tflops = 256, - .bf16_tflops = 256, - .fp16_tflops = 256, - .int8_tops = 512, - }, - .tensor_core = - { - .fp32_tflops = 256, - .bf16_tflops = 512, - .fp16_tflops = 1024, - .int8_tops = 2048, - .int4_tops = 4096, - }, - .has_tensor_core_sparsity_support = true, -}; - -// https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_8_0 = { - .cuda_core = - { - .fp64_tflops = 64, - .fp32_tflops = 128, - .bf16_tflops = 256, - .fp16_tflops = 512, - .int8_tops = 512, - }, - .tensor_core = - { - .fp64_tflops = 128, - .fp32_tflops = 1024, - .bf16_tflops = 2048, - .fp16_tflops = 2048, - .int8_tops = 4096, - }, - .has_tensor_core_sparsity_support = true, -}; - -// https://images.nvidia.com/aem-dam/en-zz/Solutions/design-visualization/technologies/turing-architecture/NVIDIA-Turing-Architecture-Whitepaper.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_7_5 = { - .cuda_core = - { - .fp64_tflops = 64, - .fp32_tflops = 128, - .fp16_tflops = 256, - .int8_tops = 512, - }, - .tensor_core = - { - .fp16_tflops = 1024, - .int8_tops = 2048, - .int4_tops = 4096, - }, - .has_tensor_core_sparsity_support = false, -}; - -// https://images.nvidia.com/content/volta-architecture/pdf/volta-architecture-whitepaper.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_7_0 = { - .cuda_core = - { - .fp64_tflops = 64, - .fp32_tflops = 128, - .bf16_tflops = 0.0, - .fp16_tflops = 256, - .int8_tops = 512, - }, - .tensor_core = - { - .fp16_tflops = 1024, - }, - .has_tensor_core_sparsity_support = false, -}; - -// https://images.nvidia.com/content/pdf/tesla/whitepaper/pascal-architecture-whitepaper.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_6_1 = { - .cuda_core = - { - .fp64_tflops = 8, - .fp32_tflops = 256, - .fp16_tflops = 4, - .int8_tops = 1024, - }, - .tensor_core = {}, - .has_tensor_core_sparsity_support = false, -}; - -// https://images.nvidia.com/content/pdf/tesla/whitepaper/pascal-architecture-whitepaper.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_6_0 = { - .cuda_core = - { - .fp64_tflops = 64, - .fp32_tflops = 128, - .fp16_tflops = 256, - .int8_tops = 512, - }, - .tensor_core = {}, - .has_tensor_core_sparsity_support = false, -}; - -// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-product-literature/NVIDIA-Kepler-GK110-GK210-Architecture-Whitepaper.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_5_0 = { - .cuda_core = - { - .fp64_tflops = 4, - .fp32_tflops = 256, - }, - .tensor_core = {}, - .has_tensor_core_sparsity_support = false, -}; - -// https://www.nvidia.com/content/PDF/product-specifications/GeForce_GTX_680_Whitepaper_FINAL.pdf -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_3_0 = { - .cuda_core = - { - .fp64_tflops = 128, - .fp32_tflops = 384, - }, - .tensor_core = {}, - .has_tensor_core_sparsity_support = false, -}; - -const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_2_0 = { - .cuda_core = - { - .fp64_tflops = 8, - .fp32_tflops = 64, - }, - .tensor_core = {}, - .has_tensor_core_sparsity_support = false, -}; - -GpuFlopCapabilities GetNvidiaFlopCapsPerSMPerCycle(int major_comp_cap, - int minor_comp_cap) { - static const auto& kPerSMFlopCapsTable = - *new absl::btree_map{ - // TODO: Add incoming blackwell, and other old GPUS - {9000, &kComputeCap_PerSM_PerCycle_9_0}, - {8090, &kComputeCap_PerSM_PerCycle_8_9}, - {8060, &kComputeCap_PerSM_PerCycle_8_6}, - {8000, &kComputeCap_PerSM_PerCycle_8_0}, - {7050, &kComputeCap_PerSM_PerCycle_7_5}, - {7000, &kComputeCap_PerSM_PerCycle_7_0}, - {6010, &kComputeCap_PerSM_PerCycle_6_1}, - {6000, &kComputeCap_PerSM_PerCycle_6_0}, - {5000, &kComputeCap_PerSM_PerCycle_5_0}, - {3000, &kComputeCap_PerSM_PerCycle_3_0}, - {2000, &kComputeCap_PerSM_PerCycle_2_0}, - }; - - const int normalized_compute_cap = - major_comp_cap * 1000 + minor_comp_cap * 10; - GpuFlopCapabilities flops_cap{}; - auto it = kPerSMFlopCapsTable.lower_bound(normalized_compute_cap); - if (it == kPerSMFlopCapsTable.end()) { - LOG(WARNING) << "GPU compute capability " << major_comp_cap << "." - << minor_comp_cap << " is too old to support."; - } else { - flops_cap = *it->second; - if (it->first != normalized_compute_cap) { - LOG(WARNING) << "GPU compute capability " << major_comp_cap << "." - << minor_comp_cap - << " is not found. Use the highest compute cap known " - << (it->first / 1000) << "." << ((it->first % 1000) / 10) - << " instead."; - } - } - return flops_cap; -} - -GpuFlopCapabilities GetGpuFlopCapabilitiesPerSM( - const DeviceCapabilities& device_cap) { - GpuFlopCapabilities flops_cap{}; - if (device_cap.device_vendor() == kDeviceVendorNvidia) { - flops_cap = - GetNvidiaFlopCapsPerSMPerCycle(device_cap.compute_capability().major(), - device_cap.compute_capability().minor()); - } else { - LOG(WARNING) << "Unsupported device vendor " << device_cap.device_vendor(); - } - - flops_cap.ScaleWith(device_cap.clock_rate_in_ghz()); - return flops_cap; -} - -} // namespace - -double GetFlopMaxThroughputPerSM(const DeviceCapabilities& device_cap) { - GpuFlopCapabilities sm_flops = GetGpuFlopCapabilitiesPerSM(device_cap); - double result = std::max( - {sm_flops.cuda_core.fp32_tflops, sm_flops.cuda_core.fp16_tflops, - sm_flops.tensor_core.fp32_tflops, sm_flops.tensor_core.fp16_tflops}); - VLOG(3) << "GetFlopMaxThroughputPerSM get result: " << result << " GFLOPs"; - return result; -} - -double GetSharedMemoryBandwidthPerSM(const DeviceCapabilities& device_cap) { - // https://docs.nvidia.com/gameworks/content/developertools/desktop/analysis/report/cudaexperiments/kernellevel/memorystatisticsshared.htm - // Compute capability 2.0, each bank has bandwidth of 4 bytes per 2 cycles. - // For compute capability 3.0 and above, each bank has bandwidth 8 bytes per - // cycle. Each SM has 32 banks. - double transaction_byts_per_cycle = - device_cap.compute_capability().major() <= 2 ? (32 * 4 / 2) : (32 * 8); - double GiBPS = transaction_byts_per_cycle * device_cap.clock_rate_in_ghz(); - return tsl::profiler::GigaToUni(GiBPS); -} - -absl::string_view GpuModelName(const DeviceCapabilities& device_cap) { - if (device_cap.device_vendor() == kDeviceVendorNvidia) { - switch (device_cap.compute_capability().major()) { - case 2: - return "Nvidia GPU (Fermi)"; - case 3: - return "Nvidia GPU (Kepler)"; - case 5: - return "Nvidia GPU (Maxwell)"; - case 6: - return "Nvidia GPU (Pascal)"; - case 7: - if (device_cap.compute_capability().minor() < 5) { - return "Nvidia GPU (Volta)"; - } else { - return "Nvidia GPU (Turing)"; - } - case 8: - if (device_cap.compute_capability().minor() < 9) { - return "Nvidia GPU (Ampere)"; - } else { - return "Nvidia GPU (Ada Lovelace)"; - } - case 9: - return "Nvidia GPU (Hopper)"; - case 10: - return "Nvidia GPU (Blackwell)"; - default: - return "Nvidia GPU"; - } - } else if (device_cap.device_vendor() == kDeviceVendorAMD) { - switch (device_cap.compute_capability().major()) { - case 9: - return "AMD GPU - gfx-9XX series"; - case 10: - return "AMD GPU - gfx-10XX series"; - case 11: - return "AMD GPU - gfx-11XX series"; - default: - return "AMD GPU"; - } - } else { - LOG(ERROR) << "Unknown device vendor " << device_cap.device_vendor(); - return ""; - } -} - -HardwareType ParseHardwareType(absl::string_view device_type) { - if (absl::StrContains(device_type, "GPU")) return HardwareType::GPU; - if (device_type == "CPU") return HardwareType::CPU_ONLY; - if (absl::StrContains(device_type, "TPU")) return HardwareType::TPU; - return HardwareType::UNKNOWN_HARDWARE; -} - -bool HasDevice(HardwareType x) { return x > tensorflow::profiler::CPU_ONLY; } - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.h b/tensorflow/core/profiler/utils/hardware_type_utils.h index 41b1bd4b65471c..c2fc5266bc3778 100644 --- a/tensorflow/core/profiler/utils/hardware_type_utils.h +++ b/tensorflow/core/profiler/utils/hardware_type_utils.h @@ -16,67 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_HARDWARE_TYPE_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_HARDWARE_TYPE_UTILS_H_ -#include "absl/strings/string_view.h" -#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" - -namespace tensorflow { -namespace profiler { - -struct GpuFlopCapabilities { - struct FlopCapabilityOnPrecisions { - double fp64_tflops = 0; - double fp32_tflops = 0; // also for tf32 for nvidia tensor core - double bf16_tflops = 0; - double fp16_tflops = 0; - double fp8_tflops = 0; - double int8_tops = 0; - double fp4_tflops = 0; - double int4_tops = 0; - - void ScaleWith(double scale) { - fp64_tflops *= scale; - fp32_tflops *= scale; - bf16_tflops *= scale; - fp16_tflops *= scale; - fp8_tflops *= scale; - int8_tops *= scale; - fp4_tflops *= scale; - int4_tops *= scale; - } - }; - - FlopCapabilityOnPrecisions cuda_core; - FlopCapabilityOnPrecisions tensor_core; - bool has_tensor_core_sparsity_support = false; - - void ScaleWith(double scale) { - cuda_core.ScaleWith(scale); - tensor_core.ScaleWith(scale); - } -}; - -// Get peak single precision throughput of the GPU in GFLOPS per -// streaming multiprocessor. -// TODO: Need design on how to use the sparsity capability of FLOPs. -double GetFlopMaxThroughputPerSM(const DeviceCapabilities& device_cap); - -// for Nvidia GPU, return shared memory bandwidth in Bytes Per Second on -// one single SM given the GPU core freq in device_cap. -double GetSharedMemoryBandwidthPerSM(const DeviceCapabilities& device_cap); - -// Returns the GPU model name from the given DeviceCapabilities. -// For nvidia GPUs, the name is like "Nvidia GPU (Kepler)" or "Nvidia GPU -// (Turing)". For AMD GPUs, the name is like "AMD GPU - gfx-10XX series". -// The model name here for Nvidia GPU in fact refers to its microarchitecture -// name. -absl::string_view GpuModelName(const DeviceCapabilities& device_cap); - -HardwareType ParseHardwareType(absl::string_view device_type); - -// Returns true if the given hardware type has a device. -bool HasDevice(HardwareType x); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/hardware_type_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_HARDWARE_TYPE_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/hardware_type_utils_test.cc b/tensorflow/core/profiler/utils/hardware_type_utils_test.cc deleted file mode 100644 index 9476848a650dcc..00000000000000 --- a/tensorflow/core/profiler/utils/hardware_type_utils_test.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/hardware_type_utils.h" - -#include "xla/tsl/profiler/utils/math_utils.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace profiler { -namespace { - -TEST(HardwareTypeUtilsTest, H100PeakComputTFlops) { - DeviceCapabilities device_cap; - // For NVIDIA H100 PCIe 80 GB, according to - // https://resources.nvidia.com/en-us-data-center-overview/gtc22-whitepaper-hopper - // https://www.techpowerup.com/gpu-specs/h100-pcie-80-gb.c3899 - device_cap.set_clock_rate_in_ghz(1.620); - device_cap.set_num_cores(114); - device_cap.set_memory_size_in_bytes( - tsl::profiler::GibiToGiga(tsl::profiler::GigaToUni(80))); - device_cap.set_memory_bandwidth(tsl::profiler::GigaToUni(2.04 * 1024)); - device_cap.set_device_vendor("Nvidia"); - device_cap.mutable_compute_capability()->set_major(9); - device_cap.mutable_compute_capability()->set_minor(0); - - // Get target TFLOPS per SM and check. - double peak_tflops = - GetFlopMaxThroughputPerSM(device_cap) * device_cap.num_cores() / 1000.0; - EXPECT_NEAR(peak_tflops, 756, /*abs_error=*/1.0); -} - -TEST(HardwareTypeUtilsTest, A100PeakComputTFlops) { - DeviceCapabilities device_cap; - // For NVIDIA A100 SXM4 80 GB, according to: - // https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf - // https://www.techpowerup.com/gpu-specs/a100-sxm4-80-gb.c3746 - device_cap.set_clock_rate_in_ghz(1.410); - device_cap.set_num_cores(108); - device_cap.set_memory_size_in_bytes( - tsl::profiler::GibiToGiga(tsl::profiler::GigaToUni(80))); - device_cap.set_memory_bandwidth(tsl::profiler::GigaToUni(2.04 * 1024)); - device_cap.set_device_vendor("Nvidia"); - device_cap.mutable_compute_capability()->set_major(8); - device_cap.mutable_compute_capability()->set_minor(0); - - double peak_tflops = - GetFlopMaxThroughputPerSM(device_cap) * device_cap.num_cores() / 1000.0; - EXPECT_NEAR(peak_tflops, 312, /*abs_error=*/1.0); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hlo_module_map.cc b/tensorflow/core/profiler/utils/hlo_module_map.cc deleted file mode 100644 index d4683d22f33efa..00000000000000 --- a/tensorflow/core/profiler/utils/hlo_module_map.cc +++ /dev/null @@ -1,181 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/hlo_module_map.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "absl/strings/str_cat.h" -#include "xla/service/hlo_cost_analysis.h" -#include "xla/shape.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "tsl/profiler/lib/traceme_encode.h" - -#if GOOGLE_CUDA -#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#endif -#include "absl/log/check.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "tensorflow/core/platform/path.h" -#include "tensorflow/core/profiler/utils/hlo_module_utils.h" -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" -#include "tensorflow/core/profiler/utils/hlo_proto_to_module.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -#if GOOGLE_CUDA -int64_t ShapeSize(const xla::Shape& shape) { - constexpr int64_t kPointerSize = 8; - return xla::ShapeUtil::ByteSizeOf(shape, kPointerSize); -} -#endif - -} // namespace - -HloInstructionWrapper::HloInstructionWrapper( - const xla::HloInstruction* instr, const xla::HloCostAnalysis* cost_analysis) - : instr_(instr), - op_full_name_( - tsl::profiler::TraceMeOp(Metadata().op_name(), Metadata().op_type())), - tf_op_name_(tsl::profiler::TfOpFullname(Metadata().op_type(), - Metadata().op_name())), - category_(instr_->ToCategory()), - expression_(tensorflow::profiler::UncachedExpression( - instr_, false, tensorflow::profiler::kMaxHlolNameSize)) { - ProcessXlaCostAnalysis(cost_analysis); -} - -HloModuleWrapper::HloModuleWrapper( - const xla::HloProto& hlo_proto, - std::function shape_func) - : HloModuleWrapper(ConvertHloProtoToModuleIgnoringErrors(hlo_proto), - shape_func) {} - -HloModuleWrapper::HloModuleWrapper( - std::unique_ptr module, - std::function shape_func) - : module_(std::move(module)) { - if (module_ == nullptr) return; - - const xla::HloCostAnalysis* cost_analysis = nullptr; -#if GOOGLE_CUDA - if (shape_func == nullptr) shape_func = ShapeSize; - xla::HloCostAnalysis::Options options; - options.shape_size = shape_func; - xla::gpu::GpuHloCostAnalysis gpu_cost_analysis(options); - - const xla::HloComputation* hlo_computation = module_->entry_computation(); - gpu_cost_analysis.ReserveVisitStates(hlo_computation->instruction_count()); - tsl::Status analysis_status = hlo_computation->Accept(&gpu_cost_analysis); - if (analysis_status.ok()) { - // Clear the visit state as it isn't used by anybody and it uses a lot of - // memory. - gpu_cost_analysis.DestroyVisitState(); - } else { - LOG(ERROR) << "Failed to create cost analysis: " << analysis_status; - } - cost_analysis = &gpu_cost_analysis; -#endif - - // Populate instructions_by_name_ with module. - for (const xla::HloComputation* computation : module_->computations()) { - for (const xla::HloInstruction* instr : computation->instructions()) { - instructions_by_name_.try_emplace( - instr->name(), HloInstructionWrapper(instr, cost_analysis)); - } - } - // Gather nested fusion instructions. - for (const xla::HloComputation* computation : module_->computations()) { - // Some modules still seem to have "dead" fusions computations. In this - // case, IsFusionComputation() = true but there is no parent - // FusionInstruction(). - if (computation->FusionInstruction() != nullptr) { - GatherFusionInstructions(computation->FusionInstruction()); - } - } -} - -// Function to gather all the instructions in a fusion computation. -void HloModuleWrapper::GatherFusionInstructions(xla::HloInstruction* inst) { - HloInstructionWrapper* fused_inst_wrapper = - GetMutableHloInstruction(inst->name()); - DCHECK(fused_inst_wrapper != nullptr); - if (!fused_inst_wrapper->FusedChildren().empty()) return; - for (auto* fused : inst->fused_instructions()) { - const auto child_inst_wrapper = GetHloInstruction(fused->name()); - DCHECK(child_inst_wrapper != nullptr); - fused_inst_wrapper->AddFusedChild(child_inst_wrapper); - if (fused->opcode() == xla::HloOpcode::kFusion) { - GatherFusionInstructions(fused); - } - } -} - -HloInstructionWrapper* HloModuleWrapper::GetMutableHloInstruction( - absl::string_view hlo_name) { - auto it = instructions_by_name_.find(hlo_name); - if (it != instructions_by_name_.end()) return &it->second; - return nullptr; -} - -const HloInstructionWrapper* HloModuleWrapper::GetHloInstruction( - absl::string_view hlo_name) const { - auto it = instructions_by_name_.find(hlo_name); - if (it != instructions_by_name_.end()) return &it->second; - return nullptr; -} - -std::string HloInstructionWrapper::source_info() const { - if (!Metadata().source_file().empty()) { - return absl::StrCat(io::Basename(Metadata().source_file()), ":", - Metadata().source_line()); - } else { - return std::string(); - } -} - -void AddHloProto(HloModuleMap& hlo_module_map, uint64_t program_id, - const xla::HloProto& hlo_proto) { - auto hlo_module = ConvertHloProtoToModule(hlo_proto); - if (!hlo_module.ok()) { - LOG(ERROR) << hlo_module.status(); - return; - } - hlo_module_map.try_emplace(program_id, - HloModuleWrapper(std::move(hlo_module).value(), - /*shape_func=*/nullptr)); -} - -void ProcessHloModuleMapFromXSpace(HloModuleMap& hlo_module_map, - const XSpace* space) { - for (auto& [program_id, hlo_proto] : ParseHloProtosFromXSpace(*space)) { - AddHloProto(hlo_module_map, program_id, *hlo_proto); - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hlo_module_map.h b/tensorflow/core/profiler/utils/hlo_module_map.h index ab6898af72ed84..e6c58633d5f334 100644 --- a/tensorflow/core/profiler/utils/hlo_module_map.h +++ b/tensorflow/core/profiler/utils/hlo_module_map.h @@ -16,200 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_MAP_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_MAP_H_ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo.pb.h" -#include "xla/service/hlo_cost_analysis.h" -#include "xla/shape.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "tensorflow/core/profiler/utils/hlo_module_utils.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -class HloInstructionInterface { - public: - virtual ~HloInstructionInterface() = default; - virtual absl::string_view Name() const = 0; - virtual xla::HloOpcode HloOpcode() const = 0; - virtual absl::string_view Category() const = 0; - virtual std::string HloOpcodeString() const = 0; - virtual const xla::OpMetadata& Metadata() const = 0; - virtual size_t flops() const = 0; - virtual size_t bytes_accessed() const = 0; - virtual std::string_view op_full_name() const = 0; - virtual std::string_view TfOpName() const = 0; - virtual std::string source_info() const = 0; - virtual bool isRoot() const = 0; - virtual bool IsFusion() const = 0; - virtual const std::string& Expression() const = 0; - - virtual void ProcessXlaCostAnalysis( - const xla::HloCostAnalysis* cost_analysis) = 0; - virtual std::string OpLocationStack(int32_t frame_id) const = 0; - virtual tsl::profiler::OpSourceInfo SourceInfo() const = 0; -}; - -// This wrapper allows caching the results of HloInstruction methods. -// This wrapper is not thread safe. -class HloInstructionWrapper : public HloInstructionInterface { - public: - explicit HloInstructionWrapper( - const xla::HloInstruction* instr, - const xla::HloCostAnalysis* cost_analysis = nullptr); - - // Non copyable - HloInstructionWrapper(const HloInstructionWrapper&) = delete; - HloInstructionWrapper& operator=(const HloInstructionWrapper&) = delete; - // Movable. - HloInstructionWrapper(HloInstructionWrapper&&) = default; - HloInstructionWrapper& operator=(HloInstructionWrapper&&) = default; - - absl::string_view Name() const override { return instr_->name(); } - - xla::HloOpcode HloOpcode() const override { return instr_->opcode(); } - - absl::string_view Category() const override { return category_; } - - std::string HloOpcodeString() const override { - return std::string(xla::HloOpcodeString(instr_->opcode())); - } - - const xla::OpMetadata& Metadata() const override { - return instr_->metadata(); - } - - size_t flops() const override { return flops_; } - size_t bytes_accessed() const override { return bytes_accessed_; } - - std::string_view op_full_name() const override { return op_full_name_; } - std::string_view TfOpName() const override { return tf_op_name_; } - std::string source_info() const override; - - bool isRoot() const override { return instr_->IsRoot(); } - bool IsFusion() const override { return !fused_children_.empty(); }; - - void ProcessXlaCostAnalysis( - const xla::HloCostAnalysis* cost_analysis) override { - if (cost_analysis == nullptr) return; - flops_ = cost_analysis->flop_count(*instr_); - bytes_accessed_ = cost_analysis->bytes_accessed(*instr_); - } - - const std::string& Expression() const override { return expression_; } - - void AddFusedChild(const HloInstructionWrapper* child) { - fused_children_.push_back(child); - }; - - const std::vector& FusedChildren() const { - return fused_children_; - } - - std::string OpLocationStack(int32_t frame_id) const override { - return GetOpLocationStack(frame_id, instr_); - } - - tsl::profiler::OpSourceInfo SourceInfo() const override { - return GetSourceInfo(instr_); - } - - private: - const xla::HloInstruction* instr_; - std::vector fused_children_; - std::string op_full_name_; - std::string tf_op_name_; - size_t flops_ = 0; - size_t bytes_accessed_ = 0; - std::string category_; - std::string expression_; -}; - -// Helper class for accessing HloModule. -class HloModuleInterface { - public: - virtual ~HloModuleInterface() = default; - - // If the module contains no instructions. - virtual bool Empty() const = 0; - virtual absl::string_view Name() const = 0; - // Function to populated nested childs= instructions in a fusion. - virtual void GatherFusionInstructions(xla::HloInstruction* inst) = 0; -}; - -// Wraps HLO module and provides an interface that maps HLO names to -// HloInstructionWrappers. -class HloModuleWrapper : public HloModuleInterface { - public: - explicit HloModuleWrapper( - const xla::HloProto& hlo_proto, - std::function shape_func = nullptr); - - explicit HloModuleWrapper( - std::unique_ptr module, - std::function shape_func); - - const HloInstructionWrapper* GetHloInstruction( - absl::string_view hlo_name) const; - HloInstructionWrapper* GetMutableHloInstruction(absl::string_view hlo_name); - - bool Empty() const override { return instructions_by_name_.empty(); } - - absl::string_view Name() const override { return module_->name(); } - void GatherFusionInstructions(xla::HloInstruction* inst) override; - - private: - std::unique_ptr module_; - - // Map of HloInstructionWrappers by name. - using HloInstructionMap = - absl::flat_hash_map; - HloInstructionMap instructions_by_name_; -}; - -// Map of HloModuleWrappers by program_id. -using HloModuleMap = - absl::flat_hash_map; - -void AddHloProto(HloModuleMap& hlo_module_map, uint64_t program_id, - const xla::HloProto& hlo_proto); - -// Process HloModuleMap from single XSpace. -void ProcessHloModuleMapFromXSpace(HloModuleMap& hlo_module_map, - const XSpace* space); - -// WARNING: The returned pointer will be invalidated if HloModuleMap is mutated. -inline const HloModuleWrapper* GetHloModule(const HloModuleMap* hlo_module_map, - uint64_t program_id) { - if (hlo_module_map == nullptr) return nullptr; - auto iter = hlo_module_map->find(program_id); - if (iter == hlo_module_map->end()) return nullptr; - return &iter->second; -} - -inline const HloInstructionWrapper* GetHloInstruction( - const HloModuleMap& hlo_module_map, std::optional program_id, - absl::string_view hlo_name) { - if (!program_id.has_value()) return nullptr; - const auto* hlo_module = GetHloModule(&hlo_module_map, *program_id); - if (hlo_module == nullptr) return nullptr; - return hlo_module->GetHloInstruction(hlo_name); -} - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/hlo_module_map.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_MAP_H_ diff --git a/tensorflow/core/profiler/utils/hlo_module_utils.h b/tensorflow/core/profiler/utils/hlo_module_utils.h index 2de48469253fe9..8b68816a52ebb6 100644 --- a/tensorflow/core/profiler/utils/hlo_module_utils.h +++ b/tensorflow/core/profiler/utils/hlo_module_utils.h @@ -16,103 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_UTILS_H_ -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_print_options.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" - -namespace tensorflow { -namespace profiler { - -// Sometimes HLO produce a huge string (>100MB). Limit the name size to 1MB. -static constexpr size_t kMaxHlolNameSize = 1000000; - -inline const xla::HloInstruction* FindInstruction(const xla::HloModule& module, - std::string node_name) { - if (absl::StartsWith(node_name, "%")) { - node_name.erase(node_name.begin()); - } - for (const xla::HloComputation* computation : module.computations()) { - auto instrs = computation->instructions(); - auto it = absl::c_find_if(instrs, [&](const xla::HloInstruction* instr) { - // Try with and without "%" at the beginning of the node name. - return absl::EqualsIgnoreCase(instr->name(), node_name) || - absl::EqualsIgnoreCase(instr->name(), - absl::StrCat("%", node_name)); - }); - if (it != instrs.end()) { - return *it; - } - } - return nullptr; -} - -inline const xla::HloComputation* FindComputation( - const xla::HloModule& module, const std::string& comp_name) { - for (const xla::HloComputation* computation : module.computations()) { - if (absl::EqualsIgnoreCase(computation->name(), comp_name)) { - return computation; - } - } - return nullptr; -} - -inline std::string UncachedExpression(const xla::HloInstruction* instr, - bool skip_expression, size_t max_size) { - if (skip_expression) { - return ""; - } - static const auto* hlo_print_options = - new xla::HloPrintOptions(xla::HloPrintOptions() - .set_print_metadata(false) - .set_print_backend_config(false) - .set_print_infeed_outfeed_config(false) - .set_print_operand_shape(true) - .set_print_large_constants(false)); - std::string expression = instr->ToString(*hlo_print_options); - if (expression.size() > max_size) { - expression.resize(max_size); - } - return expression; -} - -inline std::string GetOpLocationStack(int32_t frame_id, - const xla::HloInstruction* instr) { - std::string stack_lines; - xla::HloModule* hlo_module = instr->GetModule(); - while (frame_id != 0) { - xla::HloModule::StackFrame frame = hlo_module->get_stack_frame(frame_id); - if (frame.empty()) { - break; - } - stack_lines.insert(0, absl::StrCat(frame.file_name, ":", frame.line, ":", - frame.column, "\n")); - frame_id = frame.parent_frame_id; - } - - return stack_lines; -}; - -inline tsl::profiler::OpSourceInfo GetSourceInfo( - const xla::HloInstruction* instr) { - if (int32_t stack_frame_id = instr->metadata().stack_frame_id(); - stack_frame_id != 0) { - return {.source_file = instr->metadata().source_file(), - .source_line = instr->metadata().source_line(), - .stack_frame = GetOpLocationStack(stack_frame_id, instr)}; - } - return {.source_file = instr->metadata().source_file(), - .source_line = instr->metadata().source_line()}; -}; -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/hlo_module_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/hlo_module_utils_test.cc b/tensorflow/core/profiler/utils/hlo_module_utils_test.cc deleted file mode 100644 index 18eb2a2cdce7ce..00000000000000 --- a/tensorflow/core/profiler/utils/hlo_module_utils_test.cc +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/hlo_module_utils.h" - -#include - -#include -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tsl/platform/statusor.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace profiler { -namespace { - -class HloModuleUtilsTest : public xla::HloTestBase { - protected: - absl::StatusOr> GetModuleWithStackFrames() { - const char file_name[] = "main.py"; - const char function_name[] = "func1"; - const int line_number = 10; - const int column_number = 5; - const int frame_id = 1; - const char text[] = R"( - HloModule a_module - - ENTRY main { - %c = s32[] constant(1) - ROOT %result = s32[] parameter(0) - } - )"; - TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(text)); - - auto module_proto = module->ToProto(); - auto index = module_proto.mutable_stack_frame_index(); - index->add_file_names(file_name); - index->add_function_names(function_name); - auto location = index->add_file_locations(); - location->set_file_name_id(frame_id); - location->set_function_name_id(1); - location->set_line(line_number); - location->set_column(column_number); - - auto frame = index->add_stack_frames(); - frame->set_file_location_id(1); - - // Set the stack frame id of the root instruction. - for (auto& computation : *module_proto.mutable_computations()) { - if (computation.id() == module_proto.entry_computation_id()) { - for (auto& instruction : *computation.mutable_instructions()) { - if (instruction.id() == computation.root_id()) { - instruction.mutable_metadata()->set_stack_frame_id(frame_id); - instruction.mutable_metadata()->set_source_file(file_name); - instruction.mutable_metadata()->set_source_line(line_number); - } - } - } - } - - return xla::HloModule::CreateFromProto(module_proto, module->config()); - } -}; - -TEST_F(HloModuleUtilsTest, TestGetLocationStack) { - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module_with_stack_frames, - GetModuleWithStackFrames()); - auto root_instruction = - module_with_stack_frames->entry_computation()->root_instruction(); - EXPECT_EQ(GetOpLocationStack(1, root_instruction), "main.py:10:5\n"); -} - -TEST_F(HloModuleUtilsTest, TestGetSourceInfo) { - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module_with_stack_frames, - GetModuleWithStackFrames()); - auto root_instruction = - module_with_stack_frames->entry_computation()->root_instruction(); - auto source_info = GetSourceInfo(root_instruction); - EXPECT_EQ(source_info.source_file, "main.py"); - EXPECT_EQ(source_info.source_line, 10); - EXPECT_EQ(source_info.stack_frame, "main.py:10:5\n"); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hlo_proto_map.cc b/tensorflow/core/profiler/utils/hlo_proto_map.cc deleted file mode 100644 index 50d96c49980e74..00000000000000 --- a/tensorflow/core/profiler/utils/hlo_proto_map.cc +++ /dev/null @@ -1,172 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/hlo_proto_map.h" - -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/service/hlo.pb.h" -#include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { -namespace { - -int NumHeapSimulatorTraceEvents(const xla::HloProto* hlo) { - int result = 0; - for (const auto& trace : hlo->buffer_assignment().heap_simulator_traces()) { - result += trace.events_size(); - } - return result; -} - -} // namespace - -absl::flat_hash_map> -ParseHloProtosFromXSpace(const XSpace& space) { - absl::flat_hash_map> hlo_protos; - std::vector planes = - FindPlanesWithNames(space, {kMetadataPlaneName}); - for (const XPlane* raw_plane : planes) { - if (raw_plane != nullptr) { - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(raw_plane); - - const XStatMetadata* hlo_proto_stat_metadata = - plane.GetStatMetadataByType(StatType::kHloProto); - if (hlo_proto_stat_metadata != nullptr) { - plane.ForEachEventMetadata( - [&](const XEventMetadataVisitor& event_metadata) { - auto hlo_proto_stat = event_metadata.GetStat( - StatType::kHloProto, *hlo_proto_stat_metadata); - if (!hlo_proto_stat) return; - if (hlo_proto_stat->ValueCase() != XStat::kBytesValue) return; - auto hlo_proto = std::make_unique(); - absl::string_view byte_value = hlo_proto_stat->BytesValue(); - if (hlo_proto->ParseFromArray(byte_value.data(), - byte_value.size())) { - if (!hlo_protos - .try_emplace(event_metadata.Id(), std::move(hlo_proto)) - .second) { - LOG(WARNING) << "Insert failed for hlo_proto with program_id" - << event_metadata.Id(); - } - } - }); - } - } - } - return hlo_protos; -} - -bool HloProtoMap::AddHloProto(uint64_t program_id, - const xla::HloProto* hlo_proto) { - bool new_program_id = - hlo_protos_by_program_id_.try_emplace(program_id, hlo_proto).second; - absl::string_view hlo_module_name = hlo_proto->hlo_module().name(); - bool new_module_name = - hlo_protos_by_name_ - .try_emplace(tsl::profiler::HloModuleNameWithProgramId( - hlo_module_name, program_id), - hlo_proto) - .second; - return new_program_id || new_module_name; -} - -void HloProtoMap::AddHloProto(uint64_t program_id, - std::unique_ptr hlo_proto) { - if (AddHloProto(program_id, hlo_proto.get())) { - // Only add to if is new to HloProtoMap. - owned_hlo_protos_.push_back(std::move(hlo_proto)); - } -} - -void HloProtoMap::AddHloProtosFromXSpace(const XSpace& space) { - for (auto& [program_id, hlo_proto] : ParseHloProtosFromXSpace(space)) { - AddHloProto(program_id, std::move(hlo_proto)); - } -} - -std::vector HloProtoMap::GetModuleList() const { - std::vector module_list; - module_list.reserve(hlo_protos_by_name_.size()); - for (const auto& [name, hlo_proto] : hlo_protos_by_name_) { - module_list.push_back(name); - } - return module_list; -} - -std::vector HloProtoMap::GetSortedModuleList() const { - std::vector module_list = GetModuleList(); - absl::c_sort(module_list); - return module_list; -} - -std::vector HloProtoMap::GetSortedModuleListByHeapTraceSize() - const { - std::vector> hlo_protos( - hlo_protos_by_name_.begin(), hlo_protos_by_name_.end()); - - // Sort the hlo protos by heap trace size and then by hlo module name. - // This way trivial computations will be on the bottom of the list. - absl::c_stable_sort(hlo_protos, [](const auto& a, const auto& b) { - int num_a = tensorflow::profiler::NumHeapSimulatorTraceEvents(a.second); - int num_b = tensorflow::profiler::NumHeapSimulatorTraceEvents(b.second); - return std::tie(num_a, b.first) > std::tie(num_b, a.first); - }); - - std::vector module_list; - module_list.reserve(hlo_protos.size()); - for (const auto& [name, hlo_proto] : hlo_protos) { - module_list.push_back(name); - } - return module_list; -} - -absl::StatusOr HloProtoMap::GetHloProtoByProgramId( - uint64_t program_id) const { - auto iter = hlo_protos_by_program_id_.find(program_id); - if (iter != hlo_protos_by_program_id_.end()) { - return iter->second; - } - return absl::NotFoundError( - absl::StrCat("Program id: ", program_id, " is not found.")); -} - -absl::StatusOr HloProtoMap::GetHloProtoByModuleName( - absl::string_view module_name) const { - auto iter = hlo_protos_by_name_.find(module_name); - if (iter != hlo_protos_by_name_.end()) { - return iter->second; - } - return absl::NotFoundError( - absl::StrCat("Module name: ", module_name, " is not found.")); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hlo_proto_map.h b/tensorflow/core/profiler/utils/hlo_proto_map.h index 383c3064bc85de..23259adffaedab 100644 --- a/tensorflow/core/profiler/utils/hlo_proto_map.h +++ b/tensorflow/core/profiler/utils/hlo_proto_map.h @@ -16,71 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_MAP_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_MAP_H_ -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/service/hlo.pb.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -absl::flat_hash_map> -ParseHloProtosFromXSpace(const XSpace& space); - -class HloProtoMap { - public: - void AddHloProtosFromXSpace(const XSpace& space); - - void AddHloProto(uint64_t program_id, - std::unique_ptr hlo_proto); - - size_t size() const { return hlo_protos_by_program_id_.size(); } - - auto begin() const { return hlo_protos_by_program_id_.begin(); } - auto end() const { return hlo_protos_by_program_id_.end(); } - - bool contains(absl::string_view name) const { - return hlo_protos_by_name_.contains(name); - } - - bool contains(uint64_t program_id) const { - return hlo_protos_by_program_id_.contains(program_id); - } - - // Returns a list of module names (not sorted). - std::vector GetModuleList() const; - - // Returns a list of module names sorted alphabetically. - std::vector GetSortedModuleList() const; - - // Returns a list of hlo module names sorted first by heap trace size and then - // by hlo module name alphabetically. - std::vector GetSortedModuleListByHeapTraceSize() const; - - absl::StatusOr GetHloProtoByModuleName( - absl::string_view module_name) const; - - absl::StatusOr GetHloProtoByProgramId( - uint64_t program_id) const; - - private: - absl::flat_hash_map hlo_protos_by_program_id_; - absl::flat_hash_map hlo_protos_by_name_; - std::vector> owned_hlo_protos_; - - // Try to add proto to the map and returns true if the addition is successful - // (i.e., the proto is new to the map). - bool AddHloProto(uint64_t program_id, const xla::HloProto* hlo_proto); -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/hlo_proto_map.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_MAP_H_ diff --git a/tensorflow/core/profiler/utils/hlo_proto_to_module.cc b/tensorflow/core/profiler/utils/hlo_proto_to_module.cc deleted file mode 100644 index 4083bbfe8bbe49..00000000000000 --- a/tensorflow/core/profiler/utils/hlo_proto_to_module.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/hlo_proto_to_module.h" - -#include -#include - -#include "absl/log/log.h" -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo.pb.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/util.h" - -namespace tensorflow { -namespace profiler { - -absl::StatusOr> ConvertHloProtoToModule( - const xla::HloProto& hlo_proto) { - if (!hlo_proto.has_hlo_module()) { - return xla::Internal("No HLO module found in the HLO proto"); - } - const xla::HloModuleProto& module_proto = hlo_proto.hlo_module(); - TF_ASSIGN_OR_RETURN(auto config, xla::HloModule::CreateModuleConfigFromProto( - module_proto, xla::DebugOptions())); - TF_ASSIGN_OR_RETURN(auto module, - xla::HloModule::CreateFromProto(module_proto, config)); - return module; -} - -std::unique_ptr ConvertHloProtoToModuleIgnoringErrors( - const xla::HloProto& hlo_proto) { - auto module = ConvertHloProtoToModule(hlo_proto); - if (!module.ok()) { - LOG(ERROR) << module.status(); - return nullptr; - } - return std::move(module).value(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hlo_proto_to_module.h b/tensorflow/core/profiler/utils/hlo_proto_to_module.h index 4cf3fa6383367d..954ed71345c9bd 100644 --- a/tensorflow/core/profiler/utils/hlo_proto_to_module.h +++ b/tensorflow/core/profiler/utils/hlo_proto_to_module.h @@ -16,22 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_TO_MODULE_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_TO_MODULE_H_ -#include - -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo.pb.h" - -namespace tensorflow { -namespace profiler { - -absl::StatusOr> ConvertHloProtoToModule( - const xla::HloProto& hlo_proto); - -std::unique_ptr ConvertHloProtoToModuleIgnoringErrors( - const xla::HloProto& hlo_proto); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/hlo_proto_to_module.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_TO_MODULE_H_ diff --git a/tensorflow/core/profiler/utils/host_offload_utils.cc b/tensorflow/core/profiler/utils/host_offload_utils.cc deleted file mode 100644 index 7f135985d0b1c6..00000000000000 --- a/tensorflow/core/profiler/utils/host_offload_utils.cc +++ /dev/null @@ -1,199 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/utils/host_offload_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/profiler/utils/trace_utils.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -bool HostOffloadEventProcessor::IsHostOffloadOpName( - const XEventVisitor& event) const { - static constexpr absl::string_view keywords[] = {"copy-start", - "copy-done", - "dynamic-slice-start", - "dynamic-slice-done", - "dynamic-update-slice-start", - "dynamic-update-slice-done"}; - - for (const auto& keyword : keywords) { - // The host_memory_label_ S(5) is used by instructions to designate tensors - // that are on the host. - if (absl::StrContains(event.DisplayName(), keyword) && - absl::StrContains(event.Name(), host_memory_label_)) { - return true; - } - } - return false; -} - -std::string HostOffloadEventProcessor::GetOffloadInstructionID( - absl::string_view op_name) const { - std::vector op_name_vec = absl::StrSplit(op_name, '.'); - - // If no dot is found, or it's at the beginning or end of the string, return - // a 0. Hlo opnames are not expected to have a dot followed by 0. - if (op_name_vec.size() < 2) { - return "0"; - } - return op_name_vec.back(); -} - -std::string HostOffloadEventProcessor::GetOffloadInstructionName( - absl::string_view op_name) const { - // TODO(b/342469268): Get the display ID and name from the HloInstruction, not - // just the event name. - std::string display_id = GetOffloadInstructionID(op_name); - - size_t startPos = op_name.find("-start"); - size_t donePos = op_name.find("-done"); - - absl::string_view display_opname; - if (startPos != absl::string_view::npos) { - display_opname = op_name.substr(0, startPos); - } else if (donePos != absl::string_view::npos) { - display_opname = op_name.substr(0, donePos); - } else { - // Invalid input format: neither "-start" nor "-done" found - LOG(WARNING) << "Invalid op name: " << op_name; - display_opname = op_name; - } - return absl::StrCat("offload-", display_opname, ".", display_id); -} - -void HostOffloadEventProcessor::ProcessHostOffloadOpEvent( - const XEventVisitor& event, std::optional group_id) { - std::string display_opname = GetOffloadInstructionName(event.DisplayName()); - - auto [iter, inserted] = seen_events_.try_emplace(display_opname); - std::queue& events = iter->second; - - if (absl::StrContains(event.DisplayName(), "-start")) { - // For start events, just push them into the queue. - events.push(&event); - return; - } else if (absl::StrContains(event.DisplayName(), "-done")) { - // for done events, pop the start event and create the new event. - // Not all start events may be traced. In this case we just skip the - // corresponding done event. - if (events.empty()) { - LOG(INFO) << "No corresponding start event found for " - << event.DisplayName(); - return; - } - const XEventVisitor* start_event = events.front(); - events.pop(); - - // At this point, we have the corresponding start and end event. - // Create the new event. - tsl::profiler::Timespan event_span = tsl::profiler::Timespan::FromEndPoints( - start_event->GetTimespan().begin_ps(), event.GetTimespan().end_ps()); - - // Find the line with the smallest event end time frontier that can fit this - // new event without overlapping with its other events. - int line_builder_index = -1; - uint64_t minimum_end_time_frontier = event_span.begin_ps(); - for (int i = 0; i < host_offload_op_line_builders_.size(); ++i) { - if (host_offload_op_line_builders_[i].event_end_time_frontier_ns <= - minimum_end_time_frontier) { - line_builder_index = i; - minimum_end_time_frontier = - host_offload_op_line_builders_[i].event_end_time_frontier_ns; - } - } - - constexpr int kMaxHostOffloadOpLinesSize = - kThreadIdHostOffloadOpEnd - kThreadIdHostOffloadOpStart + 1; - - // If no existing lines can fit this new event, create a new line. - if (line_builder_index == -1) { - if (host_offload_op_line_builders_.size() < kMaxHostOffloadOpLinesSize) { - XLineBuilder lb = plane_builder_->GetOrCreateLine( - kThreadIdHostOffloadOpStart + - host_offload_op_line_builders_.size()); - lb.SetName(absl::StrFormat("%s row %d", kHostOffloadOpLineName, - host_offload_op_line_builders_.size())); - lb.SetTimestampNs(start_timestamp_ns_); - host_offload_op_line_builders_.push_back( - {std::move(lb), event_span.end_ps()}); - } - // If we have reached the maximum number of lines, just use the last line. - line_builder_index = host_offload_op_line_builders_.size() - 1; - } - - // Update the event end time frontier for the line. - host_offload_op_line_builders_[line_builder_index] - .event_end_time_frontier_ns = - std::max(host_offload_op_line_builders_[line_builder_index] - .event_end_time_frontier_ns, - event_span.end_ps()); - - XEventMetadata* host_offload_copy_metadata = - plane_builder_->CreateEventMetadata(); - host_offload_copy_metadata->set_display_name(display_opname); - XEventBuilder event_builder = - host_offload_op_line_builders_[line_builder_index] - .line_builder.AddEvent(*host_offload_copy_metadata); - event_builder.SetTimespan(event_span); - - // We mark the events as async so that they are displayed on new sub-lines - // below other async events. - const XStatMetadata& async_stat = *plane_builder_->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kIsAsync)); - event_builder.AddStatValue(async_stat, 1); - - // Set metadata stats for the event. - const XStatMetadata& raw_bytes_stat = - *plane_builder_->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kRawBytesAccessed)); - event.Metadata().ForEachStat([&](const XStatVisitor& stat) { - if (stat.Type() == StatType::kRawBytesAccessed) { - event_builder.AddStatValue(raw_bytes_stat, stat.IntValue()); - } - }); - const XStatMetadata& shape_with_layout_str = - *plane_builder_->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kShapeWithLayout)); - // Use the shape from start_event, since it contains the shape of end event. - start_event->Metadata().ForEachStat([&](const XStatVisitor& stat) { - if (stat.Type() == StatType::kShapeWithLayout) { - event_builder.AddStatValue(shape_with_layout_str, stat.StrOrRefValue()); - } - }); - } -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/host_offload_utils.h b/tensorflow/core/profiler/utils/host_offload_utils.h deleted file mode 100644 index dbf308fbfe1e41..00000000000000 --- a/tensorflow/core/profiler/utils/host_offload_utils.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_UTILS_HOST_OFFLOAD_UTILS_H_ -#define TENSORFLOW_CORE_PROFILER_UTILS_HOST_OFFLOAD_UTILS_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/layout.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -struct LineBuilderAndEventEndTimeFrontier { - XLineBuilder line_builder; - uint64_t event_end_time_frontier_ns; -}; - -class HostOffloadEventProcessor { - public: - HostOffloadEventProcessor(XPlaneBuilder* plane_builder, - uint64_t start_timestamp_ns) - : plane_builder_(plane_builder), - start_timestamp_ns_(start_timestamp_ns) {} - ~HostOffloadEventProcessor() = default; - - void ProcessHostOffloadOpEvent(const XEventVisitor& event, - std::optional group_id); - - bool IsHostOffloadOpName(const XEventVisitor& event) const; - - private: - std::string GetOffloadInstructionID(absl::string_view op_name) const; - std::string GetOffloadInstructionName(absl::string_view op_name) const; - - absl::flat_hash_map> - seen_events_; - std::string host_memory_label_ = - absl::StrCat("S(", xla::Layout::kHostMemorySpace, ")"); - - XPlaneBuilder* plane_builder_; - uint64_t start_timestamp_ns_; - - std::vector - host_offload_op_line_builders_; -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_UTILS_HOST_OFFLOAD_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/html_utils.h b/tensorflow/core/profiler/utils/html_utils.h index 215d9f51d5bec2..9dbf42507b4321 100644 --- a/tensorflow/core/profiler/utils/html_utils.h +++ b/tensorflow/core/profiler/utils/html_utils.h @@ -16,21 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_ -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" - -namespace tensorflow { -namespace profiler { - -// Creates a html that links to the given url with the given text. -inline std::string AnchorElement(absl::string_view url, - absl::string_view text) { - return absl::StrCat("", text, ""); -} - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/html_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.cc b/tensorflow/core/profiler/utils/kernel_stats_utils.cc deleted file mode 100644 index be88b216465220..00000000000000 --- a/tensorflow/core/profiler/utils/kernel_stats_utils.cc +++ /dev/null @@ -1,352 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" - -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/match.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -// The maximum number of Kernels displayed on Kernel Stats page. -const int kMaxNumOfKernels = 1000; - -// A list of patterns to help determine if a kernel uses Tensor Core. -// A kernel uses Tensor Core if its kernel name contains any of these patterns. -// Some examples of kernel names: volta_h884gemm, turing_fp16_s1688cudnn_fp16 -constexpr absl::string_view kTensorCoreKernelNamePatterns[] = { - "16816", - "c1688", - "conv1x1", - "conv2d_c1_k1", - "dgrad_1x1_stride_2x2", - "direct_group", - "first_layer_wgrad_kernel", - "h1688", - "h884", - "hmma", - "i16832", - "i8816", - "s884", - "s1688", - "xmma_gemm", - "xmma_implicit_gemm", - "xmma_sparse_conv", - "xmma_sparse_gemm", - "xmma_warp_specialized_implicit_gemm"}; - -} // namespace - -void ParseKernelLaunchParams(absl::string_view xstat_kernel_details, - KernelReport* kernel) { - const std::vector params = - absl::StrSplit(xstat_kernel_details, absl::ByAnyChar(" \n")); - - constexpr uint32 kNumDimensions = 3; - for (uint32 dim = 0; dim < kNumDimensions; ++dim) { - kernel->add_block_dim(1); - kernel->add_grid_dim(1); - } - - // Process tokens. - for (const auto& param : params) { - const std::vector key_value = absl::StrSplit(param, ':'); - if (key_value.size() != 2) { - // Unrecognized token. - continue; - } - absl::string_view key = key_value[0]; - absl::string_view value_str = key_value[1]; - uint32 value = 0; - double pct = 0.0; - // Cases that consume a pair of tokens "key:value". - if (key == "regs" && absl::SimpleAtoi(value_str, &value)) { - kernel->set_registers_per_thread(value); - } else if (key == "static_shared" && absl::SimpleAtoi(value_str, &value)) { - kernel->set_static_shmem_bytes(value); - } else if (key == "dynamic_shared" && absl::SimpleAtoi(value_str, &value)) { - kernel->set_dynamic_shmem_bytes(value); - } else if (key == "block") { - const std::vector& block = - absl::StrSplit(value_str, ','); - uint32 tmp[3]; - if (block.size() == 3 && absl::SimpleAtoi(block[0], &tmp[0]) && - absl::SimpleAtoi(block[1], &tmp[1]) && - absl::SimpleAtoi(block[2], &tmp[2])) { - std::copy_n(tmp, 3, kernel->mutable_block_dim()->begin()); - } - } else if (key == "grid") { - const std::vector& grid = - absl::StrSplit(value_str, ','); - uint32 tmp[3]; - if (grid.size() == 3 && absl::SimpleAtoi(grid[0], &tmp[0]) && - absl::SimpleAtoi(grid[1], &tmp[1]) && - absl::SimpleAtoi(grid[2], &tmp[2])) { - std::copy_n(tmp, 3, kernel->mutable_grid_dim()->begin()); - } - } else if (key == "occ_pct" && absl::SimpleAtod(value_str, &pct)) { - kernel->set_occupancy_pct(pct); - } - } -} - -bool IsKernelUsingTensorCore(absl::string_view kernel_name) { - VLOG(1) << "kernel name: " << kernel_name; - for (absl::string_view pattern : kTensorCoreKernelNamePatterns) { - if (absl::StrContains(kernel_name, pattern)) { - return true; - } - } - return false; -} - -// This list is not exhaustive. -bool IsOpTensorCoreEligible(absl::string_view tf_op_name) { - // Disable formatting to keep inline comments vertically aligned. - // clang-format off - return false - // Using EndsWith to match Fused operations. - || absl::EndsWith(tf_op_name, "Conv2D") - || absl::EndsWith(tf_op_name, "Conv2DBackpropFilter") - || absl::EndsWith(tf_op_name, "Conv2DBackpropInput") - || absl::EndsWith(tf_op_name, "Conv3D") - || absl::EndsWith(tf_op_name, "DepthwiseConv2dNative") - || absl::EndsWith(tf_op_name, "DepthwiseConv2dNativeBackpropFilter") - || absl::EndsWith(tf_op_name, "DepthwiseConv2dNativeBackpropInput") - // Using Contains to match V2/V3 suffixes. - || absl::StrContains(tf_op_name, "BatchMatMul") - // MatMul requires exact matching. - || absl::EndsWith(tf_op_name, "/MatMul") - || absl::EndsWith(tf_op_name, "FusedMatMul") - // cuDNN operations. - || absl::EndsWith(tf_op_name, "/CudnnRNN") - || absl::StrContains(tf_op_name, "CudnnRNNV") - || absl::StrContains(tf_op_name, "CudnnRNNForward") - || absl::StrContains(tf_op_name, "CudnnRNNBackprop") - // Special cases. - || absl::EndsWith(tf_op_name, "XlaDot") - || absl::EndsWith(tf_op_name, "XlaDotV2"); - // clang-format on -} - -bool IsEinsumTensorCoreEligible(absl::string_view equation) { - if (equation.empty()) { - return false; - } - const std::vector input_output = - absl::StrSplit(equation, "->"); - if (input_output.size() != 2) { - return false; - } - const std::vector lhs_rhs = - absl::StrSplit(input_output[0], ','); - return lhs_rhs.size() == 2; -} - -bool KernelReportLessThanComparator::operator()(const KernelReport& lhs, - const KernelReport& rhs) const { - // Disable formatting to keep vertical alignment for better readability, - // and make it easier to reorder columns. - // clang-format off - auto lhs_tuple = std::make_tuple( - lhs.name(), - lhs.grid_dim(0), - lhs.grid_dim(1), - lhs.grid_dim(2), - lhs.block_dim(0), - lhs.block_dim(1), - lhs.block_dim(2), - lhs.registers_per_thread(), - lhs.static_shmem_bytes(), - lhs.dynamic_shmem_bytes(), - lhs.is_kernel_using_tensor_core(), - lhs.is_op_tensor_core_eligible(), - lhs.op_name()); - - auto rhs_tuple = std::make_tuple( - rhs.name(), - rhs.grid_dim(0), - rhs.grid_dim(1), - rhs.grid_dim(2), - rhs.block_dim(0), - rhs.block_dim(1), - rhs.block_dim(2), - rhs.registers_per_thread(), - rhs.static_shmem_bytes(), - rhs.dynamic_shmem_bytes(), - rhs.is_kernel_using_tensor_core(), - rhs.is_op_tensor_core_eligible(), - rhs.op_name()); - // clang-format on - return lhs_tuple < rhs_tuple; -} - -bool KernelReportEqualToComparator::operator()(const KernelReport& lhs, - const KernelReport& rhs) const { - // Disable formatting to keep vertical alignment for better readability, - // and make it easier to reorder columns. - // clang-format off - // Put the most expensive string comparisons last. - return ( - lhs.is_kernel_using_tensor_core() == rhs.is_kernel_using_tensor_core() && - lhs.is_op_tensor_core_eligible() == rhs.is_op_tensor_core_eligible() && - lhs.block_dim(0) == rhs.block_dim(0) && - lhs.block_dim(1) == rhs.block_dim(1) && - lhs.block_dim(2) == rhs.block_dim(2) && - lhs.grid_dim(0) == rhs.grid_dim(0) && - lhs.grid_dim(1) == rhs.grid_dim(1) && - lhs.grid_dim(2) == rhs.grid_dim(2) && - lhs.registers_per_thread() == rhs.registers_per_thread() && - lhs.static_shmem_bytes() == rhs.static_shmem_bytes() && - lhs.dynamic_shmem_bytes() == rhs.dynamic_shmem_bytes() && - lhs.name() == rhs.name() && - lhs.op_name() == rhs.op_name()); - // clang-format on -} - -void SortAndKeepTopKDurationKernelReportsInDb(KernelStatsDb* kernel_stats_db) { - auto comp = [](const KernelReport& lhs, const KernelReport& rhs) { - return lhs.total_duration_ns() > rhs.total_duration_ns() || - (lhs.total_duration_ns() == rhs.total_duration_ns() && - KernelReportLessThanComparator()(lhs, rhs)); - }; - - // Sort and keep at most kernel reports. - if (kernel_stats_db->reports_size() > kMaxNumOfKernels) { - std::partial_sort( - kernel_stats_db->mutable_reports()->begin(), - kernel_stats_db->mutable_reports()->begin() + kMaxNumOfKernels, - kernel_stats_db->mutable_reports()->end(), comp); - kernel_stats_db->mutable_reports()->erase( - kernel_stats_db->mutable_reports()->begin() + kMaxNumOfKernels, - kernel_stats_db->mutable_reports()->end()); - } else { - std::sort(kernel_stats_db->mutable_reports()->begin(), - kernel_stats_db->mutable_reports()->end(), comp); - } -} - -void CopyTopKDurationKernelReportsToDb(const KernelReportMap& reports, - KernelStatsDb* dst) { - std::vector> - kernels_to_sort; - kernels_to_sort.reserve(reports.size()); - for (const auto& report_value : reports) { - kernels_to_sort.push_back( - std::make_pair(&report_value.first, &report_value.second)); - } - - auto comp = - [](const std::pair& lhs, - const std::pair& rhs) { - return lhs.second->total_duration_ns > rhs.second->total_duration_ns || - (lhs.second->total_duration_ns == - rhs.second->total_duration_ns && - KernelReportLessThanComparator()(*lhs.first, *rhs.first)); - }; - - // Sort and copy at most kernels to . - if (kernels_to_sort.size() > kMaxNumOfKernels) { - absl::c_partial_sort(kernels_to_sort, - kernels_to_sort.begin() + kMaxNumOfKernels, comp); - } else { - absl::c_sort(kernels_to_sort, comp); - } - - int copy_size = - std::min(kMaxNumOfKernels, static_cast(kernels_to_sort.size())); - for (int i = 0; i < copy_size; i++) { - KernelReport* report = dst->add_reports(); - *report = *kernels_to_sort[i].first; - const KernelReportValue& kernel_value = *kernels_to_sort[i].second; - // Set value using KernelReportValue. - report->set_occurrences(kernel_value.occurrences); - report->set_min_duration_ns(kernel_value.min_duration_ns); - report->set_max_duration_ns(kernel_value.max_duration_ns); - report->set_total_duration_ns(kernel_value.total_duration_ns); - } -} - -void InsertOrUpdateKernelReport(const KernelReport& kernel, - const KernelReportValue& value, - KernelReportMap* dst) { - KernelReportValue& element = (*dst)[kernel]; - if (element.occurrences == 0) { - element = value; - } else { - element.total_duration_ns += value.total_duration_ns; - element.min_duration_ns = - std::min(element.min_duration_ns, value.min_duration_ns); - element.max_duration_ns = - std::max(element.max_duration_ns, value.max_duration_ns); - element.occurrences += value.occurrences; - } -} - -void MergeKernelReports(const KernelReportMap& reports, KernelReportMap* dst) { - for (auto& kernel_value : reports) { - InsertOrUpdateKernelReport(kernel_value.first, kernel_value.second, dst); - } -} - -KernelStatsByOpName GroupKernelReportsByOpName( - const KernelStatsDb& kernel_stats_db) { - KernelStatsByOpName op_level_kernel_stats; - for (const KernelReport& kernel_report : kernel_stats_db.reports()) { - auto ret = op_level_kernel_stats.emplace(kernel_report.op_name(), - OpLevelKernelStats()); - if (ret.second) { - // Inserted. Add a new op in . - OpLevelKernelStats& stats = ret.first->second; - stats.is_op_tensor_core_eligible = - kernel_report.is_op_tensor_core_eligible(); - stats.total_duration_ns += kernel_report.total_duration_ns(); - if (kernel_report.is_kernel_using_tensor_core()) { - stats.tensor_core_duration_ns += kernel_report.total_duration_ns(); - } - } else { - // Not inserted. Aggregate kernel stats to op level. - OpLevelKernelStats& stats = ret.first->second; - // Verifies operations with the same name have the same TensorCore - // eligibility. - DCHECK_EQ(stats.is_op_tensor_core_eligible, - kernel_report.is_op_tensor_core_eligible()); - stats.total_duration_ns += kernel_report.total_duration_ns(); - if (kernel_report.is_kernel_using_tensor_core()) { - stats.tensor_core_duration_ns += kernel_report.total_duration_ns(); - } - } - } - return op_level_kernel_stats; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.h b/tensorflow/core/profiler/utils/kernel_stats_utils.h index 1afecd6d54b1f0..6e625d9835e91f 100644 --- a/tensorflow/core/profiler/utils/kernel_stats_utils.h +++ b/tensorflow/core/profiler/utils/kernel_stats_utils.h @@ -16,121 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_KERNEL_STATS_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_KERNEL_STATS_UTILS_H_ -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/hash/hash.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" - -namespace tensorflow { -namespace profiler { - -// Populates kernel launch information from a kKernelDetails XStat. -void ParseKernelLaunchParams(absl::string_view xstat_kernel_details, - KernelReport* kernel); - -// Returns true if kernel uses TensorCores. -bool IsKernelUsingTensorCore(absl::string_view kernel_name); - -// Returns true if operation is eligible to use TensorCores. -bool IsOpTensorCoreEligible(absl::string_view tf_op_name); - -// Returns true if Einsum equation is eligible to use TensorCores. -bool IsEinsumTensorCoreEligible(absl::string_view equation); - -// Less than comparator for Kernel Reports. -struct KernelReportLessThanComparator { - bool operator()(const KernelReport& lhs, const KernelReport& rhs) const; -}; - -// Equal to comparator for Kernel Reports. -struct KernelReportEqualToComparator { - bool operator()(const KernelReport& lhs, const KernelReport& rhs) const; -}; - -// Sorts kernel reports by total duration descendingly. -// Keeps only the top kernel reports with long kernel duration in the given -// KernelStatsDb. Kernel reports with shorter kernel duration are dropped. -void SortAndKeepTopKDurationKernelReportsInDb(KernelStatsDb* kernel_stats_db); - -struct KernelReportValue { - uint64 total_duration_ns = 0; - uint64 min_duration_ns = 0; - uint64 max_duration_ns = 0; - uint64 occurrences = 0; -}; - -struct KernelKeyWrap { - const KernelReport* key; - template - friend H AbslHashValue(H h, KernelKeyWrap wrap) { - // Kernel reports are grouped by these fields, hence they are used as - // hashing criteria. - // clang-format off - return H::combine( - std::move(h), - wrap.key->is_kernel_using_tensor_core(), - wrap.key->is_op_tensor_core_eligible(), - wrap.key->block_dim(0), - wrap.key->block_dim(1), - wrap.key->block_dim(2), - wrap.key->grid_dim(0), - wrap.key->grid_dim(1), - wrap.key->grid_dim(2), - wrap.key->registers_per_thread(), - wrap.key->static_shmem_bytes(), - wrap.key->dynamic_shmem_bytes(), - wrap.key->name(), - wrap.key->op_name()); - // clang-format on - } -}; - -struct KernelHash { - size_t operator()(const KernelReport& key) const { - return absl::Hash()(KernelKeyWrap{&key}); - } -}; - -using KernelReportMap = - absl::flat_hash_map; - -// Copies the top kernel reports with long kernel duration into the given -// KernelStatsDb. -void CopyTopKDurationKernelReportsToDb(const KernelReportMap& reports, - KernelStatsDb* dst); - -// Inserts or aggregates KernelReports into the given KernelReportMap. -void InsertOrUpdateKernelReport(const KernelReport& kernel, - const KernelReportValue& value, - KernelReportMap* dst); - -// Aggregates values from one KernelReportMap into another. -void MergeKernelReports(const KernelReportMap& reports, KernelReportMap* dst); - -// Kernel stats aggregated at TF operation level. -struct OpLevelKernelStats { - // Whether op is eligible to use TensorCore. - bool is_op_tensor_core_eligible = false; - // The accumulated duration of all the kernels launched in this op. - uint64 total_duration_ns = 0; - // The accumulated duration of all the kernels using TensorCore in this op. - // If this value is not 0, at least one of the kernels launched by this op - // is using TensorCore. - uint64 tensor_core_duration_ns = 0; -}; - -using KernelStatsByOpName = - absl::flat_hash_map; - -// Groups KernelReport in by tensorflow operation name. -KernelStatsByOpName GroupKernelReportsByOpName( - const KernelStatsDb& kernel_stats_db); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/kernel_stats_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_KERNEL_STATS_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc b/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc deleted file mode 100644 index a8cf90adf62a9b..00000000000000 --- a/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc +++ /dev/null @@ -1,175 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" - -#include - -#include -#include "absl/strings/string_view.h" -#include "xla/backends/profiler/gpu/cupti_buffer_events.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using ::testing::FieldsAre; - -TEST(KernelStatsUtilsTest, TestGroupKernelReportsByOpName) { - KernelStatsDb kernel_stats_db; - KernelReport* kernel_report_1 = kernel_stats_db.add_reports(); - kernel_report_1->set_name("op1_kernel1"); - kernel_report_1->set_op_name("op1"); - kernel_report_1->set_total_duration_ns(1000); - kernel_report_1->set_is_kernel_using_tensor_core(true); - kernel_report_1->set_is_op_tensor_core_eligible(true); - - KernelReport* kernel_report_2 = kernel_stats_db.add_reports(); - kernel_report_2->set_name("op1_kernel2"); - kernel_report_2->set_op_name("op1"); - kernel_report_2->set_total_duration_ns(1000); - kernel_report_2->set_is_kernel_using_tensor_core(false); - kernel_report_2->set_is_op_tensor_core_eligible(true); - - KernelReport* kernel_report_3 = kernel_stats_db.add_reports(); - kernel_report_3->set_name("op2_kernel1"); - kernel_report_3->set_op_name("op2"); - kernel_report_3->set_total_duration_ns(100); - kernel_report_3->set_is_kernel_using_tensor_core(false); - kernel_report_3->set_is_op_tensor_core_eligible(false); - - KernelStatsByOpName kernel_stats_by_op_name = - GroupKernelReportsByOpName(kernel_stats_db); - - // Verifies there are two OpLevelKernelStats - ASSERT_EQ(kernel_stats_by_op_name.size(), 2); - auto iter1 = kernel_stats_by_op_name.find("op1"); - auto iter2 = kernel_stats_by_op_name.find("op2"); - ASSERT_NE(iter1, kernel_stats_by_op_name.end()); - ASSERT_NE(iter2, kernel_stats_by_op_name.end()); - const OpLevelKernelStats& op1_stats = iter1->second; - const OpLevelKernelStats& op2_stats = iter2->second; - - EXPECT_EQ(op1_stats.is_op_tensor_core_eligible, true); - EXPECT_EQ(op1_stats.total_duration_ns, 2000); - EXPECT_EQ(op1_stats.tensor_core_duration_ns, 1000); - - EXPECT_EQ(op2_stats.is_op_tensor_core_eligible, false); - EXPECT_EQ(op2_stats.total_duration_ns, 100); - EXPECT_EQ(op2_stats.tensor_core_duration_ns, 0); -} - -TEST(KernelStatsUtilsTest, KernelDetailsXStatParser) { - xla::profiler::KernelDetails kernel_info; - kernel_info.registers_per_thread = 10; - kernel_info.static_shared_memory_usage = 128; - kernel_info.dynamic_shared_memory_usage = 256; - kernel_info.block_x = 32; - kernel_info.block_y = 8; - kernel_info.block_z = 4; - kernel_info.grid_x = 3; - kernel_info.grid_y = 2; - kernel_info.grid_z = 1; - const double occupancy_pct = 50.0; - std::string xstat_kernel_details = ToXStat(kernel_info, occupancy_pct); - KernelReport kernel; - ParseKernelLaunchParams(xstat_kernel_details, &kernel); - // Verifies that the parser can parse kKernelDetails XStat. - EXPECT_EQ(kernel.registers_per_thread(), 10); - EXPECT_EQ(kernel.static_shmem_bytes(), 128); - EXPECT_EQ(kernel.dynamic_shmem_bytes(), 256); - EXPECT_EQ(kernel.block_dim()[0], 32); - EXPECT_EQ(kernel.block_dim()[1], 8); - EXPECT_EQ(kernel.block_dim()[2], 4); - EXPECT_EQ(kernel.grid_dim()[0], 3); - EXPECT_EQ(kernel.grid_dim()[1], 2); - EXPECT_EQ(kernel.grid_dim()[2], 1); -} - -TEST(KernelStatsUtilsTest, KernelDetailsTokenizer) { - KernelReport kernel; - - // Test odd token count (3): { "odd", "grid", "3,2,1" } - absl::string_view kernel_details_0 = "odd grid:3,2,1"; - ParseKernelLaunchParams(kernel_details_0, &kernel); - EXPECT_EQ(kernel.grid_dim()[0], 3); - EXPECT_EQ(kernel.grid_dim()[1], 2); - EXPECT_EQ(kernel.grid_dim()[2], 1); - - // Test odd token count (3): { "block", "6,5,4", "odd" } - absl::string_view kernel_details_1 = "block:6,5,4 odd "; - ParseKernelLaunchParams(kernel_details_1, &kernel); - EXPECT_EQ(kernel.block_dim()[0], 6); - EXPECT_EQ(kernel.block_dim()[1], 5); - EXPECT_EQ(kernel.block_dim()[2], 4); - - // Test odd token count (3): { "block", "1,2,3", "odd", "grid", "4,5,6" } - absl::string_view kernel_details_2 = "block:1,2,3 odd grid:4,5,6"; - ParseKernelLaunchParams(kernel_details_2, &kernel); - EXPECT_EQ(kernel.block_dim()[0], 1); - EXPECT_EQ(kernel.block_dim()[1], 2); - EXPECT_EQ(kernel.block_dim()[2], 3); - EXPECT_EQ(kernel.grid_dim()[0], 4); - EXPECT_EQ(kernel.grid_dim()[1], 5); - EXPECT_EQ(kernel.grid_dim()[2], 6); - - // Test even token count (4): { "static_shared", "7", "dynamic_shared", "8" } - absl::string_view kernel_details_3 = "static_shared:7 dynamic_shared:8"; - ParseKernelLaunchParams(kernel_details_3, &kernel); - EXPECT_EQ(kernel.static_shmem_bytes(), 7); - EXPECT_EQ(kernel.dynamic_shmem_bytes(), 8); -} - -TEST(KernelStatsUtilsTest, TestInsertOrUpdateKernelReport) { - KernelReport kr; - kr.set_name("op1_kernel1"); - kr.set_op_name("op1"); - // Must provide dummy dims since KernelReportMap's comparator assumes array of - // size 3; values here were suggested by autocomplete - kr.add_block_dim(32); - kr.add_block_dim(8); - kr.add_block_dim(4); - kr.add_grid_dim(3); - kr.add_grid_dim(2); - kr.add_grid_dim(1); - - KernelReportValue krv1; - krv1.total_duration_ns = 1700; - krv1.min_duration_ns = 500; - krv1.max_duration_ns = 1200; - krv1.occurrences = 2; - - KernelReportValue krv2; - krv2.total_duration_ns = 900; - krv2.min_duration_ns = 900; - krv2.max_duration_ns = 900; - krv2.occurrences = 1; - - KernelReportMap dst1; - InsertOrUpdateKernelReport(kr, krv1, &dst1); - InsertOrUpdateKernelReport(kr, krv2, &dst1); - EXPECT_THAT(dst1[kr], FieldsAre(2600, 500, 1200, 3)); - - KernelReportMap dst2; - InsertOrUpdateKernelReport(kr, krv2, &dst2); - InsertOrUpdateKernelReport(kr, krv1, &dst2); - EXPECT_THAT(dst2[kr], FieldsAre(2600, 500, 1200, 3)); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc deleted file mode 100644 index 7ff1c33c762f80..00000000000000 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc +++ /dev/null @@ -1,370 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/profiler/utils/math_utils.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/xplane_schema.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/xplane_visitor.h" - -namespace tensorflow { -namespace profiler { - -const absl::string_view kIdle = "IDLE"; -const uint32_t kSparseCoreIndexStart = 1000000; -const int64_t kSingleOccurrence = 1; - -namespace { - -constexpr uint64_t kRootSymbolId = 0; - -using tsl::profiler::StatType; -using tsl::profiler::XEventMetadataVisitor; -using tsl::profiler::XStatVisitor; - -class DeviceTfOpMetricsDbBuilder : public OpMetricsDbBuilder { - public: - explicit DeviceTfOpMetricsDbBuilder(OpMetricsDb* db) - : OpMetricsDbBuilder(db) {} - - void UpdateTfOpMetricsWithDeviceOpMetrics( - absl::string_view tf_op_name, absl::string_view tf_op_type, - const OpMetrics& device_op_metrics) { - OpMetrics* tf_op_metrics = OpMetricsDbBuilder::LookupOrInsertNewOpMetrics( - /*hlo_module_id=*/0, tf_op_name); - if (tf_op_metrics->category().empty()) { - tf_op_metrics->set_category(tf_op_type == tsl::profiler::kUnknownOp - ? "Unknown" - : std::string(tf_op_type)); - } - tf_op_metrics->set_is_eager(device_op_metrics.is_eager()); - // The occurrences of a TF-op is the maximum among the occurrences of all - // device ops that it contains. - tf_op_metrics->set_occurrences(std::max(tf_op_metrics->occurrences(), - device_op_metrics.occurrences())); - tf_op_metrics->set_time_ps(tf_op_metrics->time_ps() + - device_op_metrics.time_ps()); - tf_op_metrics->set_self_time_ps(tf_op_metrics->self_time_ps() + - device_op_metrics.self_time_ps()); - tf_op_metrics->set_flops(tf_op_metrics->flops() + - device_op_metrics.flops()); - tf_op_metrics->set_bytes_accessed(tf_op_metrics->bytes_accessed() + - device_op_metrics.bytes_accessed()); - } -}; - -void SetOpMetadataFromHloEventMetadata( - const XEventMetadataVisitor& hlo_event_metadata, OpMetrics* op_metrics) { - if (hlo_event_metadata.HasDisplayName()) { - op_metrics->set_name(std::string(hlo_event_metadata.DisplayName())); - op_metrics->set_long_name(std::string(hlo_event_metadata.Name())); - } else { - op_metrics->set_name(std::string(hlo_event_metadata.Name())); - } - hlo_event_metadata.ForEachStat([&](const XStatVisitor& stat) { - if (stat.Type().has_value()) { - switch (static_cast(*stat.Type())) { - case StatType::kProgramId: - op_metrics->set_hlo_module_id(stat.IntOrUintValue()); - break; - case StatType::kHloCategory: - op_metrics->set_category(std::string(stat.StrOrRefValue())); - break; - case StatType::kTfOp: - op_metrics->set_provenance(std::string(stat.StrOrRefValue())); - break; - case StatType::kFlops: - op_metrics->set_flops(stat.IntOrUintValue()); - break; - case StatType::kModelFlops: - op_metrics->set_model_flops(stat.IntOrUintValue()); - break; - case StatType::kBytesAccessed: - op_metrics->set_bytes_accessed(stat.IntOrUintValue()); - break; - case StatType::kMemoryAccessBreakdown: { - tensorflow::profiler::MemoryAccessBreakdown breakdown; - const auto& value = stat.BytesValue(); - if (breakdown.ParseFromArray(value.data(), value.size())) { - *op_metrics->mutable_memory_accessed_breakdown() = - breakdown.memory_accessed(); - } - break; - } - case StatType::kDeduplicatedName: - op_metrics->set_deduplicated_name(std::string(stat.StrOrRefValue())); - break; - default: - break; - } - } - }); - hlo_event_metadata.ForEachChild( - [&](const XEventMetadataVisitor& child_hlo_event_metadata) { - OpMetrics* child = op_metrics->mutable_children()->add_metrics_db(); - child->set_occurrences(1); - SetOpMetadataFromHloEventMetadata(child_hlo_event_metadata, child); - }); -} - -void SetOpMetricsFromHloEvent(const tsl::profiler::XEventVisitor& hlo_event, - OpMetrics* op_metrics) { - uint64_t duration_ps = hlo_event.DurationPs(); - uint64_t min_duration_ps = duration_ps; - uint64_t self_duration_ps = duration_ps; - uint64_t dma_stall_ps = 0; - hlo_event.ForEachStat([&](const XStatVisitor& stat) { - if (!stat.Type()) return; - switch (static_cast(*stat.Type())) { - case StatType::kMinDurationPs: - min_duration_ps = stat.IntValue(); - break; - case StatType::kSelfDurationPs: - self_duration_ps = stat.IntValue(); - break; - case StatType::kDmaStallDurationPs: - dma_stall_ps = stat.IntValue(); - break; - default: - break; - } - }); - if (op_metrics->occurrences() == 0) { - SetOpMetadataFromHloEventMetadata(hlo_event.Metadata(), op_metrics); - op_metrics->set_occurrences( - std::max(kSingleOccurrence, hlo_event.NumOccurrences())); - op_metrics->set_time_ps(duration_ps); - op_metrics->set_min_time_ps(min_duration_ps); - op_metrics->set_self_time_ps(self_duration_ps); - op_metrics->set_dma_stall_ps(dma_stall_ps); - op_metrics->set_num_cores(1); - } else { - op_metrics->set_occurrences(op_metrics->occurrences() + - hlo_event.NumOccurrences()); - op_metrics->set_time_ps(op_metrics->time_ps() + duration_ps); - op_metrics->set_min_time_ps( - std::min(op_metrics->min_time_ps(), min_duration_ps)); - op_metrics->set_self_time_ps(op_metrics->self_time_ps() + self_duration_ps); - op_metrics->set_dma_stall_ps(op_metrics->dma_stall_ps() + dma_stall_ps); - } -} - -void MergeOpMetrics(const OpMetrics& src, OpMetrics& dst) { - if (dst.occurrences() == 0) { - dst = src; - } else { - dst.set_occurrences(src.occurrences() + dst.occurrences()); - dst.set_time_ps(src.time_ps() + dst.time_ps()); - dst.set_min_time_ps( - std::min(src.min_time_ps(), dst.min_time_ps())); - dst.set_self_time_ps(src.self_time_ps() + dst.self_time_ps()); - dst.set_dma_stall_ps(src.dma_stall_ps() + dst.dma_stall_ps()); - } -} - -void AdjustFlopsAndBytesAccessed(OpMetrics& op_metrics) { - op_metrics.set_flops(op_metrics.flops() * op_metrics.occurrences()); - if (op_metrics.model_flops() > 0) { - op_metrics.set_model_flops(op_metrics.model_flops() * - op_metrics.occurrences()); - } else { - op_metrics.set_model_flops(op_metrics.flops()); - } - op_metrics.set_bytes_accessed(op_metrics.bytes_accessed() * - op_metrics.occurrences()); - for (auto& memory_access : *op_metrics.mutable_memory_accessed_breakdown()) { - memory_access.set_bytes_accessed(memory_access.bytes_accessed() * - op_metrics.occurrences()); - } -} - -} // namespace - -OpMetricsDbBuilder::OpMetricsDbBuilder(OpMetricsDb* db) : db_(db) { - DCHECK_NE(db_, nullptr); - DCHECK_EQ(db_->metrics_db_size(), db->metrics_db_size()); -} - -OpMetrics* OpMetricsDbBuilder::LookupOrInsertNewOpMetrics( - uint64 hlo_module_id, absl::string_view name) { - OpMetrics*& op_metrics = op_metrics_map_[hlo_module_id][name]; - if (op_metrics == nullptr) { - op_metrics = db_->add_metrics_db(); - op_metrics->set_hlo_module_id(hlo_module_id); - op_metrics->set_name(name.data(), name.size()); - } - return op_metrics; -} - -void XEventsOpMetricsDbBuilder::AddOpMetric( - const tsl::profiler::XEventVisitor& event) { - AddOpMetric(FromXEvent(event), GetOpKeyFromXEvent(event)); -} - -void XEventsOpMetricsDbBuilder::AddOpMetric(const OpMetrics& op_metrics, - const OpKey& key) { - if (!key.program_id.has_value() || !key.symbol_id.has_value() || - key.symbol_id == kRootSymbolId) - return; - MergeOpMetrics( - op_metrics, - flat_op_metric_[key.program_id.value()][key.symbol_id.value()]); -} - -OpMetricsDb XEventsOpMetricsDbBuilder::Finalize(uint64_t total_time_ps) { - OpMetricsDb db = Finalize(); - SetTotalTimePs(db, total_time_ps); - AddIdleOp(db); - return db; -} - -OpMetricsDb XEventsOpMetricsDbBuilder::Finalize() { - OpMetricsDb db; - uint64_t total_op_time_ps = 0; - for (auto& [program_id, op_metric_by_symbol] : flat_op_metric_) { - for (auto& [symbol_id, op_metrics] : op_metric_by_symbol) { - AdjustFlopsAndBytesAccessed(op_metrics); - total_op_time_ps += op_metrics.self_time_ps(); - db.add_metrics_db()->Swap(&op_metrics); - } - } - db.set_total_op_time_ps(total_op_time_ps); - return db; -} - -double IdleTimeRatio(const OpMetricsDb& db) { - return 1.0 - - tsl::profiler::SafeDivide(db.total_op_time_ps(), db.total_time_ps()); -} - -uint64 IdleTimePs(const OpMetricsDb& db) { - DCHECK_GE(db.total_time_ps(), db.total_op_time_ps()); - return db.total_time_ps() - db.total_op_time_ps(); -} - -void SetIdleOp(uint64_t idle_time_ps, OpMetrics& metrics) { - metrics.set_name(std::string(kIdle)); - metrics.set_category(std::string(kIdle)); - metrics.set_occurrences(0); - metrics.set_time_ps(idle_time_ps); - metrics.set_self_time_ps(idle_time_ps); -} - -void AddIdleOp(OpMetricsDb& db) { - uint64 idle_time_ps = IdleTimePs(db); - SetIdleOp(idle_time_ps, *db.add_metrics_db()); -} - -std::optional HostInfeedEnqueueRatio(const OpMetricsDb& db) { - if (db.total_host_infeed_enq_start_timestamp_ps_diff() > 0) { - // We use total_host_infeed_enq_start_timestamp_ps_diff to approximate the - // total host time. - return tsl::profiler::SafeDivide( - db.total_host_infeed_enq_duration_ps(), - db.total_host_infeed_enq_start_timestamp_ps_diff()); - } - return std::nullopt; -} - -OpMetricsDb CreateTfMetricsDbFromDeviceOpMetricsDb( - const OpMetricsDb& device_op_metrics_db, bool with_idle) { - OpMetricsDb tf_op_metrics_db; - DeviceTfOpMetricsDbBuilder builder(&tf_op_metrics_db); - for (const auto& device_op_metrics : device_op_metrics_db.metrics_db()) { - if (IsIdleOp(device_op_metrics)) { - if (with_idle) { - builder.UpdateTfOpMetricsWithDeviceOpMetrics(kIdle, kIdle, - device_op_metrics); - } - } else if (device_op_metrics.provenance().empty()) { - builder.UpdateTfOpMetricsWithDeviceOpMetrics(device_op_metrics.name(), - tsl::profiler::kUnknownOp, - device_op_metrics); - } else { - tsl::profiler::TfOp tf_op = - tsl::profiler::ParseTfOpFullname(device_op_metrics.provenance()); - builder.UpdateTfOpMetricsWithDeviceOpMetrics(tf_op.name, tf_op.type, - device_op_metrics); - } - } - tf_op_metrics_db.set_total_op_time_ps( - device_op_metrics_db.total_op_time_ps()); - - tf_op_metrics_db.set_total_time_ps( - with_idle ? device_op_metrics_db.total_time_ps() - : device_op_metrics_db.total_op_time_ps()); - - return tf_op_metrics_db; -} - -OpMetrics FromXEvent(const tsl::profiler::XEventVisitor& xevent) { - OpMetrics op_metrics; - std::optional stat = xevent.GetStat(StatType::kStepIdleTimePs); - if (stat.has_value()) { - // TODO(b/397774568) : Remove this once the SparseCore OpMetricsDb is - // implemented. - uint64_t idle_time_ps = stat->IntOrUintValue(); - op_metrics.set_self_time_ps(xevent.DurationPs() - idle_time_ps); - op_metrics.set_name("sparse_core_busy_ops"); - op_metrics.set_category("sparse_core_busy_ops"); - return op_metrics; - } - SetOpMetricsFromHloEvent(xevent, &op_metrics); - return op_metrics; -} - -XEventsOpMetricsDbBuilder::OpKey GetOpKeyFromXEvent( - const XEventVisitor& event) { - std::optional stat = event.GetStat(StatType::kStepIdleTimePs); - if (stat.has_value()) { - return {.program_id = std::numeric_limits::max(), - .symbol_id = std::numeric_limits::max()}; - } - - XEventsOpMetricsDbBuilder::OpKey op_key; - DCHECK(event.metadata() != nullptr); - event.Metadata().ForEachStat([&](const XStatVisitor& stat) { - if (stat.Type().has_value()) { - switch (static_cast(*stat.Type())) { - case StatType::kProgramId: - op_key.program_id = stat.IntOrUintValue(); - break; - case StatType::kSymbolId: - op_key.symbol_id = stat.IntOrUintValue(); - break; - default: - break; - } - } - }); - return op_key; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.h b/tensorflow/core/profiler/utils/op_metrics_db_utils.h index 4eca439960b0c2..5ed177ac3780d9 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.h +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.h @@ -16,136 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_OP_METRICS_DB_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_OP_METRICS_DB_UTILS_H_ -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/macros.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" - -namespace tensorflow { -namespace profiler { - -// The name of OpMetrics to represent the idle time. -TF_CONST_INIT extern const absl::string_view kIdle; -// The core index to add to sparse core index in op metrics. -TF_CONST_INIT extern const uint32_t kSparseCoreIndexStart; - -// Helps build an op metrics database (borrowed). -// Enables fast lookup of existing ops and prevents the creation of duplicate -// ops. It is the user's responsibility to ensure an op metrics database -// outlives its builder, and that no ops are added to the database outside of -// the builder. -class OpMetricsDbBuilder { - public: - // Create with a borrowed op database. - // REQUIRED: The op database must be empty. - explicit OpMetricsDbBuilder(OpMetricsDb* db); - - protected: - // Looks up the given OP name. If it is already in the database, - // return its OpMetrics; otherwise, insert a new one. - OpMetrics* LookupOrInsertNewOpMetrics(uint64 hlo_module_id, - absl::string_view name); - - OpMetricsDb* db() { return db_; } - - private: - // Map op (hlo_module_id, name) to the corresponding metrics in the op - // database. - absl::flat_hash_map> - op_metrics_map_; - - // The op database. - OpMetricsDb* db_; -}; - -// Helps build an op metrics database (borrowed) from XEvents, -class XEventsOpMetricsDbBuilder { - public: - struct OpKey { - std::optional program_id; - std::optional symbol_id; - }; - // DEPRECATED: Use the OpKey version below. - // Add OpMetric from XEventVisitor. - void AddOpMetric(const tsl::profiler::XEventVisitor& xevent); - - // Add an OpMetric to the builder based on the provided key. - void AddOpMetric(const OpMetrics& op_metrics, const OpKey& key); - - // Finalize OpMetricDb and add total time and Idle op. - OpMetricsDb Finalize(uint64_t total_time); - - // Finalize OpMetricDb, but the total time is unknown at the moment, So ignore - // the total time and Idle Op and will be handled by the caller. - OpMetricsDb Finalize(); - - private: - using OpMetricBySymbol = - absl::flat_hash_map; - absl::flat_hash_map - flat_op_metric_; -}; - -// Constructs an OpMetrics from the provided XEventVisitor. -OpMetrics FromXEvent(const tsl::profiler::XEventVisitor& xevent); - -// Returns the OpKey for the provided XEventVisitor. -XEventsOpMetricsDbBuilder::OpKey GetOpKeyFromXEvent( - const tsl::profiler::XEventVisitor& event); - -// Sets the total time for OpMetricsDb, ensuring idle time is not negative. -inline void SetTotalTimePs(OpMetricsDb& db, uint64_t total_time_ps) { - db.set_total_time_ps(std::max(db.total_op_time_ps(), total_time_ps)); -} - -// Returns the total time in OpMetricsDb, optionally excluding the idle time. -inline uint64_t TotalTimePs(const OpMetricsDb& db, bool exclude_idle = false) { - return exclude_idle ? db.total_op_time_ps() : db.total_time_ps(); -} - -// Returns the ratio of time that is idle (no op execution) over total time. -double IdleTimeRatio(const OpMetricsDb& db); - -// Returns the idle time in picoseconds. -uint64 IdleTimePs(const OpMetricsDb& db); - -// Populates an OpMetrics record representing idle time, i.e., the amount of -// time spent without any op execution. -void SetIdleOp(uint64_t idle_time_ps, OpMetrics& metrics); - -// Adds an OpMetrics record representing idle time, i.e., the amount of time -// spent without any op execution. -// REQUIRED: All ops must have been added to the database and the total time -// must have been set. -void AddIdleOp(OpMetricsDb& db); - -// Returns true if the given metrics represents idle time. -inline bool IsIdleOp(const OpMetrics& metrics) { - return metrics.category() == kIdle; -} - -// Returns the time spent in children (nested) ops. -inline uint64_t ChildrenTimePs(const OpMetrics& metrics) { - return metrics.time_ps() - metrics.self_time_ps(); -} - -// Returns the ratio of time spent sending data from the host to the device -// relative to the total time the host was active. -std::optional HostInfeedEnqueueRatio(const OpMetricsDb& db); - -// Converts from the device op metrics to Tf-op metrics. -OpMetricsDb CreateTfMetricsDbFromDeviceOpMetricsDb( - const OpMetricsDb& device_op_metrics_db, bool with_idle = true); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/op_metrics_db_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_OP_METRICS_DB_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils_test.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils_test.cc deleted file mode 100644 index 07d85e1411e0a1..00000000000000 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils_test.cc +++ /dev/null @@ -1,220 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" - -#include -#include -#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" -#include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { -namespace { -#if defined(PLATFORM_GOOGLE) -using ::testing::EqualsProto; -using ::testing::proto::IgnoringRepeatedFieldOrdering; -#endif - -constexpr double kMaxError = 1E-10; - -TEST(OpMetricsDbTest, IdleTimeRatio) { - OpMetricsDb metrics_db_0; - metrics_db_0.set_total_time_ps(100000000); - metrics_db_0.set_total_op_time_ps(60000000); - EXPECT_NEAR(0.4, IdleTimeRatio(metrics_db_0), kMaxError); - - OpMetricsDb metrics_db_1; - metrics_db_1.set_total_time_ps(200000000); - metrics_db_1.set_total_op_time_ps(150000000); - EXPECT_NEAR(0.25, IdleTimeRatio(metrics_db_1), kMaxError); - - OpMetricsDb metrics_db_2; - metrics_db_1.set_total_time_ps(0); - metrics_db_1.set_total_op_time_ps(0); - EXPECT_NEAR(1.0, IdleTimeRatio(metrics_db_2), kMaxError); -} - -TEST(OpMetricsDbTest, FromXEventHandlesMissingOccurrences) { - XPlane raw_plane; - XPlaneBuilder plane(&raw_plane); - XLineBuilder line = plane.GetOrCreateLine(0); - XEventMetadata* event_metadata = plane.GetOrCreateEventMetadata("metadata"); - event_metadata->set_display_name("display_name"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kProgramId)), 1); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)), 2); - stats.AddStatValue(*plane.GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDeduplicatedName)), - "deduplicated_name"); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), "tf_op"); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kHloCategory)), - "tf_op_category"); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kFlops)), 3); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kModelFlops)), 4); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kBytesAccessed)), - 5); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(100); - tsl::profiler::XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(&raw_plane); - tsl::profiler::XEventVisitor event_visitor( - &plane_visitor, &raw_plane.lines(0), &raw_plane.lines(0).events(0)); - OpMetrics op_metrics = FromXEvent(event_visitor); - -#if defined(PLATFORM_GOOGLE) - EXPECT_THAT(op_metrics, EqualsProto(R"pb( - occurrences: 1 - time_ps: 100 - self_time_ps: 100 - dma_stall_ps: 0 - hlo_module_id: 1 - flops: 3 - model_flops: 4 - bytes_accessed: 5 - name: "display_name" - long_name: "metadata" - deduplicated_name: "deduplicated_name" - category: "tf_op_category" - provenance: "tf_op" - min_time_ps: 100 - num_cores: 1 - )pb")); -#endif -} - -TEST(OpMetricsDbTest, GetOpKeyFromXEvent) { - XPlane raw_plane; - XPlaneBuilder plane(&raw_plane); - XEventMetadata* event_metadata = plane.GetOrCreateEventMetadata("metadata"); - event_metadata->set_display_name("display_name"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kProgramId)), 1); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)), 2); - XLineBuilder line = plane.GetOrCreateLine(0); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(100); - tsl::profiler::XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(&raw_plane); - tsl::profiler::XEventVisitor event_visitor( - &plane_visitor, &raw_plane.lines(0), &raw_plane.lines(0).events(0)); - XEventsOpMetricsDbBuilder::OpKey op_key = GetOpKeyFromXEvent(event_visitor); - EXPECT_EQ(op_key.program_id, 1); - EXPECT_EQ(op_key.symbol_id, 2); -} - -TEST(OpMetricsDbTest, XEventsOpMetricsDbBuilder) { - XPlane raw_plane; - XPlaneBuilder plane(&raw_plane); - XLineBuilder line = plane.GetOrCreateLine(0); - { - XEventMetadata* event_metadata = plane.GetOrCreateEventMetadata("m1"); - event_metadata->set_display_name("display_name1"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kProgramId)), - 1); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)), 1); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(100); - XEventBuilder event2 = line.AddEvent(*event_metadata); - event2.SetOffsetPs(100); - event2.SetDurationPs(100); - } - { - XEventMetadata* event_metadata = plane.GetOrCreateEventMetadata("m2"); - event_metadata->set_display_name("display_name2"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kProgramId)), - 1); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)), 2); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(100); - } - { - XEventMetadata* event_metadata = plane.GetOrCreateEventMetadata("m3"); - event_metadata->set_display_name("display_name3"); - XStatsBuilder stats(event_metadata, &plane); - stats.AddStatValue( - *plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSymbolId)), 1); - XEventBuilder event = line.AddEvent(*event_metadata); - event.SetOffsetPs(0); - event.SetDurationPs(100); - } - - XEventsOpMetricsDbBuilder builder; - XEventsOpMetricsDbBuilder legacy_builder; - tsl::profiler::XPlaneVisitor plane_visitor = - tsl::profiler::CreateTfXPlaneVisitor(&raw_plane); - plane_visitor.ForEachLine([&](const tsl::profiler::XLineVisitor& line) { - line.ForEachEvent([&](const tsl::profiler::XEventVisitor& event) { - builder.AddOpMetric(FromXEvent(event), GetOpKeyFromXEvent(event)); - legacy_builder.AddOpMetric(event); - }); - }); -#if defined(PLATFORM_GOOGLE) - OpMetricsDb legacy_db = legacy_builder.Finalize(); - OpMetricsDb db = builder.Finalize(); - EXPECT_THAT(db, IgnoringRepeatedFieldOrdering(EqualsProto(legacy_db))); - EXPECT_THAT(db, IgnoringRepeatedFieldOrdering(EqualsProto(R"pb( - metrics_db { - hlo_module_id: 1 - self_time_ps: 200 - occurrences: 2 - name: "display_name1" - long_name: "m1" - time_ps: 200 - min_time_ps: 100 - num_cores: 1 - } - metrics_db { - hlo_module_id: 1 - self_time_ps: 100 - occurrences: 1 - name: "display_name2" - long_name: "m2" - time_ps: 100 - min_time_ps: 100 - num_cores: 1 - } - total_op_time_ps: 300 - )pb"))); -#endif -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/op_utils.cc b/tensorflow/core/profiler/utils/op_utils.cc deleted file mode 100644 index 72b55ba1a76c9b..00000000000000 --- a/tensorflow/core/profiler/utils/op_utils.cc +++ /dev/null @@ -1,183 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/op_utils.h" - -#include -#include - -#include "absl/log/check.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/tsl/platform/types.h" -#include "xla/tsl/profiler/utils/tf_op_utils.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/hlo_module_map.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { -using tsl::uint64; - -namespace {} // namespace - -// Annotate the op_metrics with the metadata from the instr_wrapper. -void EnterOpMetadata(OpMetrics* op_metrics, - const HloInstructionWrapper* instr_wrapper) { - if (op_metrics->name().empty() && op_metrics->category().empty() && - op_metrics->provenance().empty()) { - op_metrics->set_name(std::string(instr_wrapper->Name())); - op_metrics->set_category(std::string(instr_wrapper->Category())); - op_metrics->set_deduplicated_name( - instr_wrapper->Metadata().deduplicated_name()); - op_metrics->set_provenance(std::string(instr_wrapper->op_full_name())); - op_metrics->set_num_cores(1); - op_metrics->set_occurrences(op_metrics->occurrences() + 1); - op_metrics->set_flops(op_metrics->flops() + instr_wrapper->flops()); - op_metrics->set_bytes_accessed(op_metrics->bytes_accessed() + - instr_wrapper->bytes_accessed()); - op_metrics->set_long_name(instr_wrapper->Expression()); - } -} - -void AddFusionChildrenToOpMetricsFromHloInstruction( - OpMetrics* op_metrics, const HloInstructionWrapper* instr_wrapper) { - if (instr_wrapper->FusedChildren().empty()) return; - for (const HloInstructionWrapper* child : instr_wrapper->FusedChildren()) { - if (child->HloOpcode() == xla::HloOpcode::kParameter || - child->HloOpcode() == xla::HloOpcode::kTuple) - continue; - OpMetrics* child_op_metrics = - op_metrics->mutable_children()->add_metrics_db(); - // DeviceOpMetricsDbBuilder children_db_builder( - // op_metrics->mutable_children()); - EnterOpMetadata(child_op_metrics, child); - // children_db_builder.EnterOpMetadata(child_op_metrics, child); - AddFusionChildrenToOpMetricsFromHloInstruction(child_op_metrics, child); - } -} - -void EnterOpMetadataFromHloModuleMap(OpMetrics* op_metrics, - const HloModuleMap& hlo_module_map) { - const HloInstructionWrapper* instr_wrapper = GetHloInstruction( - hlo_module_map, op_metrics->hlo_module_id(), op_metrics->name()); - if (instr_wrapper != nullptr) { - AddFusionChildrenToOpMetricsFromHloInstruction(op_metrics, instr_wrapper); - } -} - -void HostOpMetricsDbBuilder::EnterOp(absl::string_view name, - absl::string_view category, bool is_eager, - uint64 time_ps, uint64 children_time_ps) { - uint64 self_time_ps = time_ps - children_time_ps; - DCHECK_GE(time_ps, self_time_ps); - OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(/*hlo_module_id=*/0, name); - if (op_metrics->category().empty()) - op_metrics->set_category(category.data(), category.size()); - op_metrics->set_num_cores(1); - op_metrics->set_is_eager(op_metrics->is_eager() || is_eager); - op_metrics->set_occurrences(op_metrics->occurrences() + 1); - op_metrics->set_time_ps(op_metrics->time_ps() + time_ps); - op_metrics->set_self_time_ps(op_metrics->self_time_ps() + self_time_ps); - db()->set_total_op_time_ps(db()->total_op_time_ps() + self_time_ps); -} - -void HostOpMetricsDbBuilder::EnterHostInfeedEnqueue( - tsl::profiler::Timespan host_infeed_enqueue) { - if (!last_host_infeed_enqueue_.Empty()) { - // Expect non-overlapping InfeedEnqueue timespans sorted by time. - DCHECK_GE(host_infeed_enqueue.end_ps(), - last_host_infeed_enqueue_.begin_ps()); - db()->set_total_host_infeed_enq_duration_ps( - db()->total_host_infeed_enq_duration_ps() + - last_host_infeed_enqueue_.duration_ps()); - db()->set_total_host_infeed_enq_start_timestamp_ps_diff( - db()->total_host_infeed_enq_start_timestamp_ps_diff() + - (host_infeed_enqueue.begin_ps() - - last_host_infeed_enqueue_.begin_ps())); - } - last_host_infeed_enqueue_ = host_infeed_enqueue; -} - -void DeviceOpMetricsDbBuilder::EnterOpMetadataFromHloModuleMap( - uint64 program_id, absl::string_view op_name, - const HloModuleMap& hlo_module_map) { - OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(program_id, op_name); - tensorflow::profiler::EnterOpMetadataFromHloModuleMap(op_metrics, - hlo_module_map); -} - -void DeviceOpMetricsDbBuilder::EnterOpMetadata( - uint64 program_id, absl::string_view program_name, - absl::string_view category, absl::string_view provenance, - absl::string_view deduplicated_name, bool is_eager, - absl::string_view long_name) { - // We only need to add xla metadata once to each new op, as they are the - // same across occurrences. - OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(program_id, program_name); - if (op_metrics->occurrences() > 0 || !op_metrics->category().empty() || - !op_metrics->provenance().empty()) - return; - op_metrics->set_category(category == tsl::profiler::kUnknownOp - ? "unknown" - : std::string(category)); - op_metrics->set_provenance(std::string(provenance)); - if (!deduplicated_name.empty()) { - op_metrics->set_deduplicated_name(std::string(deduplicated_name)); - } - if (!long_name.empty()) { - op_metrics->set_long_name(std::string(long_name)); - } - op_metrics->set_is_eager(op_metrics->is_eager() || is_eager); -} - -void DeviceOpMetricsDbBuilder::EnterOp( - uint64 program_id, absl::string_view name, absl::string_view category, - absl::string_view provenance, absl::string_view deduplicated_name, - bool is_eager, uint64 occurrences, uint64 time_ps, uint64 children_time_ps, - int64_t flops, int64_t bytes_accessed, - // NOLINTNEXTLINE: clang-tidy missing-includes false positive - const tsl::protobuf::RepeatedPtrField& - memory_accessed_breakdown, - int64_t model_flops) { - EnterOpMetadata(program_id, name, category, provenance, deduplicated_name, - is_eager); - uint64 self_time_ps = time_ps - children_time_ps; - DCHECK_GE(time_ps, self_time_ps); - OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(program_id, name); - op_metrics->set_num_cores(1); - op_metrics->set_occurrences(op_metrics->occurrences() + occurrences); - op_metrics->set_time_ps(op_metrics->time_ps() + time_ps); - op_metrics->set_self_time_ps(op_metrics->self_time_ps() + self_time_ps); - op_metrics->set_flops(op_metrics->flops() + flops * occurrences); - if (model_flops == 0) { - // If ModelsFlops is 0, use the same value as device flops. - op_metrics->set_model_flops(op_metrics->flops()); - } else { - op_metrics->set_model_flops(op_metrics->model_flops() + - model_flops * occurrences); - } - op_metrics->set_bytes_accessed(op_metrics->bytes_accessed() + - bytes_accessed * occurrences); - CombineMemoryAccessedBreakdown( - memory_accessed_breakdown, - op_metrics->mutable_memory_accessed_breakdown()); - db()->set_total_op_time_ps(db()->total_op_time_ps() + self_time_ps); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/op_utils.h b/tensorflow/core/profiler/utils/op_utils.h index 9363574a474f99..b5edd9288a4652 100644 --- a/tensorflow/core/profiler/utils/op_utils.h +++ b/tensorflow/core/profiler/utils/op_utils.h @@ -16,94 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_ -#include - -#include "absl/strings/string_view.h" -#include "xla/tsl/platform/types.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tensorflow/core/profiler/utils/hlo_module_map.h" -#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace profiler { -using tsl::uint64; - -// Annotate the op_metrics with the metadata from the instr_wrapper. -void EnterOpMetadata(OpMetrics* op_metrics, - const HloInstructionWrapper* instr_wrapper); -void EnterOpMetadataFromHloModuleMap(OpMetrics* op_metrics, - const HloModuleMap& hlo_module_map); - -void AddFusionChildrenToOpMetricsFromHloInstruction( - OpMetrics* op_metrics, const HloInstructionWrapper* instr_wrapper); - -class HostOpMetricsDbBuilder : public OpMetricsDbBuilder { - public: - explicit HostOpMetricsDbBuilder(OpMetricsDb* db) : OpMetricsDbBuilder(db) {} - - // A function that will be called when the end of an OP is - // observed on a trace, where: - // name = the OP name. - // category = the OP category. - // is_eager = whether this OP is eagerly executed. - // time_ps = the total execution time of the OP in picoseconds, including - // the execution time of its children. - // children_time_ps = the execution time of the children of this OP in - // picoseconds - void EnterOp(absl::string_view name, absl::string_view category, - bool is_eager, uint64 time_ps, uint64 children_time_ps); - - // Updates total_host_infeed_enq_duration_ps_ and - // total_host_infeed_enq_duration_ps_. - void EnterHostInfeedEnqueue(tsl::profiler::Timespan host_infeed_enqueue); - - private: - // The tsl::profiler::Timespan of the last InfeedEnqueue op on this thread. - tsl::profiler::Timespan last_host_infeed_enqueue_; -}; - -class DeviceOpMetricsDbBuilder : public OpMetricsDbBuilder { - public: - explicit DeviceOpMetricsDbBuilder(OpMetricsDb* db) : OpMetricsDbBuilder(db) {} - - // A function that will be called when the end of an OP is - // observed on a trace, where: - // program_id = the ID of the program that contains this OP. - // name = the OP name. - // category = the OP category. - // provenance = the provenance of this OP (e.g. original TF OP). - // is_eager = whether this OP is eagerly executed. - // occurrences = the number of occurrences of this OP. - // time_ps = the total execution time of the OP in picoseconds, including - // the execution time of its children. - // children_time_ps = the execution time of the children of this OP in - // picoseconds. - // flops = the number of floating-point operations computed. - // bytes_accessed = the sum of bytes read and bytes written by this OP. - // memory_accessed_breakdown = the breakdown of memory accessed by operation - // type and memory space. - void EnterOp(uint64 program_id, absl::string_view name, - absl::string_view category, absl::string_view provenance, - absl::string_view deduplicated_name, bool is_eager, - uint64 occurrences, uint64 time_ps, uint64 children_time_ps, - int64_t flops, int64_t bytes_accessed, - const tsl::protobuf::RepeatedPtrField& - memory_accessed_breakdown = {}, - int64_t model_flops = 0); - - void EnterOpMetadata(uint64 program_id, absl::string_view program_name, - absl::string_view category, absl::string_view provenance, - absl::string_view deduplicated_name, bool is_eager, - absl::string_view long_name = ""); - - void EnterOpMetadataFromHloModuleMap(uint64 program_id, - absl::string_view op_name, - const HloModuleMap& hlo_module_map); -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/op_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/step_intersection.cc b/tensorflow/core/profiler/utils/step_intersection.cc deleted file mode 100644 index 8eb967fafba1e2..00000000000000 --- a/tensorflow/core/profiler/utils/step_intersection.cc +++ /dev/null @@ -1,305 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/utils/step_intersection.h" - -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/strings/str_cat.h" -#include "xla/tsl/profiler/utils/timespan.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/logging.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -// Returns the timespan in this step (across all cores). -tsl::profiler::Timespan StepTimespan(const PerCoreStepInfo& percore_stepinfo) { - uint64 min_ps = kuint64max; - uint64 max_ps = 0; - for (const auto& core_stepinfo : percore_stepinfo.step_info_per_core()) { - const auto& stepinfo = core_stepinfo.second; - uint64 begin_ps = stepinfo.begin_ps(); - uint64 end_ps = begin_ps + stepinfo.duration_ps(); - min_ps = std::min(min_ps, begin_ps); - max_ps = std::max(max_ps, end_ps); - } - return (min_ps < max_ps) - ? tsl::profiler::Timespan::FromEndPoints(min_ps, max_ps) - : tsl::profiler::Timespan(); -} - -// Returns the timespan across all steps in the given step_db. -tsl::profiler::Timespan AllStepsTimespan(const StepDatabaseResult& step_db) { - uint64 min_ps = kuint64max; - uint64 max_ps = 0; - for (const auto& step : step_db.step_sequence()) { - tsl::profiler::Timespan timespan = StepTimespan(step); - uint64 begin_ps = timespan.begin_ps(); - uint64 end_ps = timespan.end_ps(); - min_ps = std::min(min_ps, begin_ps); - max_ps = std::max(max_ps, end_ps); - } - return (min_ps < max_ps) - ? tsl::profiler::Timespan::FromEndPoints(min_ps, max_ps) - : tsl::profiler::Timespan(); -} - -struct AlignmentInfo { - StepsAlignment alignment; - double similarity; -}; - -// Computes the similarity between the given two steps. The closer their -// timespans are, the larger is the similarity. -double StepSimilarity(const PerCoreStepInfo& subordinate_step, - const PerCoreStepInfo& chief_step) { - tsl::profiler::Timespan subordinate_timespan = StepTimespan(subordinate_step); - tsl::profiler::Timespan chief_timespan = StepTimespan(chief_step); - return chief_timespan.OverlappedDurationPs(subordinate_timespan); -} - -// If the subordinate steps and the chief steps are aligned at the given anchor -// points (i.e. at the subordinate_anchor step on the subordinate sequence, at -// the chief_anchor step on the chief sequence), returns the corresponding -// AlignmentInfo. -AlignmentInfo ComputeAlignmentInfo(const StepDatabaseResult& subordinate, - uint32 subordinate_anchor, - const StepDatabaseResult& chief, - uint32 chief_anchor) { - // Assumes that the step at subordinate_anchor on the subordinate sequence is - // aligned with the step at the chief_anchor on the chief sequence. Then the - // number of steps before the anchor is the minimum of the number of steps - // before the anchor in the subordinate and that before the anchor in the - // chief. Similarly, the number of steps after the anchor is the minimum of - // the number of steps after the anchor in the subordinate and that after the - // anchor in the chief. - uint32 pre_anchor_steps = std::min(subordinate_anchor, chief_anchor); - uint32 post_anchor_steps = - std::min(subordinate.step_sequence_size() - subordinate_anchor, - chief.step_sequence_size() - chief_anchor); - // total number of steps aligned = pre_anchor_steps + post_anchor_steps. - uint32 alignment_steps = pre_anchor_steps + post_anchor_steps; - - double similarity = 0; - // Where the aligned steps begin on the subordinate sequence. - uint32 begin_subordinate_idx = subordinate_anchor - pre_anchor_steps; - // Where the aligned steps begin on the chief sequence. - uint32 begin_chief_idx = chief_anchor - pre_anchor_steps; - - for (uint32 i = 0; i < alignment_steps; i++) { - // Accumulates the similarity at each step. - similarity += - StepSimilarity(subordinate.step_sequence(begin_subordinate_idx + i), - chief.step_sequence(begin_chief_idx + i)); - } - StepsAlignment alignment = {begin_subordinate_idx, begin_chief_idx, - alignment_steps}; - return {alignment, similarity}; -} - -// Returns the best alignment for aligning subordinate against chief. -StepsAlignment FindStepsAlignment(const StepDatabaseResult& subordinate, - const StepDatabaseResult& chief) { - double max_similarity = -1; - StepsAlignment alignment = {0, 0, 0}; - if (subordinate.step_sequence_size() == 0 || chief.step_sequence_size() == 0) - return alignment; - for (auto c = 0; c < chief.step_sequence_size(); c++) { - AlignmentInfo info = - ComputeAlignmentInfo(subordinate, /*subordinate_anchor=*/0, chief, c); - if (info.similarity <= max_similarity) continue; - max_similarity = info.similarity; - alignment = info.alignment; - } - for (auto s = 1; s < subordinate.step_sequence_size(); s++) { - // s starts at 1 instead of 0, because the loop above already considers - // (s=0, c=0). - AlignmentInfo info = - ComputeAlignmentInfo(subordinate, s, chief, /*chief_anchor=*/0); - if (info.similarity <= max_similarity) continue; - max_similarity = info.similarity; - alignment = info.alignment; - } - return alignment; -} - -std::string StringStepsAlignment(const StepsAlignment& alignment) { - return absl::StrCat( - "[begin_subordinate_idx: ", alignment.begin_subordinate_idx, - ", begin_chief_idx: ", alignment.begin_chief_idx, - ", num_steps: ", alignment.num_steps, "]"); -} - -std::string StringDstStepNumbers(const std::vector& step_numbers) { - std::string str; - absl::StrAppend(&str, "["); - for (auto i = 0; i < step_numbers.size(); i++) { - if (i > 0) absl::StrAppend(&str, ", "); - absl::StrAppend(&str, step_numbers[i]); - } - absl::StrAppend(&str, "]"); - return str; -} - -std::string StringSrcToDstIndexMap(uint32 src_first_step_idx, - uint32 num_steps) { - std::string str; - absl::StrAppend(&str, "["); - for (auto i = 0; i < num_steps; i++) { - if (i > 0) absl::StrAppend(&str, ", "); - absl::StrAppend(&str, src_first_step_idx + i, ":", i); - } - absl::StrAppend(&str, "]"); - return str; -} - -} // namespace - -StepIntersection::StepIntersection( - uint32 max_steps, - const absl::flat_hash_map& - perhost_stepdb) { - empty_intersect_ = false; - - // Figures out the host with the shortest timespan among their steps (called - // this host the "chief"). - chief_host_id_ = kuint32max; - uint64 min_duration_ps = kuint64max; - const StepDatabaseResult* chief_step_db = nullptr; - for (const auto& hostid_stepdb : perhost_stepdb) { - auto host_id = hostid_stepdb.first; - const auto& step_db = hostid_stepdb.second; - tsl::profiler::Timespan timespan = AllStepsTimespan(*step_db); - if (timespan.duration_ps() < min_duration_ps) { - chief_host_id_ = host_id; - chief_step_db = step_db; - min_duration_ps = timespan.duration_ps(); - } - } - if (chief_host_id_ == kuint32max) { - // There is no step at all on any host. - steps_dropped_ = 0; - begin_chief_idx_ = 0; - end_chief_idx_ = 0; - return; - } - - uint32 max_begin_chief_idx = 0; - uint32 min_end_chief_idx = kuint32max; - // Aligns the steps in all hosts with those in the chief. - for (const auto& hostid_stepdb : perhost_stepdb) { - auto host_id = hostid_stepdb.first; - const auto& step_db = hostid_stepdb.second; - if (host_id == chief_host_id_) { - // Simply aligns with itself. - perhost_alignment_[host_id] = { - /*begin_subordinate_idx=*/0, /*begin_chief_idx=*/0, - static_cast(step_db->step_sequence_size())}; - } else { - perhost_alignment_[host_id] = - FindStepsAlignment(*step_db, *chief_step_db); - } - // Intersects this host's alignment with other hosts' alignments. - uint32 host_begin_chief_idx = perhost_alignment_[host_id].begin_chief_idx; - max_begin_chief_idx = std::max(max_begin_chief_idx, host_begin_chief_idx); - uint32 host_end_chief_idx = perhost_alignment_[host_id].begin_chief_idx + - perhost_alignment_[host_id].num_steps; - min_end_chief_idx = std::min(min_end_chief_idx, host_end_chief_idx); - } - if (max_begin_chief_idx > min_end_chief_idx) { - // The intersection is empty. - steps_dropped_ = 0; - begin_chief_idx_ = 0; - end_chief_idx_ = 0; - empty_intersect_ = true; - return; - } - - begin_chief_idx_ = max_begin_chief_idx; - - // Takes max_steps into account. - uint32 num_steps = min_end_chief_idx - max_begin_chief_idx; - if (num_steps > max_steps) { - steps_dropped_ = num_steps - max_steps; - // TODO(ckluk): Drops from both ends to avoid incomplete steps at the - // beginning and end of the profile. - end_chief_idx_ = max_begin_chief_idx + max_steps; - } else { - steps_dropped_ = 0; - end_chief_idx_ = min_end_chief_idx; - } -} - -std::vector StepIntersection::DstStepNumbers() const { - // TODO(ckluk): Honors training-loop boundaries (if more than one loop - // sampled). - std::vector result; - result.reserve(NumSteps()); - for (uint32 i = 0; i < NumSteps(); i++) { - result.push_back(i); - } - return result; -} - -uint32 StepIntersection::FirstStepIndex(uint32 host_id) const { - const auto* alignment = gtl::FindOrNull(perhost_alignment_, host_id); - if (alignment == nullptr) return 0; - DCHECK(alignment->begin_chief_idx <= begin_chief_idx_); - uint32 shift = begin_chief_idx_ - alignment->begin_chief_idx; - uint32 begin_subordinate_idx = alignment->begin_subordinate_idx + shift; - return begin_subordinate_idx; -} - -std::string StepIntersection::DebugString() const { - std::string str; - absl::StrAppend(&str, "chief host id_: ", chief_host_id_, "\n"); - absl::StrAppend(&str, "begin_chief_idx_: ", begin_chief_idx_, - ", num_steps: ", NumSteps(), "\n"); - absl::StrAppend( - &str, "DstStepNumbers(): ", StringDstStepNumbers(DstStepNumbers()), "\n"); - - std::vector host_ids; - host_ids.reserve(perhost_alignment_.size()); - for (const auto& hostid_alignment : perhost_alignment_) { - auto host_id = hostid_alignment.first; - host_ids.push_back(host_id); - } - absl::c_sort(host_ids); - - absl::StrAppend(&str, "perhost_alignment:\n"); - for (const auto host_id : host_ids) { - const auto* ptr = gtl::FindOrNull(perhost_alignment_, host_id); - if (ptr == nullptr) continue; - absl::StrAppend(&str, "host: ", host_id, - ", step-alignment: ", StringStepsAlignment(*ptr), "\n"); - } - absl::StrAppend(&str, "SrcToDstIndexMap():\n"); - for (const auto host_id : host_ids) { - absl::StrAppend(&str, "host: ", host_id, ", src-to-dst-index-map: ", - StringSrcToDstIndexMap(FirstStepIndex(host_id), NumSteps()), - "\n"); - } - return str; -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/step_intersection.h b/tensorflow/core/profiler/utils/step_intersection.h index 777b0528c30a05..d1932a5c5e43be 100644 --- a/tensorflow/core/profiler/utils/step_intersection.h +++ b/tensorflow/core/profiler/utils/step_intersection.h @@ -16,72 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_STEP_INTERSECTION_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_STEP_INTERSECTION_H_ -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" - -namespace tensorflow { -namespace profiler { - -// Description of how two step sequences are aligned. -struct StepsAlignment { - uint32 begin_subordinate_idx; // where the alignment begins on the - // subordinate steps. - uint32 begin_chief_idx; // where the alignment begins on the chief steps. - uint32 num_steps; // aligned for how many steps. -}; - -class StepIntersection { - public: - StepIntersection( - uint32 max_steps, - const absl::flat_hash_map& - perhost_stepdb); - - // Returns the number of steps in the intersection. - uint32 NumSteps() const { return end_chief_idx_ - begin_chief_idx_; } - - // Returns the value of empty_intersect_ (see the explanation of - // empty_intersect_ below). - bool EmptyIntersect() const { return empty_intersect_; } - - // Returns the step numbers for the destination (i.e. the intersection - // result). - std::vector DstStepNumbers() const; - - // Returns the index to the step in the given host that corresponds to the - // first step in the intersection. - uint32 FirstStepIndex(uint32 host_id) const; - - // Returns the number of steps dropped due to the max_steps constraint - // specified in the constructor. - uint32 StepsDropped() const { return steps_dropped_; } - - std::string DebugString() const; - - private: - absl::flat_hash_map perhost_alignment_; - uint32 - chief_host_id_; // the host whose step sequence is selected as the chief. - uint32 steps_dropped_; // number of steps dropped. - // If NumSteps() is 0, empty_intersect indicates one of two possible reasons: - // (i) At least one host has some steps, but the intersection over all hosts - // is empty. In this case, empty_intersect is true, - // (ii) None of the hosts has any steps. In this case, empty_intersect is - // false. - // If NumSteps() > 0, empty_intersect is don't care. - bool empty_intersect_; - // The begin and end indices to the chief step sequence for this step - // intersection. Note that the begin index is inclusive but the end index is - // exclusive. - uint32 begin_chief_idx_; - uint32 end_chief_idx_; -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/step_intersection.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_STEP_INTERSECTION_H_ diff --git a/tensorflow/core/profiler/utils/step_intersection_test.cc b/tensorflow/core/profiler/utils/step_intersection_test.cc deleted file mode 100644 index 2115581ff1a270..00000000000000 --- a/tensorflow/core/profiler/utils/step_intersection_test.cc +++ /dev/null @@ -1,260 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/step_intersection.h" - -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace profiler { -namespace { - -using PerHostStepDb = - absl::flat_hash_map; - -constexpr uint64 kStepDurationPs = 2000000000; -constexpr uint32 kNumStepsPerHost = 10; -constexpr uint64 kStepGapPs = 0; -constexpr uint32 kNumCoresPerHost = 8; - -PerCoreStepInfo CreateOneTestStep(uint32 host_id, uint32 num_steps, - uint32 step_idx, uint64 step_begin_ps) { - PerCoreStepInfo result; - uint32 step_num = - step_idx * host_id; // creates the situation where each host has a - // different step number for the same step. - result.set_step_num(step_num); - StepInfoResult info; - info.set_step_num(step_num); - if (host_id == 0 && step_idx == (num_steps - 1)) { - // Makes the last step on host_id is little bit shorter so that host-0 will - // be chosen as the chief. - info.set_duration_ps(kStepDurationPs - 1); - } else { - info.set_duration_ps(kStepDurationPs); - } - info.set_begin_ps(step_begin_ps); - // Don't care about the rest of the fields in StepInfoResult. - for (uint32 core_id = 0; core_id < kNumCoresPerHost; core_id++) { - (*result.mutable_step_info_per_core())[core_id] = info; - // Don't care about the rest of the fields in PerCoreStepInfo. - } - return result; -} - -PerHostStepDb CreateTestSteps(uint32 num_hosts, uint64 shift_ps) { - PerHostStepDb result; - uint64 first_step_begin_ps = 0; - for (uint32 host_id = 0; host_id < num_hosts; host_id++) { - StepDatabaseResult step_db; - uint64 step_begin_ps = first_step_begin_ps; - for (uint32 step_idx = 0; step_idx < kNumStepsPerHost; step_idx++) { - *step_db.add_step_sequence() = - CreateOneTestStep(host_id, kNumStepsPerHost, step_idx, step_begin_ps); - step_begin_ps += (kStepDurationPs + kStepGapPs); - } - result[host_id] = step_db; - first_step_begin_ps += shift_ps; - } - return result; -} - -PerHostStepDb CreateEmptyIntersectTestSteps() { - PerHostStepDb result; - - uint64 step_begin_ps; - uint32 host_id; - - // Host-0 - host_id = 0; - step_begin_ps = 0; - uint64 host_0_num_steps = 10; - StepDatabaseResult step_db_0; - for (uint32 step_idx = 0; step_idx < host_0_num_steps; step_idx++) { - *step_db_0.add_step_sequence() = - CreateOneTestStep(host_id, host_0_num_steps, step_idx, step_begin_ps); - step_begin_ps += (kStepDurationPs + kStepGapPs); - } - result[host_id] = step_db_0; - - // Host-1 - host_id = 1; - step_begin_ps = (host_0_num_steps - 2) * (kStepDurationPs + kStepGapPs); - uint64 host_1_num_steps = 5; - StepDatabaseResult step_db_1; - for (uint32 step_idx = 0; step_idx < host_1_num_steps; step_idx++) { - *step_db_1.add_step_sequence() = - CreateOneTestStep(host_id, host_1_num_steps, step_idx, step_begin_ps); - step_begin_ps += (kStepDurationPs + kStepGapPs); - } - result[host_id] = step_db_1; - - // Host-2 - host_id = 2; - step_begin_ps = (host_0_num_steps + host_1_num_steps - 4) * - (kStepDurationPs + kStepGapPs); - uint64 host_2_num_steps = 10; - StepDatabaseResult step_db_2; - for (uint32 step_idx = 0; step_idx < host_2_num_steps; step_idx++) { - *step_db_2.add_step_sequence() = - CreateOneTestStep(host_id, host_2_num_steps, step_idx, step_begin_ps); - step_begin_ps += (kStepDurationPs + kStepGapPs); - } - result[host_id] = step_db_2; - - return result; -} - -PerHostStepDb CreateNoStep(uint32 num_hosts) { - PerHostStepDb result; - for (uint32 host_id = 0; host_id < num_hosts; host_id++) { - StepDatabaseResult step_db; - result[host_id] = step_db; - } - return result; -} - -absl::flat_hash_map Convert( - const PerHostStepDb& perhost_stepdb) { - absl::flat_hash_map result; - for (const auto& hostid_stepdb : perhost_stepdb) { - auto host_id = hostid_stepdb.first; - const auto& step_db = hostid_stepdb.second; - result[host_id] = &step_db; - } - return result; -} - -TEST(StepIntersectionTest, EachHostShiftedBy1StepDuration) { - uint32 num_hosts = 4; - uint64 shift_ps = kStepDurationPs; - - PerHostStepDb perhost_stepdb = CreateTestSteps(num_hosts, shift_ps); - StepIntersection intersection = - StepIntersection(kNumStepsPerHost, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.StepsDropped(), 0); - uint32 dst_num_steps = kNumStepsPerHost - num_hosts + 1; - EXPECT_EQ(intersection.NumSteps(), dst_num_steps); - - uint32 src_first_step_index = intersection.FirstStepIndex(0); - EXPECT_EQ(src_first_step_index, num_hosts - 1); - std::vector dst_step_numbers = intersection.DstStepNumbers(); - for (uint32 i = 0; i < dst_num_steps; i++) { - EXPECT_EQ(dst_step_numbers[i], i); - } -} - -TEST(StepIntersectionTest, ExactlyNoShift) { - uint32 num_hosts = 4; - uint64 shift_ps = 0; - - PerHostStepDb perhost_stepdb = CreateTestSteps(num_hosts, shift_ps); - StepIntersection intersection = - StepIntersection(kNumStepsPerHost, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.StepsDropped(), 0); - uint32 dst_num_steps = kNumStepsPerHost; - EXPECT_EQ(intersection.NumSteps(), dst_num_steps); - - std::vector dst_step_numbers = intersection.DstStepNumbers(); - for (uint32 i = 0; i < dst_num_steps; i++) { - EXPECT_EQ(dst_step_numbers[i], i); - } - for (uint32 host_id = 0; host_id < num_hosts; host_id++) { - uint32 src_first_step_index = intersection.FirstStepIndex(host_id); - EXPECT_EQ(src_first_step_index, 0); - } -} - -TEST(StepIntersectionTest, EachHostShiftedByJustABit) { - uint32 num_hosts = 4; - uint64 shift_ps = 100; - - PerHostStepDb perhost_stepdb = CreateTestSteps(num_hosts, shift_ps); - StepIntersection intersection = - StepIntersection(kNumStepsPerHost, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.StepsDropped(), 0); - uint32 dst_num_steps = kNumStepsPerHost; - EXPECT_EQ(intersection.NumSteps(), dst_num_steps); - - std::vector dst_step_numbers = intersection.DstStepNumbers(); - for (uint32 i = 0; i < dst_num_steps; i++) { - EXPECT_EQ(dst_step_numbers[i], i); - } - for (uint32 host_id = 0; host_id < num_hosts; host_id++) { - uint32 src_first_step_index = intersection.FirstStepIndex(host_id); - EXPECT_EQ(src_first_step_index, 0); - } -} - -TEST(StepIntersectionTest, SingleHost) { - uint32 num_hosts = 1; - uint64 shift_ps = 0; - - PerHostStepDb perhost_stepdb = CreateTestSteps(num_hosts, shift_ps); - StepIntersection intersection = - StepIntersection(kNumStepsPerHost, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.StepsDropped(), 0); - uint32 dst_num_steps = kNumStepsPerHost; - EXPECT_EQ(intersection.NumSteps(), dst_num_steps); - - std::vector dst_step_numbers = intersection.DstStepNumbers(); - for (uint32 i = 0; i < dst_num_steps; i++) { - EXPECT_EQ(dst_step_numbers[i], i); - } - for (uint32 host_id = 0; host_id < num_hosts; host_id++) { - uint32 src_first_step_index = intersection.FirstStepIndex(host_id); - EXPECT_EQ(src_first_step_index, 0); - } -} - -TEST(StepIntersectionTest, WithMaxSteps) { - uint32 num_hosts = 4; - uint64 shift_ps = 0; - uint32 max_steps = 3; - - PerHostStepDb perhost_stepdb = CreateTestSteps(num_hosts, shift_ps); - StepIntersection intersection = - StepIntersection(max_steps, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.StepsDropped(), kNumStepsPerHost - max_steps); - EXPECT_EQ(intersection.NumSteps(), max_steps); -} - -TEST(StepIntersectionTest, NoStep) { - uint32 num_hosts = 4; - uint32 max_steps = 100; - PerHostStepDb perhost_stepdb = CreateNoStep(num_hosts); - StepIntersection intersection = - StepIntersection(max_steps, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.NumSteps(), 0); - EXPECT_FALSE(intersection.EmptyIntersect()); -} - -TEST(StepIntersectionTest, EmptyIntersection) { - uint32 max_steps = 100; - PerHostStepDb perhost_stepdb = CreateEmptyIntersectTestSteps(); - StepIntersection intersection = - StepIntersection(max_steps, Convert(perhost_stepdb)); - EXPECT_EQ(intersection.StepsDropped(), 0); - EXPECT_EQ(intersection.NumSteps(), 0); - EXPECT_TRUE(intersection.EmptyIntersect()); -} - -} // namespace -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/tfstreamz_utils.cc b/tensorflow/core/profiler/utils/tfstreamz_utils.cc deleted file mode 100644 index 2d3b5fa4a1bc8e..00000000000000 --- a/tensorflow/core/profiler/utils/tfstreamz_utils.cc +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/utils/tfstreamz_utils.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "absl/strings/substitute.h" -#include "tensorflow/core/framework/summary.pb.h" -#include "tensorflow/core/lib/monitoring/collected_metrics.h" -#include "tensorflow/core/lib/monitoring/metric_def.h" -#include "tensorflow/core/lib/monitoring/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/protobuf/tfstreamz.pb.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -std::string ConstructXStatName(absl::string_view name, - const monitoring::Point& point) { - if (point.labels.empty()) { - return std::string(name); - } - return absl::Substitute( - "$0{$1}", name, - absl::StrJoin( - point.labels, ", ", - [](std::string* out, const monitoring::Point::Label& label) { - absl::StrAppend(out, label.name, "=", label.value); - })); -} - -tfstreamz::Percentiles ToProto(const monitoring::Percentiles& percentiles) { - tfstreamz::Percentiles output; - output.set_unit_of_measure( - static_cast(percentiles.unit_of_measure)); - output.set_start_nstime(percentiles.start_nstime); - output.set_end_nstime(percentiles.end_nstime); - output.set_min_value(percentiles.min_value); - output.set_max_value(percentiles.max_value); - output.set_mean(percentiles.mean); - output.set_stddev(percentiles.stddev); - output.set_num_samples(percentiles.num_samples); - output.set_total_samples(percentiles.total_samples); - output.set_accumulator(percentiles.accumulator); - for (const auto& pp : percentiles.points) { - auto* percentile_point = output.add_points(); - percentile_point->set_percentile(pp.percentile); - percentile_point->set_value(pp.value); - } - return output; -} - -} // namespace - -absl::Status SerializeToXPlane(const std::vector& snapshots, - XPlane* plane, uint64 line_start_time_ns) { - XPlaneBuilder xplane(plane); - XLineBuilder line = xplane.GetOrCreateLine(0); // This plane has single line. - line.SetTimestampNs(line_start_time_ns); - - // For each snapshot, create a virtual event. - for (const auto& snapshot : snapshots) { - XEventMetadata* event_metadata = - xplane.GetOrCreateEventMetadata("TFStreamz Snapshot"); - XEventBuilder xevent = line.AddEvent(*event_metadata); - xevent.SetTimestampNs(snapshot.start_time_ns); - xevent.SetEndTimestampNs(snapshot.end_time_ns); - auto& metric_descriptor_map = snapshot.metrics->metric_descriptor_map; - for (const auto& point_set : snapshot.metrics->point_set_map) { - const std::string& metric_name = point_set.first; - // Each metrics have multiple points corresponding to different labels. - for (const auto& point : point_set.second->points) { - // Generates one KPI metric for each point. - std::string stat_name = ConstructXStatName(metric_name, *point); - auto* metadata = xplane.GetOrCreateStatMetadata(stat_name); - auto it = metric_descriptor_map.find(metric_name); - if (it != metric_descriptor_map.end()) { - metadata->set_description(it->second->description); - } - switch (point->value_type) { - case monitoring::ValueType::kInt64: - xevent.AddStatValue(*metadata, point->int64_value); - break; - case monitoring::ValueType::kBool: - xevent.AddStatValue(*metadata, point->bool_value); - break; - case monitoring::ValueType::kString: - xevent.AddStatValue(*metadata, *xplane.GetOrCreateStatMetadata( - point->string_value)); - break; - case monitoring::ValueType::kDouble: - xevent.AddStatValue(*metadata, point->double_value); - break; - case monitoring::ValueType::kHistogram: - xevent.AddStatValue(*metadata, point->histogram_value); - break; - case monitoring::ValueType::kPercentiles: - xevent.AddStatValue(*metadata, ToProto(point->percentiles_value)); - break; - } - } - } - } - return absl::OkStatus(); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/tfstreamz_utils.h b/tensorflow/core/profiler/utils/tfstreamz_utils.h index abaafbc6e3c990..dffca153ac07b0 100644 --- a/tensorflow/core/profiler/utils/tfstreamz_utils.h +++ b/tensorflow/core/profiler/utils/tfstreamz_utils.h @@ -15,27 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TFSTREAMZ_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_TFSTREAMZ_UTILS_H_ -#include -#include - -#include "absl/status/status.h" -#include "tensorflow/core/lib/monitoring/collected_metrics.h" -#include "tensorflow/core/platform/types.h" -#include "tsl/profiler/protobuf/xplane.pb.h" - -namespace tensorflow { -namespace profiler { - -struct TfStreamzSnapshot { - std::unique_ptr metrics; - uint64 start_time_ns; // time before collection. - uint64 end_time_ns; // time after collection. -}; - -absl::Status SerializeToXPlane(const std::vector& snapshots, - XPlane* plane, uint64 line_start_time_ns); - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/tfstreamz_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_TFSTREAMZ_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h b/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h index 731481a4da8612..e803bbc1b41244 100644 --- a/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h +++ b/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2024 The OpenXLA Authors. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,61 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_BREAKDOWN_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_BREAKDOWN_UTILS_H_ -#include - -#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" - -namespace tensorflow { -namespace profiler { - -// Total duration of infeed from host or SparseCoreV0 to TensorCore. -inline uint64_t InfeedDurationPs(const TpuStepBreakdown& tpu) { - return tpu.infeed_duration_ps() + tpu.wait_for_scv0_duration_ps() + - tpu.scv0_infeed_transform_ps(); -} - -// Total duration of outfeed from TensorCore to host or SparseCoreV0. -inline uint64_t OutfeedDurationPs(const TpuStepBreakdown& tpu) { - return tpu.host_outfeed_ps() + tpu.scv0_outfeed_ps(); -} - -// Total duration of infeed from host to SparseCoreV0. -inline uint64_t ScV0InfeedDurationPs(const TpuStepBreakdown& tpu) { - return tpu.wait_for_scv0_duration_ps() * tpu.scv0_infeed_percent() / 100.0; -} - -// Total duration of SparseCoreV0 compute. -inline uint64_t ScV0ComputeDurationPs(const TpuStepBreakdown& tpu) { - return tpu.wait_for_scv0_duration_ps() - ScV0InfeedDurationPs(tpu); -} - -// Total duration of infeed from host to TensorCore or SparseCoreV0. -inline uint64_t TcPlusScV0InfeedDurationPs(const TpuStepBreakdown& tpu) { - return tpu.infeed_duration_ps() + ScV0InfeedDurationPs(tpu); -} - -// Total duration of send and recv ops. -inline uint64_t SendRecvDurationPs(const TpuStepBreakdown& tpu) { - return tpu.send_duration_ps() + tpu.recv_duration_ps(); -} - -// Total duration of host send and host recv ops. -inline uint64_t HostSendRecvDurationPs(const TpuStepBreakdown& tpu) { - return tpu.host_send_duration_ps() + tpu.host_recv_duration_ps(); -} - -// Total duration TensorCore spends waiting for host. -inline uint64_t WaitForHostDurationPs(const TpuStepBreakdown& tpu) { - return tpu.infeed_duration_ps() + tpu.host_outfeed_ps() + - HostSendRecvDurationPs(tpu) + tpu.tc_idle_ps(); -} - -// Total duration TensorCore spends waiting for host or SparseCoreV0. -inline uint64_t WaitForHostOrScV0DurationPs(const TpuStepBreakdown& tpu) { - return WaitForHostDurationPs(tpu) + tpu.wait_for_scv0_duration_ps(); -} - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/tpu_step_breakdown_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_BREAKDOWN_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/tpu_step_details_utils.h b/tensorflow/core/profiler/utils/tpu_step_details_utils.h index 23c1609dc797b7..8ce4f3a2bef490 100644 --- a/tensorflow/core/profiler/utils/tpu_step_details_utils.h +++ b/tensorflow/core/profiler/utils/tpu_step_details_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2024 The OpenXLA Authors. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,35 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_DETAILS_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_DETAILS_UTILS_H_ -#include "tensorflow/core/profiler/protobuf/tpu_input_pipeline.pb.h" - -namespace tensorflow { -namespace profiler { - -inline double ComputeTimeMs(const PerTpuStepDetails& details) { - return details.tc_compute_time_ms() + details.scv0_compute_time_ms(); -} - -inline double InfeedTimeMs(const PerTpuStepDetails& details) { - return details.tc_infeed_time_ms() + details.scv0_infeed_time_ms(); -} - -inline double AllReduceTimeMs(const PerTpuStepDetails& details) { - return details.all_reduce_compute_time_ms() + - details.all_reduce_sync_time_ms(); -} - -inline double NonIdleTimeMs(const PerTpuStepDetails& details) { - return ComputeTimeMs(details) + InfeedTimeMs(details) + - AllReduceTimeMs(details) + details.tc_outfeed_time_ms(); -} - -// Time spent by a training step on TPU. -inline double StepTimeMs(const PerTpuStepDetails& details) { - return NonIdleTimeMs(details) + details.tc_idle_time_ms(); -} - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/tpu_step_details_utils.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_DETAILS_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc deleted file mode 100644 index 321cf041502c7b..00000000000000 --- a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h" - -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/primitive_util.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/service/hlo_cost_analysis.h" -#include "xla/tsl/platform/errors.h" - -namespace tensorflow { -namespace profiler { - -namespace { - -std::vector GetInputBitwidths(const xla::HloInstruction& hlo) { - std::vector input_bitwidths; - for (const auto& operand : hlo.operands()) { - switch (operand->shape().element_type()) { - case xla::PRIMITIVE_TYPE_INVALID: - case xla::TUPLE: - case xla::OPAQUE_TYPE: - case xla::TOKEN: - break; - default: - input_bitwidths.push_back( - xla::primitive_util::BitWidth(operand->shape().element_type())); - } - } - return input_bitwidths; -} - -} // namespace - -absl::Status XProfGpuCostAnalysis::HandleCustomCall( - const xla::HloInstruction* hlo) { - TF_RETURN_IF_ERROR(xla::gpu::GpuHloCostAnalysis::HandleCustomCall(hlo)); - - if (xla::gpu::IsCublasGemm(*hlo)) { - // The naming conventions and meanings of gemm parameters are documented at: - // https://docs.nvidia.com/cuda/cublas/index.html#using-the-cublaslt-api - // as inherited from GpuHloCostAnalysis, we only normalize the flops based - // on the datatype of A and B, which are supposed of same bitwidth. - int dot_operands_bitwidth = - xla::primitive_util::BitWidth(hlo->operand(0)->shape().element_type()); - uint32_t flop_rate_adjustment = 1; - switch (dot_operands_bitwidth) { - case 8: - flop_rate_adjustment = 2; - break; - case 4: - flop_rate_adjustment = 4; - break; - default: - break; - } - float model_flops = current_properties_[kFlopsKey]; - current_properties_[kDeviceFlopsAdjustment] = - model_flops - model_flops / flop_rate_adjustment; - } - return absl::OkStatus(); -} - -absl::Status XProfGpuCostAnalysis::DefaultPostprocess( - const xla::HloInstruction* hlo) { - uint32_t flop_rate_adjustment = 1; - float model_flops = current_properties_[kFlopsKey]; - - // Calculate adjustment of device flops based on input bit widths. - // This provide most general adjustment for all ops, and for all gpus. - std::vector input_bitwidths = GetInputBitwidths(*hlo); - if (!input_bitwidths.empty()) { - int max_input_bitwidth = - *std::max_element(input_bitwidths.begin(), input_bitwidths.end()); - if (model_flops) { - // for int8/fp8, 2x flops assumed comparing with fp16 flops(most of - // recent GPU models); for int4, 4x of model flops assumed comparing - // with fp16 flops. (like Nvidia T4, 3090). It will be more precise - // after adjustment based on specific GPUs mentioned above. - switch (max_input_bitwidth) { - case 8: - flop_rate_adjustment = 2; - break; - case 4: - flop_rate_adjustment = 4; - break; - default: - break; - } - } - } - current_properties_[kDeviceFlopsAdjustment] = - model_flops - model_flops / flop_rate_adjustment; - return absl::OkStatus(); -} - -absl::Status XProfGpuCostAnalysis::Postprocess(const xla::HloInstruction* hlo) { - if (hlo == nullptr) { - return absl::OkStatus(); - } - - switch (hlo->opcode()) { - case xla::HloOpcode::kCustomCall: - // Already handled specially in HandleCustomCall(), skip here. - // Add more OpCode here if it is handled specially in future. - break; - default: - DefaultPostprocess(hlo).IgnoreError(); - break; - } - - return xla::gpu::GpuHloCostAnalysis::Postprocess(hlo); -} - -std::unique_ptr -XProfGpuCostAnalysis::CreateNestedCostAnalysis() { - return std::make_unique(options_); -} - -int64_t XProfGpuCostAnalysis::GetDeviceFlopsAdjustment( - const xla::HloInstruction& hlo) { - return GetPropertyForHlo(hlo, kDeviceFlopsAdjustment, hlo_properties_); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h index 76b50f5997d9c6..3814be42d65646 100644 --- a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h +++ b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h @@ -16,42 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/service/hlo_cost_analysis.h" - -namespace tensorflow { -namespace profiler { - -// XProfGpuCostAnalysis provides additional cost analysis for XProf, which -// normalizes the flops to the device flops based on input bit widths. -class XProfGpuCostAnalysis : public xla::gpu::GpuHloCostAnalysis { - public: - explicit XProfGpuCostAnalysis(const xla::HloCostAnalysis::Options& options) - : xla::gpu::GpuHloCostAnalysis(options) {} - - absl::Status HandleCustomCall(const xla::HloInstruction* hlo) override; - - absl::Status Postprocess(const xla::HloInstruction* hlo) override; - - int64_t GetDeviceFlopsAdjustment(const xla::HloInstruction& hlo); - - protected: - std::unique_ptr CreateNestedCostAnalysis() override; - - absl::Status DefaultPostprocess(const xla::HloInstruction* hlo); - - private: - static inline constexpr absl::string_view kDeviceFlopsAdjustment = - "device_flops_adjustment"; -}; - -} // namespace profiler -} // namespace tensorflow +#include "xprof/utils/xprof_gpu_cost_analysis.h" // from @org_xprof // IWYU pragma: export #endif // TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ diff --git a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc deleted file mode 100644 index c71d1a9dfb5730..00000000000000 --- a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc +++ /dev/null @@ -1,189 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h" - -#include - -#include -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/testlib/test_helpers.h" -#include "xla/service/hlo_cost_analysis.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/xla_data.pb.h" - -namespace tensorflow { -namespace profiler { - -class XprofGpuHloCostAnalysisTest : public xla::HloTestBase { - xla::HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { - return [&](const xla::Shape& shape) { - constexpr int64_t kPointerSize = 8; - return xla::ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - } - - public: - xla::HloCostAnalysis::Options options_{ - ShapeSizeBytesFunction(), - /*per_second_rates=*/{}, - /*min_latencies_seconds=*/{}, - /*count_multiple_input_accesses=*/true}; - XProfGpuCostAnalysis analysis_{options_}; - XprofGpuHloCostAnalysisTest() : xla::HloTestBase() {} -}; - -TEST_F(XprofGpuHloCostAnalysisTest, Fp16GemmNoAdjustment) { - absl::string_view hlo_string = R"( -HloModule r - -ENTRY e { - arg0 = f16[65536,32800] parameter(0) - arg1 = f16[32800,32] parameter(1) - gemm = (f16[65536,32], s8[0]) custom-call(arg0, arg1), - custom_call_target="__cublas$gemm", - backend_config="{ - \"gemm_backend_config\": { - \"alpha_real\":1, - \"beta\":0, - \"dot_dimension_numbers\":{ - \"lhs_contracting_dimensions\":[\"1\"], - \"rhs_contracting_dimensions\":[\"0\"], - \"lhs_batch_dimensions\":[], - \"rhs_batch_dimensions\":[] - }, - \"alpha_imag\":0, - \"precision_config\":{ - \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] - }, - \"epilogue\":\"DEFAULT\" - } - }" - ROOT get-tuple-element = f16[65536,32] - get-tuple-element((f16[65536,32], s8[0]) gemm), index=0 -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); - xla::HloComputation* comp = module->entry_computation(); - const xla::HloInstruction* fp16gemm = comp->GetInstructionWithName("gemm"); - // flops of gemm A * B = rows(A) * cols(B) * cols(A) * 2 - // where 2 is for the add and multiply - int64_t gold_flops = 65536LL * 32800 * 32 * 2; - EXPECT_EQ(analysis_.flop_count(*fp16gemm), gold_flops); - EXPECT_EQ(analysis_.GetDeviceFlopsAdjustment(*fp16gemm), 0); -} - -TEST_F(XprofGpuHloCostAnalysisTest, S8GemmAdjustment) { - absl::string_view hlo_string = R"( -HloModule r - -ENTRY e { - arg0 = s8[65536,32800] parameter(0) - arg1 = s8[32800,32] parameter(1) - gemm = (s32[65536,32], s8[0]) custom-call(arg0, arg1), - custom_call_target="__cublas$gemm", - backend_config="{ - \"gemm_backend_config\": { - \"alpha_real\":1, - \"beta\":0, - \"dot_dimension_numbers\":{ - \"lhs_contracting_dimensions\":[\"1\"], - \"rhs_contracting_dimensions\":[\"0\"], - \"lhs_batch_dimensions\":[], - \"rhs_batch_dimensions\":[] - }, - \"alpha_imag\":0, - \"precision_config\":{ - \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] - }, - \"epilogue\":\"DEFAULT\" - } - }" - ROOT get-tuple-element = s32[65536,32] - get-tuple-element((s32[65536,32], s8[0]) gemm), index=0 -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); - xla::HloComputation* comp = module->entry_computation(); - const xla::HloInstruction* s8gemm = comp->GetInstructionWithName("gemm"); - int64_t gold_flops = 65536LL * 32800 * 32 * 2; - EXPECT_EQ(analysis_.flop_count(*s8gemm), gold_flops); - // Matmul of int8 * int8 -> int32, normalized it to equivalent fp16 flops by - // dividing by 2 as all inputs are 8 bits - EXPECT_EQ(analysis_.GetDeviceFlopsAdjustment(*s8gemm), gold_flops / 2); -} - -// test special handling logic when fp32 parameter is also used -TEST_F(XprofGpuHloCostAnalysisTest, Fp8GemmWithFp32ParameterAdjustment) { - absl::string_view hlo_string = R"( -HloModule r - -ENTRY e { - arg0 = f8e4m3fn[2048,5120]{1,0} parameter(0) - arg1 = f8e4m3fn[5120,5120]{0,1} parameter(1) - arg2 = f32[] parameter(2) - arg3 = f32[] parameter(3) - gemm = (bf16[2048,5120]{1,0}, s8[33554432]{0}) - custom-call(arg0, arg1, arg2, arg3), - custom_call_target="__cublas$lt$matmul$f8", - backend_config="{ - \"gemm_backend_config\": { - \"alpha_real\":1, - \"beta\":0, - \"dot_dimension_numbers\":{ - \"lhs_contracting_dimensions\":[\"1\"], - \"rhs_contracting_dimensions\":[\"0\"], - \"lhs_batch_dimensions\":[], - \"rhs_batch_dimensions\":[] - }, - \"alpha_imag\":0, - \"precision_config\":{ - \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] - }, - \"epilogue\":\"DEFAULT\", - \"lhs_stride\":\"10485760\", - \"rhs_stride\":\"26214400\", - \"grad_x\":false, - \"grad_y\":false, - \"damax_output\":false - } - }" - ROOT get-tuple-element = bf16[2048,5120]{1,0} - get-tuple-element((bf16[2048,5120]{1,0}, s8[33554432]{0}) gemm), index=0 -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); - xla::HloComputation* comp = module->entry_computation(); - const xla::HloInstruction* fp8_gemm = comp->GetInstructionWithName("gemm"); - int64_t gold_flops = 2048LL * 5120 * 5120 * 2; - EXPECT_EQ(analysis_.flop_count(*fp8_gemm), gold_flops); - // Matmul of int8 * int8 -> int32, normalized it to equivalent fp16 flops by - // dividing by 2 as all inputs are 8 bits - EXPECT_EQ(analysis_.GetDeviceFlopsAdjustment(*fp8_gemm), gold_flops / 2); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/public/session_options.h b/tensorflow/core/public/session_options.h index 92134528dbf975..3335046aa58d16 100644 --- a/tensorflow/core/public/session_options.h +++ b/tensorflow/core/public/session_options.h @@ -18,7 +18,6 @@ limitations under the License. #include -#include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" namespace tsl { diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index ddcfd2d0bec899..d7e3b8bcde89fe 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2180 // Updated: 2025/3/28 +#define TF_GRAPH_DEF_VERSION 2209 // Updated: 2025/4/26 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/tfrt/fallback/fallback_state.cc b/tensorflow/core/tfrt/fallback/fallback_state.cc index 6e3b98511f9e20..0a19bd1c5d7159 100644 --- a/tensorflow/core/tfrt/fallback/fallback_state.cc +++ b/tensorflow/core/tfrt/fallback/fallback_state.cc @@ -115,13 +115,13 @@ absl::StatusOr> FallbackState::CreateWithDeviceMgr( const SessionOptions &session_options, const tensorflow::FunctionDefLibrary &fdef_lib, - absl::Nonnull device_mgr) { + DynamicDeviceMgr *absl_nonnull device_mgr) { return std::make_unique(session_options, device_mgr, fdef_lib); } FallbackState::FallbackState(const SessionOptions &session_options, std::variant>, - absl::Nonnull> + DynamicDeviceMgr *absl_nonnull> device_mgr, const tensorflow::FunctionDefLibrary &fdef_lib) : session_options_(session_options), @@ -132,8 +132,8 @@ FallbackState::FallbackState(const SessionOptions &session_options, std::get>>(device_mgr)) : std::vector>()), device_manager_ptr_( - std::holds_alternative>(device_mgr) - ? std::get>(device_mgr) + std::holds_alternative(device_mgr) + ? std::get(device_mgr) : &device_manager_), func_lib_def_(OpRegistry::Global(), fdef_lib), pflr_(device_manager_ptr_, session_options.env, &session_options.config, diff --git a/tensorflow/core/tfrt/fallback/fallback_state.h b/tensorflow/core/tfrt/fallback/fallback_state.h index ffbf0695bafbad..703955448e6b90 100644 --- a/tensorflow/core/tfrt/fallback/fallback_state.h +++ b/tensorflow/core/tfrt/fallback/fallback_state.h @@ -56,11 +56,11 @@ class FallbackState { static absl::StatusOr> CreateWithDeviceMgr( const SessionOptions &session_options, const tensorflow::FunctionDefLibrary &fdef_lib, - absl::Nonnull device_mgr); + DynamicDeviceMgr *absl_nonnull device_mgr); FallbackState(const SessionOptions &session_options, std::variant>, - absl::Nonnull> + DynamicDeviceMgr *absl_nonnull> device_mgr, const tensorflow::FunctionDefLibrary &fdef_lib); @@ -93,7 +93,7 @@ class FallbackState { private: SessionOptions session_options_; DynamicDeviceMgr device_manager_; - absl::Nonnull device_manager_ptr_; + DynamicDeviceMgr *absl_nonnull device_manager_ptr_; DeviceSet device_set_; FunctionLibraryDefinition func_lib_def_; ProcessFunctionLibraryRuntime pflr_; diff --git a/tensorflow/core/tfrt/fallback/fallback_state_test.cc b/tensorflow/core/tfrt/fallback/fallback_state_test.cc index 3546992cfa7614..21f961556e3b49 100644 --- a/tensorflow/core/tfrt/fallback/fallback_state_test.cc +++ b/tensorflow/core/tfrt/fallback/fallback_state_test.cc @@ -49,7 +49,7 @@ TEST(FallbackStateTest, CreateWithCpuDeviceVector) { session_options, "/job:localhost/replica:0/task:0", &devices)); std::variant>, - absl::Nonnull> + DynamicDeviceMgr* absl_nonnull> device_variant = std::move(devices); auto fallback_state = std::make_unique( @@ -70,7 +70,7 @@ TEST(FallbackStateTest, CreateWithDynamicDeviceMgr) { auto static_device_mgr = std::make_unique(std::move(devices)); - absl::Nonnull device_mgr_ptr(static_device_mgr.get()); + DynamicDeviceMgr* absl_nonnull device_mgr_ptr(static_device_mgr.get()); auto fallback_state = std::make_unique( session_options, device_mgr_ptr, fdef_lib); diff --git a/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.cc b/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.cc index 8ca71ba8e25b88..fdac986b64e4c6 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.cc +++ b/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.cc @@ -142,10 +142,8 @@ void CaseOp::Invoke() { mlrt::bc::Vector attribute_function_indices = function_indices(); if (argument_branch_idx >= attribute_function_indices.size()) { - execution_context().Fail(absl::InvalidArgumentError( - absl::StrCat("Case branch number ", argument_branch_idx, - " exceeds limit ", attribute_function_indices.size()))); - return; + // Consistent with the behavior of the legacy TFRT case kernel behavior. + argument_branch_idx = attribute_function_indices.size() - 1; } auto function = diff --git a/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc b/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc index 34b53ac1fb1bd6..568949ac3792c0 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc +++ b/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc @@ -2825,6 +2825,43 @@ TEST(KernelTest, Case) { } } +TEST(KernelTest, CaseInvalidBranchIndexShallChooseLastBranch) { + auto buffer = CreateCaseExecutable(); + + bc::Executable executable(buffer.data()); + + KernelRegistry registry; + RegisterBuiltinKernels(registry); + LoadedExecutable loaded_executable(executable, registry); + + ExecutionContext execution_context(&loaded_executable); + + auto function = loaded_executable.GetFunction("main"); + ASSERT_TRUE(function); + + Value inputs[3]; + + constexpr int32_t kBranch0In = 123; + constexpr int32_t kBranch1In = 456; + + // Test Invalid Branch 10 + { + inputs[0].Set(10); + inputs[1].Set(kBranch0In); + inputs[2].Set(kBranch1In); + Value output; + + std::vector last_uses = {true, true, true}; + execution_context.Call(function, last_uses, absl::MakeSpan(inputs), + absl::Span(&output, 1)); + + Execute(execution_context); + + ASSERT_TRUE(output.HasValue()); + EXPECT_EQ(kBranch1In, output.Get()); + } +} + struct TestPromiseReturnOp : PromiseReturnOpBase { using PromiseReturnOpBase::PromiseReturnOpBase; diff --git a/tensorflow/core/tfrt/runtime/stream_test.cc b/tensorflow/core/tfrt/runtime/stream_test.cc index bcb8a14a553675..e0b73ebcb1c1fd 100644 --- a/tensorflow/core/tfrt/runtime/stream_test.cc +++ b/tensorflow/core/tfrt/runtime/stream_test.cc @@ -50,7 +50,29 @@ using ::testing::Pair; using ::testing::UnorderedElementsAre; using ::testing::status::StatusIs; -TEST(StreamTest, Simple) { +class TestStreamControllerInterface : public StreamControllerInterface { + public: + TestStreamControllerInterface() + : StreamControllerInterface("test_controller_address") {} +}; + +class StreamTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + GetGlobalStreamInterfaceFactory().RegisterController( + []() { return std::make_unique(); }); + } +}; + +TEST_F(StreamTest, Initialize) { + TF_ASSERT_OK_AND_ASSIGN( + auto controller_interface, + GetGlobalStreamInterfaceFactory().CreateControllerStreamInterface()); + EXPECT_EQ(controller_interface->controller_address(), + "test_controller_address"); +} + +TEST_F(StreamTest, Simple) { StreamCallbackId callback_id(1234); StepId step_id(5678); @@ -97,7 +119,7 @@ TEST(StreamTest, Simple) { EXPECT_THAT(status, StatusIs(absl::StatusCode::kAlreadyExists)); } -TEST(StreamTest, MultipleWriters) { +TEST_F(StreamTest, MultipleWriters) { StreamCallbackId callback_id(1234); StepId step_id(5678); @@ -146,22 +168,6 @@ TEST(StreamTest, MultipleWriters) { } } -class TestStreamControllerInterface : public StreamControllerInterface { - public: - TestStreamControllerInterface() - : StreamControllerInterface("test_controller_address") {} -}; - -TEST(StreamControllerInterface, Initialize) { - GetGlobalStreamInterfaceFactory().RegisterController( - []() { return std::make_unique(); }); - TF_ASSERT_OK_AND_ASSIGN( - auto controller_interface, - GetGlobalStreamInterfaceFactory().CreateControllerStreamInterface()); - EXPECT_EQ(controller_interface->controller_address(), - "test_controller_address"); -} - class TestStreamWorkerInterface : public StreamWorkerInterface { public: explicit TestStreamWorkerInterface(std::string worker_address) diff --git a/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.cc b/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.cc index 448e05d411d165..c2dcfbebe87734 100644 --- a/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.cc +++ b/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.h" +#include + #include #include #include diff --git a/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc b/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc index 05f1a9c3521822..ea46dc949dbd77 100644 --- a/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc +++ b/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc @@ -503,7 +503,7 @@ class TfrtSession : public tensorflow::Session { compile_options.device_target = device_target_; compile_options.tpu_fuse_ops = tpu_use_tpu_runner_; compile_options.hoist_invariant_ops = true; - compile_options.sink_in_invariant_ops = false; + compile_options.sink_in_invariant_ops = true; compile_options.cost_threshold = 1024; if (use_gpu_) { @@ -779,18 +779,22 @@ void TfrtSessionFactory::RegisterInitializer(RuntimeInitializer initializer) { absl::Status TfrtSessionFactory::InitializeLocked( const TfrtSessionOptions& options) { mutex_.AssertHeld(); + if (options.backend_compiler) { + backend_compiler_ = options.backend_compiler; + } if (options.use_tpu) { - DCHECK(!options.backend_compiler); DCHECK(!options.use_gpu); device_target_ = TfrtDeviceInfraTarget::kTpurt; - tpu_use_tpu_runner_ = true; + if (!options.backend_compiler) { + tpu_use_tpu_runner_ = true; + } } else if (options.use_gpu) { - DCHECK(!options.backend_compiler); device_target_ = TfrtDeviceInfraTarget::kGpu; - use_gpu_ = true; - } else if (options.backend_compiler) { - backend_compiler_ = options.backend_compiler; + if (!options.backend_compiler) { + use_gpu_ = true; + } } + LOG(INFO) << "Start initializing TfrtSession"; if (options.runtime != nullptr) { runtime_ = options.runtime; diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 8f14a5abac0c29..eb4c221c8f4384 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -192,6 +192,7 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:statusor", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -200,10 +201,12 @@ cc_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/builder/lib:arithmetic", "@local_xla//xla/hlo/builder/lib:slicing", "@local_xla//xla/stream_executor/tpu:c_api_decl", "@local_xla//xla/stream_executor/tpu:tpu_api", "@local_xla//xla/stream_executor/tpu:tpu_ops_c_api_hdrs", + "@local_xla//xla/tsl/platform:errors", ], alwayslink = 1, ) @@ -1504,6 +1507,9 @@ cc_library( "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@local_xla//xla:shape_util", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/stream_executor/tpu:status_helper", "@local_xla//xla/stream_executor/tpu:tpu_api", "@local_xla//xla/stream_executor/tpu:tpu_ops_c_api_hdrs", diff --git a/tensorflow/core/tpu/kernels/image_resize_ops.cc b/tensorflow/core/tpu/kernels/image_resize_ops.cc index 5bd2f21f55d7bb..7e255bab054550 100644 --- a/tensorflow/core/tpu/kernels/image_resize_ops.cc +++ b/tensorflow/core/tpu/kernels/image_resize_ops.cc @@ -50,8 +50,8 @@ class TpuCustomResizeOp : public XlaOpKernel { TF_ASSIGN_OR_RETURN(xla::Shape input_shape, ctx->InputXlaShape(0)); xla::Shape output_shape = TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0)); - output_shape.mutable_dimensions()[1] = out_size[0]; - output_shape.mutable_dimensions()[2] = out_size[1]; + output_shape.set_dimensions(1, out_size[0]); + output_shape.set_dimensions(2, out_size[1]); output_shape.set_dynamic_dimension(0, input_shape.is_dynamic_dimension(0)); output_shape.set_dynamic_dimension(3, input_shape.is_dynamic_dimension(3)); return output_shape; @@ -75,7 +75,7 @@ class TpuCustomResizeOp : public XlaOpKernel { if (input_shape.dimensions(1) / output_shape.dimensions(1) > 3 && input_shape.dimensions(2) / output_shape.dimensions(2) > 3) { auto intermediate_shape = output_shape; - intermediate_shape.mutable_dimensions()[1] = input_shape.dimensions(1); + intermediate_shape.set_dimensions(1, input_shape.dimensions(1)); input = xla::CustomCall(ctx->builder(), target, {ctx->Input(0)}, intermediate_shape, OpaqueField()); } diff --git a/tensorflow/core/tpu/kernels/infeed_ops.cc b/tensorflow/core/tpu/kernels/infeed_ops.cc index 17953799fff16c..d59c6c4b6d4683 100644 --- a/tensorflow/core/tpu/kernels/infeed_ops.cc +++ b/tensorflow/core/tpu/kernels/infeed_ops.cc @@ -89,7 +89,7 @@ absl::StatusOr TransposeTensor(OpKernelContext* ctx, const Tensor& input_tensor, const xla::Shape& xla_shape) { tsl::profiler::TraceMe trace_me("TransposeTensor", /*level=*/2); - const int64_t rank = xla_shape.dimensions_size(); + const int64_t rank = xla_shape.dimensions().size(); std::vector permutation(rank); std::vector transposed_shapes(rank); for (int64_t i = 0; i < rank; ++i) { diff --git a/tensorflow/core/tpu/kernels/sparse_core_ops_utils.cc b/tensorflow/core/tpu/kernels/sparse_core_ops_utils.cc index 52051b9e32fc3d..182f5bf29ca32b 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_ops_utils.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_ops_utils.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -31,6 +32,10 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/flags.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/tpu/status_helper.h" #include "xla/stream_executor/tpu/tpu_api.h" #include "xla/stream_executor/tpu/tpu_ops_c_api.h" @@ -224,4 +229,319 @@ ABSL_ATTRIBUTE_WEAK int64_t GetXlaSparseCoreStackingTableShardLimit() { return sparse_core_flags->tf_xla_sparse_core_stacking_table_shard_limit_bytes; } +xla::XlaOp ApplyWeightClippingToTable(xla::XlaBuilder* builder, + xla::XlaOp table, float clip_weight_min, + float clip_weight_max) { + xla::XlaOp clip_weight_min_op = xla::ConstantR0(builder, clip_weight_min); + xla::XlaOp clip_weight_max_op = xla::ConstantR0(builder, clip_weight_max); + xla::XlaOp clipped_table = + xla::Clamp(clip_weight_min_op, table, clip_weight_max_op); + return clipped_table; +} + +xla::XlaComputation BuildSgdOptimizerComputation(const int32_t feature_width, + const float clip_weight_min, + const float clip_weight_max) { + auto sgd_optimizer_builder = + std::make_unique("sgd_optimizer_builder"); + + xla::Shape per_row_shape = + xla::ShapeUtil::MakeShapeWithType({1, feature_width}); + + xla::XlaOp gradient = + xla::Parameter(sgd_optimizer_builder.get(), 0, per_row_shape, "gradient"); + + xla::XlaOp embedding_table = xla::Parameter(sgd_optimizer_builder.get(), 1, + per_row_shape, "embedding_table"); + + xla::XlaOp learning_rate = xla::Parameter(sgd_optimizer_builder.get(), 2, + per_row_shape, "learning_rate"); + + xla::XlaOp updated_embedding_table = + embedding_table - learning_rate * gradient; + + // Apply the weight clipping. + xla::XlaOp clipped_embedding_table = ApplyWeightClippingToTable( + sgd_optimizer_builder.get(), updated_embedding_table, clip_weight_min, + clip_weight_max); + + xla::XlaOp updated_tables = + xla::Tuple(sgd_optimizer_builder.get(), {clipped_embedding_table}); + + return sgd_optimizer_builder->Build(updated_tables).value(); +} + +xla::XlaComputation BuildAdagradOptimizerComputation( + const int32_t feature_width, const float clip_weight_min, + const float clip_weight_max) { + auto adagrad_optimizer_builder = + std::make_unique("adagrad_optimizer_builder"); + + xla::Shape per_row_shape = + xla::ShapeUtil::MakeShapeWithType({1, feature_width}); + + xla::XlaOp gradient = xla::Parameter(adagrad_optimizer_builder.get(), 0, + per_row_shape, "gradient"); + + xla::XlaOp embedding_table = xla::Parameter( + adagrad_optimizer_builder.get(), 1, per_row_shape, "embedding_table"); + + xla::XlaOp accumulator = xla::Parameter(adagrad_optimizer_builder.get(), 2, + per_row_shape, "accumulator"); + + xla::XlaOp learning_rate = xla::Parameter(adagrad_optimizer_builder.get(), 3, + per_row_shape, "learning_rate"); + + xla::XlaOp new_accumulator = accumulator + gradient * gradient; + + xla::XlaOp updated_embedding_table = + embedding_table - learning_rate * gradient / xla::Sqrt(new_accumulator); + + // Apply the weight clipping. + xla::XlaOp clipped_embedding_table = ApplyWeightClippingToTable( + adagrad_optimizer_builder.get(), updated_embedding_table, clip_weight_min, + clip_weight_max); + + xla::XlaOp updated_tables = + xla::Tuple(adagrad_optimizer_builder.get(), + {clipped_embedding_table, new_accumulator}); + return adagrad_optimizer_builder->Build(updated_tables).value(); +} + +xla::XlaComputation BuildAdagradMomentumOptimizerComputation( + const int32_t feature_width, const bool use_nesterov, const float exponent, + const float beta1, const float beta2, const float epsilon, + const float clip_weight_min, const float clip_weight_max) { + auto adagrad_momentum_optimizer_builder = + std::make_unique("adagrad_momentum_optimizer_builder"); + + xla::Shape per_row_shape = + xla::ShapeUtil::MakeShapeWithType({1, feature_width}); + + xla::XlaOp gradient = xla::Parameter(adagrad_momentum_optimizer_builder.get(), + 0, per_row_shape, "gradient"); + xla::XlaOp embedding_table = + xla::Parameter(adagrad_momentum_optimizer_builder.get(), 1, per_row_shape, + "embedding_table"); + xla::XlaOp accumulator = + xla::Parameter(adagrad_momentum_optimizer_builder.get(), 2, per_row_shape, + "accumulator"); + xla::XlaOp momenta = xla::Parameter(adagrad_momentum_optimizer_builder.get(), + 3, per_row_shape, "momenta"); + xla::XlaOp learning_rate = + xla::Parameter(adagrad_momentum_optimizer_builder.get(), 4, per_row_shape, + "learning_rate"); + + xla::XlaOp beta1_op = + xla::ConstantR0(adagrad_momentum_optimizer_builder.get(), beta1); + xla::XlaOp beta2_op = + xla::ConstantR0(adagrad_momentum_optimizer_builder.get(), beta2); + xla::XlaOp epsilon_op = + xla::ConstantR0(adagrad_momentum_optimizer_builder.get(), epsilon); + + // If beta_2 == 1: + // accumulator(t) = accumulator(t-1) + gradient(t) ^ 2 + // Else: + // accumulator(t) = beta_2 * accumulator(t-1) + + // (1-beta_2) * gradient(t) ^ 2 + xla::XlaOp exponent_op = xla::ConstantR0( + adagrad_momentum_optimizer_builder.get(), 1.0f / exponent); + xla::XlaOp one = + xla::ConstantR0(adagrad_momentum_optimizer_builder.get(), 1.0f); + + xla::XlaOp new_accumulator = xla::Select( + xla::Eq(beta2_op, one), accumulator + gradient * gradient, + beta2_op * accumulator + (one - beta2_op) * gradient * gradient); + + // scaled_gradient = (accumulator + epsilon)^(-1/k) * gradient + xla::XlaOp scaled_gradients = + Pow(new_accumulator + epsilon_op, xla::Neg(exponent_op)) * gradient; + + // momenta(t) = beta1 * momenta(t-1) + scaled_gradient(t) + xla::XlaOp new_momenta = beta1_op * momenta + scaled_gradients; + + // Table update: + // non-nesterov: update = momenta_t + // nesterov: update = beta_1 * momenta_t + scaled_gradient + // weights(t) = weights(t-1) - lr * update + xla::XlaOp updated_embedding_table; + if (use_nesterov) { + updated_embedding_table = + embedding_table - + learning_rate * (beta1_op * new_momenta + scaled_gradients); + } else { + updated_embedding_table = embedding_table - learning_rate * new_momenta; + } + + // Apply the weight clipping. + xla::XlaOp clipped_embedding_table = ApplyWeightClippingToTable( + adagrad_momentum_optimizer_builder.get(), updated_embedding_table, + clip_weight_min, clip_weight_max); + + xla::XlaOp updated_tables = + xla::Tuple(adagrad_momentum_optimizer_builder.get(), + {clipped_embedding_table, new_accumulator, new_momenta}); + return adagrad_momentum_optimizer_builder->Build(updated_tables).value(); +} + +xla::XlaComputation BuildAdamOptimizerComputation( + const int32_t feature_width, const bool use_sum_inside_sqrt, + const float beta1, const float beta2, const float epsilon, + const float clip_weight_min, const float clip_weight_max) { + auto adam_optimizer_builder = + std::make_unique("adam_optimizer_builder"); + + xla::Shape per_row_shape = + xla::ShapeUtil::MakeShapeWithType({1, feature_width}); + + xla::XlaOp gradient = xla::Parameter(adam_optimizer_builder.get(), 0, + per_row_shape, "gradient"); + xla::XlaOp embedding_table = xla::Parameter(adam_optimizer_builder.get(), 1, + per_row_shape, "embedding_table"); + xla::XlaOp momenta = + xla::Parameter(adam_optimizer_builder.get(), 2, per_row_shape, "momenta"); + xla::XlaOp velocity = xla::Parameter(adam_optimizer_builder.get(), 3, + per_row_shape, "velocity"); + xla::XlaOp learning_rate = xla::Parameter(adam_optimizer_builder.get(), 4, + per_row_shape, "learning_rate"); + + xla::XlaOp beta1_op = xla::ConstantR0(adam_optimizer_builder.get(), beta1); + xla::XlaOp beta2_op = xla::ConstantR0(adam_optimizer_builder.get(), beta2); + xla::XlaOp epsilon_op = + xla::ConstantR0(adam_optimizer_builder.get(), epsilon); + + // Depending on sum_inside_sqrt, the denominator is either: + // sum_inside_sqrt==true: sqrt(v + eps^2) + // sum_inside_sqrt==false: sqrt(v) + eps + // To simplify the for loop below, write the sqrt denominator as: + // sqrt(v + e1) + e2 + // and set e1 and e2 appropriately: + xla::XlaOp zero = xla::ConstantR0(adam_optimizer_builder.get(), 0.0f); + xla::XlaOp one = xla::ConstantR0(adam_optimizer_builder.get(), 1.0f); + xla::XlaOp e1 = use_sum_inside_sqrt ? epsilon_op * epsilon_op : zero; + xla::XlaOp e2 = use_sum_inside_sqrt ? zero : epsilon_op; + + // momentum(t) = beta_1 * momentum(t-1) + // + (1-beta_1)*gradient(t) + xla::XlaOp new_momenta = beta1_op * momenta + (one - beta1_op) * gradient; + + // velocity(t) = beta_2 * velocity(t-1) + // + (1-beta_2)*gradient(t)*gradient(t) + xla::XlaOp new_velocity = + beta2_op * velocity + (one - beta2_op) * gradient * gradient; + + xla::XlaOp updated_embedding_table = + embedding_table - + learning_rate * new_momenta / (xla::Sqrt(new_velocity + e1) + e2); + + // Apply the weight clipping. + xla::XlaOp clipped_embedding_table = ApplyWeightClippingToTable( + adam_optimizer_builder.get(), updated_embedding_table, clip_weight_min, + clip_weight_max); + + xla::XlaOp updated_tables = + xla::Tuple(adam_optimizer_builder.get(), + {clipped_embedding_table, new_momenta, new_velocity}); + return adam_optimizer_builder->Build(updated_tables).value(); +} + +xla::XlaComputation BuildFtrlOptimizerComputation( + int32_t feature_width, bool multiply_linear_by_learning_rate, float beta, + float learning_rate_power, float l1_regularization_strength, + float l2_regularization_strength, float clip_weight_min, + float clip_weight_max) { + auto ftrl_optimizer_builder = + std::make_unique("ftrl_optimizer_builder"); + + xla::Shape per_row_shape = + xla::ShapeUtil::MakeShapeWithType({1, feature_width}); + + xla::XlaOp gradient = xla::Parameter(ftrl_optimizer_builder.get(), 0, + per_row_shape, "gradient"); + + xla::XlaOp embedding_table = xla::Parameter(ftrl_optimizer_builder.get(), 1, + per_row_shape, "embedding_table"); + xla::XlaOp accumulator = xla::Parameter(ftrl_optimizer_builder.get(), 2, + per_row_shape, "accumulator"); + xla::XlaOp linear = + xla::Parameter(ftrl_optimizer_builder.get(), 3, per_row_shape, "linear"); + xla::XlaOp learning_rate = xla::Parameter(ftrl_optimizer_builder.get(), 4, + per_row_shape, "learning_rate"); + + // accumulator(t) = accumulator(t-1) + gradient(t) ^ 2 + xla::XlaOp new_accumulator = accumulator + gradient * gradient; + + xla::XlaOp learning_rate_power_op = + xla::ConstantR0(ftrl_optimizer_builder.get(), learning_rate_power); + + xla::XlaOp power_old = Pow(accumulator, xla::Neg(learning_rate_power_op)); + xla::XlaOp power_new = Pow(new_accumulator, xla::Neg(learning_rate_power_op)); + xla::XlaOp delta_p = power_new - power_old; + + xla::XlaOp zero = xla::ConstantR0(ftrl_optimizer_builder.get(), 0.0f); + + xla::XlaOp two = xla::ConstantR0(ftrl_optimizer_builder.get(), 2.0f); + + xla::XlaOp l1_regularization_strength_op = + xla::ConstantR0(ftrl_optimizer_builder.get(), l1_regularization_strength); + + xla::XlaOp l2_regularization_strength_op = + xla::ConstantR0(ftrl_optimizer_builder.get(), l2_regularization_strength); + + xla::XlaOp beta_op = xla::ConstantR0(ftrl_optimizer_builder.get(), beta); + + // Note: + // min(|linear(t)|, lr*l1)*sgn(linear(t)) + // can be written as + // clamp( -lr*l1, linear(t), lr*l1) + // assuming lr>0 and l1>0. + xla::XlaOp new_linear; + xla::XlaOp numer; + xla::XlaOp denom; + if (multiply_linear_by_learning_rate) { + new_linear = linear + learning_rate * gradient - delta_p * embedding_table; + // if multiply_linear: + // linear(t) = linear(t-1) + lr*g - delta_p * table(t-1) + // Update numerator: + // N = min(|linear(t)|, lr*l1)*sgn(linear(t)) - linear(t) + // Update denomninator: + // D = power(t) + 2*lr*l2 + beta + // table(t) = N / D + numer = xla::Select( + xla::Eq(l1_regularization_strength_op, zero), xla::Neg(new_linear), + xla::Clamp(xla::Neg(learning_rate * l1_regularization_strength_op), + new_linear, learning_rate * l1_regularization_strength_op) - + new_linear); + denom = power_new + two * learning_rate * l2_regularization_strength_op + + beta_op; + } else { + new_linear = linear + gradient - delta_p * embedding_table / learning_rate; + // if NOT multiply_linear: + // linear(t) = linear(t-1) + g - (1/lr) * delta_p * table(t-1) + // Update numerator: + // N = min(|linear(t)|, l1)*sgn(linear(t)) - linear(t) + // Update denomninator: + // D = (1/lr) * (power(t) + beta) + 2*l2 + // table(t) = N / D + numer = xla::Select(xla::Eq(l1_regularization_strength_op, zero), + xla::Neg(new_linear), + xla::Clamp(xla::Neg(l1_regularization_strength_op), + new_linear, l1_regularization_strength_op) - + new_linear); + denom = (power_new + beta_op) / learning_rate + + two * l2_regularization_strength_op; + } + xla::XlaOp updated_embedding_table = numer / denom; + + // Apply the weight clipping. + xla::XlaOp clipped_embedding_table = ApplyWeightClippingToTable( + ftrl_optimizer_builder.get(), updated_embedding_table, clip_weight_min, + clip_weight_max); + + xla::XlaOp updated_tables = + xla::Tuple(ftrl_optimizer_builder.get(), + {clipped_embedding_table, new_accumulator, new_linear}); + return ftrl_optimizer_builder->Build(updated_tables).value(); +} + } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/sparse_core_ops_utils.h b/tensorflow/core/tpu/kernels/sparse_core_ops_utils.h index dc9b028edca4cc..72419504760aa6 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_ops_utils.h +++ b/tensorflow/core/tpu/kernels/sparse_core_ops_utils.h @@ -23,6 +23,8 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/compiler/jit/flags.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" @@ -70,6 +72,32 @@ absl::Status GetMaxIdsAndUniquesExternal(const std::string& program_key, int64_t* max_ids_per_partition, int64_t* max_unique_ids_per_partition); +xla::XlaOp ApplyWeightClippingToTable(xla::XlaBuilder* builder, + xla::XlaOp table, float clip_weight_min, + float clip_weight_max); + +xla::XlaComputation BuildSgdOptimizerComputation(int32_t feature_width, + float clip_weight_min, + float clip_weight_max); + +xla::XlaComputation BuildAdagradOptimizerComputation(int32_t feature_width, + float clip_weight_min, + float clip_weight_max); + +xla::XlaComputation BuildAdagradMomentumOptimizerComputation( + int32_t feature_width, bool use_nesterov, float exponent, float beta1, + float beta2, float epsilon, float clip_weight_min, float clip_weight_max); + +xla::XlaComputation BuildAdamOptimizerComputation( + int32_t feature_width, bool use_sum_inside_sqrt, float beta1, float beta2, + float epsilon, float clip_weight_min, float clip_weight_max); + +xla::XlaComputation BuildFtrlOptimizerComputation( + int32_t feature_width, bool multiply_linear_by_learning_rate, float beta, + float learning_rate_power, float l1_regularization_strength, + float l2_regularization_strength, float clip_weight_min, + float clip_weight_max); + } // namespace tensorflow #endif // TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_OPS_UTILS_H_ diff --git a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc index ecfb757e71d335..9ae04c246d9f92 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/lib/slicing.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" @@ -37,7 +39,9 @@ limitations under the License. #include "xla/stream_executor/tpu/c_api_decl.h" #include "xla/stream_executor/tpu/tpu_api.h" #include "xla/stream_executor/tpu/tpu_ops_c_api.h" +#include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" @@ -64,9 +68,52 @@ namespace tensorflow { namespace { // Get the SparseCore logical replica count. -absl::StatusOr GetSparseCoresPerChip() { - return stream_executor::tpu::OpsApiFn()->TpuTopology_AvailableCoresPerChipFn( - /*tpu_core_type=*/TpuCoreTypeEnum::kEmbeddingV2); +absl::StatusOr GetSparseCoresPerLogicalDevice() { + return stream_executor::tpu::OpsApiFn() + ->TpuTopology_MaybeAvailableSparseCoresPerLogicalDeviceFn( + /*tpu_core_type=*/TpuCoreTypeEnum::kEmbeddingV2); +} + +// Helper function to get the number of sparsecores per device from the topology +// if available. Or if not, from the op's attribute. +void GetAndSetSparseCoresPerLogicalDevice(OpKernelConstruction* ctx, + int64_t& num_sparsecores_per_device) { + // Try to get the number of sparsecores per chip from topology. + absl::StatusOr num_from_topology = GetSparseCoresPerLogicalDevice(); + int64_t num_sparsecores_from_attribute; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_sparsecores_per_device", + &num_sparsecores_from_attribute)); + + if (num_from_topology.ok()) { + num_sparsecores_per_device = num_from_topology.value(); + // Verify that the attribute is consistent with the topology value. + OP_REQUIRES( + ctx, + num_sparsecores_from_attribute == -1 || + num_sparsecores_from_attribute == num_sparsecores_per_device, + absl::InvalidArgumentError(absl::StrCat( + "The op's attribute num_sparsecores_per_device: ", + num_sparsecores_per_device, + " is not consistent with the value discovered from the topology: ", + num_sparsecores_from_attribute))); + } else { + // Fall back to the attribute if topology is not available or failed.; + num_sparsecores_per_device = num_sparsecores_from_attribute; + } + + // Validate the final value. + OP_REQUIRES( + ctx, num_sparsecores_per_device == 2 || num_sparsecores_per_device == 4, + absl::InvalidArgumentError( + absl::StrCat("num_sparsecores_per_device must be 2 or 4, but got: ", + num_sparsecores_per_device))); +} + +// Returns the number of ops in the tuple. +absl::StatusOr GetTupleOpSize(xla::XlaBuilder* builder, + xla::XlaOp tuple_op) { + TF_ASSIGN_OR_RETURN(xla::Shape tuple_shape, builder->GetShape(tuple_op)); + return tuple_shape.tuple_shapes().size(); } // This TensorFlow op performs the embedding lookup on SparseCore. It takes the @@ -143,8 +190,6 @@ class XlaSparseDenseMatmulOp : public XlaOpKernel { // Pack the input tensors as a tuple. This is a intermediate stage before // switching to SparseTensor type. - xla::XlaOp coo_tensor_input = - xla::Tuple(builder, {row_ids, col_ids, values}); new_frontend_attributes.mutable_map()->insert( {"_xla_sharding_strategy", "mod"}); @@ -160,7 +205,7 @@ class XlaSparseDenseMatmulOp : public XlaOpKernel { xla::XlaOp result = xla::CustomCall( builder, "SparseDenseMatmulOp", - {coo_tensor_input, embedding_table, offsets, activation_init}, + {row_ids, col_ids, values, embedding_table, offsets, activation_init}, xla::ShapeUtil::MakeTupleShape({activation_shape, row_pointers_shape, sorted_ids_shape, sorted_ids_shape, sorted_gains_shape})); @@ -201,12 +246,10 @@ class XlaSparseDenseMatmulWithCsrInputOp : public XlaOpKernel { : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_name_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("input_size", &input_size_)); - OP_REQUIRES_VALUE(num_sparsecores_per_chip_, ctx, GetSparseCoresPerChip()); - OP_REQUIRES(ctx, input_size_ % num_sparsecores_per_chip_ == 0, - errors::InvalidArgument("input_size_ ", input_size_, - " not divisible by the number " - "of sparsecores per chip ", - num_sparsecores_per_chip_)); + + // Try to get the number of sparsecores per chip from topology. And fall + // back to the attribute if the topology is not available. + GetAndSetSparseCoresPerLogicalDevice(ctx, num_sparsecores_per_device_); // Get and save quantization config params, if they were configured. // num_buckets == 0 indicate no quantization configs were provided. @@ -247,7 +290,7 @@ class XlaSparseDenseMatmulWithCsrInputOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { int64_t per_sparse_core_batch_size = - input_size_ / num_sparsecores_per_chip_; + input_size_ / num_sparsecores_per_device_; int64_t max_ids_per_partition = 0; int64_t max_unique_ids_per_partition = 0; @@ -340,9 +383,9 @@ class XlaSparseDenseMatmulWithCsrInputOp : public XlaOpKernel { ctx->SetOutput(0, result); } - private: + protected: int input_size_; - int64_t num_sparsecores_per_chip_; + int64_t num_sparsecores_per_device_; std::optional quantization_config_low_; std::optional quantization_config_high_; std::optional quantization_config_num_buckets_; @@ -357,6 +400,191 @@ class XlaSparseDenseMatmulWithCsrInputOp : public XlaOpKernel { REGISTER_XLA_OP(Name("XlaSparseDenseMatmulWithCsrInput"), XlaSparseDenseMatmulWithCsrInputOp); +// Similar to XlaSparseDenseMatmulWithCsrInputOp, but with an additional field +// `sorted_pos_ids` in the input Csr, `weights` which is a tensor of shape +// [num_weights] to be used by the `combiner_computation`. It produces the same +// embedding look up result as `XlaSparseDenseMatmulWithCsrInputOp`. +class XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp + : public XlaSparseDenseMatmulWithCsrInputOp { + public: + explicit XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp( + OpKernelConstruction* ctx) + : XlaSparseDenseMatmulWithCsrInputOp(ctx) { + const NameAttrList* name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_valency", &max_valency_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_weights", &num_weights_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("combiner_computation", &name_attr)); + combiner_computation_ = *name_attr; + } + + ~XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp() override = default; + + absl::StatusOr BuildTcCustomCombinerComputation( + XlaOpKernelContext* ctx, const int32_t feature_width) { + XlaCompiler::CompileOptions options; + options.use_tuple_arg = false; + options.always_return_tuple = false; + options.is_entry_computation = false; + + XlaCompiler* compiler = ctx->compiler(); + XlaCompiler::CompilationResult custom_combiner_computation_result; + + XlaCompiler::Argument valencies_arg; + XlaCompiler::Argument vectors_arg; + + valencies_arg.kind = XlaCompiler::Argument::kParameter; + valencies_arg.type = DT_INT32; + valencies_arg.shape = xla::ShapeUtil::MakeShape(xla::S32, {input_size_}); + valencies_arg.name = "valencies"; + vectors_arg.kind = XlaCompiler::Argument::kParameter; + vectors_arg.type = DT_FLOAT; + vectors_arg.shape = xla::ShapeUtil::MakeShape( + xla::F32, {input_size_, max_valency_, feature_width}); + vectors_arg.name = "vectors"; + + std::vector arguments = {valencies_arg, vectors_arg}; + + // Don't add the weights argument if it's not needed. This helps avoid + // issues of passing around zero-sized tensors and Xla values. + if (num_weights_ > 0) { + XlaCompiler::Argument weights_arg; + weights_arg.kind = XlaCompiler::Argument::kParameter; + weights_arg.type = DT_FLOAT; + weights_arg.shape = + xla::ShapeUtil::MakeShape(xla::F32, {input_size_, num_weights_}); + weights_arg.name = "weights"; + arguments.push_back(weights_arg); + } + + TF_RETURN_IF_ERROR( + compiler->CompileFunction(options, combiner_computation_, arguments, + &custom_combiner_computation_result)); + return std::move(*custom_combiner_computation_result.computation); + } + + void Compile(XlaOpKernelContext* ctx) override { + int64_t per_sparse_core_batch_size = + input_size_ / num_sparsecores_per_device_; + int64_t max_ids_per_partition = 0; + int64_t max_unique_ids_per_partition = 0; + + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp row_pointers = ctx->Input("row_pointers"); + xla::XlaOp sorted_sample_ids = ctx->Input("sorted_sample_ids"); + xla::XlaOp sorted_token_ids = ctx->Input("sorted_token_ids"); + xla::XlaOp sorted_pos_ids = ctx->Input("sorted_pos_ids"); + xla::XlaOp sorted_gains = ctx->Input("sorted_gains"); + xla::XlaOp embedding_table = ctx->Input("embedding_table"); + + OP_REQUIRES_VALUE(xla::Shape embedding_table_shape, ctx, + ctx->InputXlaShape("embedding_table")); + const int32_t feature_width = embedding_table_shape.dimensions(1); + + OP_REQUIRES_OK( + ctx, GetMaxIdsAndUniques(per_sparse_core_batch_size, feature_width, + &max_ids_per_partition, + &max_unique_ids_per_partition)); + // Log max_ids and max_uniques for offline analysis. We do this here since + // these values are fixed at TPU compile time and remain fixed during + // training. + max_ids_per_partition_gauge_->GetCell(device_name_, table_name_) + ->Set(max_ids_per_partition); + max_unique_ids_per_partition_gauge_->GetCell(device_name_, table_name_) + ->Set(max_unique_ids_per_partition); + LOG(INFO) << "Lowering " + "XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp to HLO: " + << "table_name = '" << table_name_ + << "', max_ids = " << max_ids_per_partition + << ", max_uniques = " << max_unique_ids_per_partition; + + xla::FrontendAttributes tc_frontend_attributes; + xla::FrontendAttributes sc_frontend_attributes; + + sc_frontend_attributes.mutable_map()->insert( + {"_xla_compute_type", "sparse"}); + + sc_frontend_attributes.mutable_map()->insert( + {"_xla_sharding_strategy", "mod"}); + + sc_frontend_attributes.mutable_map()->insert( + {"_xla_pad_value", absl::StrCat(kXlaPadValue)}); + + sc_frontend_attributes.mutable_map()->insert( + {"_xla_max_ids_per_partition", absl::StrCat(max_ids_per_partition)}); + + sc_frontend_attributes.mutable_map()->insert( + {"_xla_max_unique_ids_per_partition", + absl::StrCat(max_unique_ids_per_partition)}); + + sc_frontend_attributes.mutable_map()->insert( + {"_xla_max_valency", absl::StrCat(max_valency_)}); + + if (quantization_config_low_.has_value()) { + sc_frontend_attributes.mutable_map()->insert( + {"_xla_quantization_high_value", + absl::StrCat(quantization_config_high_.value())}); + sc_frontend_attributes.mutable_map()->insert( + {"_xla_quantization_low_value", + absl::StrCat(quantization_config_low_.value())}); + sc_frontend_attributes.mutable_map()->insert( + {"_xla_quantization_num_buckets_value", + absl::StrCat(quantization_config_num_buckets_.value())}); + } + + tc_frontend_attributes = + builder->SwapFrontendAttributes(sc_frontend_attributes); + + // Emit the custom call that performs the SC embedding lookup. + xla::Shape valencies_shape = + xla::ShapeUtil::MakeShape(xla::S32, {input_size_}); + xla::Shape vectors_shape = xla::ShapeUtil::MakeShape( + xla::F32, {input_size_, max_valency_, feature_width}); + xla::Shape gains_shape = + xla::ShapeUtil::MakeShape(xla::F32, {input_size_, max_valency_}); + xla::XlaOp sc_lookup_result_tuple = xla::CustomCall( + builder, "SparseDenseMatmulCustomCombinerTcCombinerMegachipOp", + {row_pointers, sorted_token_ids, sorted_sample_ids, sorted_pos_ids, + sorted_gains, embedding_table}, + xla::ShapeUtil::MakeTupleShape( + {valencies_shape, vectors_shape, gains_shape})); + + // Emit the custom combiner computation into an HLO computation. + OP_REQUIRES_VALUE(xla::XlaComputation custom_combiner_tc_computation, ctx, + BuildTcCustomCombinerComputation(ctx, feature_width)); + + builder->SetFrontendAttributes(tc_frontend_attributes); + + xla::XlaOp valencies = xla::GetTupleElement(sc_lookup_result_tuple, 0); + xla::XlaOp vectors = xla::GetTupleElement(sc_lookup_result_tuple, 1); + + std::vector tc_combiner_args = {valencies, vectors}; + if (num_weights_ > 0) { + xla::XlaOp weights = ctx->Input("weights"); + tc_combiner_args.push_back(xla::Broadcast(weights, {input_size_})); + } + + xla::XlaOp tc_activations = + xla::Call(builder, custom_combiner_tc_computation, tc_combiner_args); + + ctx->SetOutput(0, tc_activations); + ctx->SetOutput(1, valencies); + ctx->SetOutput(2, vectors); + } + + private: + int max_valency_; + int num_weights_; + NameAttrList combiner_computation_; + + XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp( + const XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp&) = delete; + void operator=(const XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp&) = + delete; +}; + +REGISTER_XLA_OP(Name("XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput"), + XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp); + // Base class for all the minibatch with CSR input optimizer kernel. class XlaSparseDenseMatmulGradWithCsrInputBase : public XlaOpKernel { public: @@ -366,6 +594,10 @@ class XlaSparseDenseMatmulGradWithCsrInputBase : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("clip_weight_min", &clip_weight_min_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("clip_weight_max", &clip_weight_max_)); + // Try to get the number of sparsecores per chip from topology. And fall + // back to the attribute if the topology is not available. + GetAndSetSparseCoresPerLogicalDevice(ctx, num_sparsecores_per_device_); + OP_REQUIRES(ctx, clip_weight_min_ <= clip_weight_max_, absl::InvalidArgumentError( absl::StrCat("clip_weight_min must be smaller or equal to " @@ -385,15 +617,6 @@ class XlaSparseDenseMatmulGradWithCsrInputBase : public XlaOpKernel { virtual xla::Shape get_tables_shape(xla::Shape embedding_table_shape) = 0; - xla::XlaOp apply_weight_clipping_to_table(xla::XlaBuilder* builder, - xla::XlaOp table) { - xla::XlaOp clip_weight_min = xla::ConstantR0(builder, clip_weight_min_); - xla::XlaOp clip_weight_max = xla::ConstantR0(builder, clip_weight_max_); - xla::XlaOp clipped_table = - xla::Clamp(clip_weight_min, table, clip_weight_max); - return clipped_table; - } - virtual absl::Status GetMaxIdsAndUniques( int64_t num_samples_per_sparse_core, int64_t feature_width, int64_t* max_ids_per_partition, int64_t* max_unique_ids_per_partition) { @@ -423,16 +646,14 @@ class XlaSparseDenseMatmulGradWithCsrInputBase : public XlaOpKernel { errors::InvalidArgument( "activations input has non static or non-rank 2 shape: ", activation_shape.ToString())); - OP_REQUIRES_VALUE(int64_t num_sparsecores_per_chip, ctx, - GetSparseCoresPerChip()); int64 num_samples_per_chip = activation_shape.dimensions(0); - OP_REQUIRES(ctx, num_samples_per_chip % num_sparsecores_per_chip == 0, + OP_REQUIRES(ctx, num_samples_per_chip % num_sparsecores_per_device_ == 0, errors::InvalidArgument( "num_samples_per_chip ", num_samples_per_chip, " not divisible by the number of sparsecores per chip ", - num_sparsecores_per_chip)); + num_sparsecores_per_device_)); int64_t per_sparse_core_batch_size = - num_samples_per_chip / num_sparsecores_per_chip; + num_samples_per_chip / num_sparsecores_per_device_; int64_t max_ids_per_partition = 0; int64_t max_unique_ids_per_partition = 0; OP_REQUIRES_VALUE(xla::Shape embedding_table_shape, ctx, @@ -509,6 +730,7 @@ class XlaSparseDenseMatmulGradWithCsrInputBase : public XlaOpKernel { private: std::string table_name_; + int64_t num_sparsecores_per_device_; XlaSparseDenseMatmulGradWithCsrInputBase( const XlaSparseDenseMatmulGradWithCsrInputBase&) = delete; @@ -522,6 +744,11 @@ class XlaSparseDenseMatmulGradWithCsrInputOp : public XlaOpKernel { const NameAttrList* name_attr; OP_REQUIRES_OK(ctx, ctx->GetAttr("custom_computation", &name_attr)); OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_name_)); + + // Try to get the number of sparsecores per chip from topology. And fall + // back to the attribute if the topology is not available. + GetAndSetSparseCoresPerLogicalDevice(ctx, num_sparsecores_per_device_); + custom_computation_ = *name_attr; } @@ -555,17 +782,16 @@ class XlaSparseDenseMatmulGradWithCsrInputOp : public XlaOpKernel { absl::InvalidArgumentError(absl::StrCat( "activations input has non static or non-rank 2 shape: ", activation_shape.ToString()))); - OP_REQUIRES_VALUE(int64_t num_sparsecores_per_chip, ctx, - GetSparseCoresPerChip()); + int64_t num_samples_per_chip = activation_shape.dimensions(0); - OP_REQUIRES(ctx, num_samples_per_chip % num_sparsecores_per_chip == 0, + OP_REQUIRES(ctx, num_samples_per_chip % num_sparsecores_per_device_ == 0, absl::InvalidArgumentError(absl::StrCat( "num_samples_per_chip ", num_samples_per_chip, " not divisible by the number of sparsecores per chip ", - num_sparsecores_per_chip))); + num_sparsecores_per_device_))); int64_t per_sparse_core_batch_size = - num_samples_per_chip / num_sparsecores_per_chip; + num_samples_per_chip / num_sparsecores_per_device_; int64_t max_ids_per_partition = 0; int64_t max_unique_ids_per_partition = 0; @@ -679,6 +905,7 @@ class XlaSparseDenseMatmulGradWithCsrInputOp : public XlaOpKernel { private: std::string table_name_; NameAttrList custom_computation_; + int64_t num_sparsecores_per_device_; XlaSparseDenseMatmulGradWithCsrInputOp( const XlaSparseDenseMatmulGradWithCsrInputOp&) = delete; void operator=(const XlaSparseDenseMatmulGradWithCsrInputOp&) = delete; @@ -687,6 +914,703 @@ class XlaSparseDenseMatmulGradWithCsrInputOp : public XlaOpKernel { REGISTER_XLA_OP(Name("XlaSparseDenseMatmulGradWithCsrInput"), XlaSparseDenseMatmulGradWithCsrInputOp); +class XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase + : public XlaOpKernel { + public: + explicit XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase( + OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_valency", &max_valency_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_weights", &num_weights_)); + + // Not all subclasses have the weight range attributes. We parse these + // attributes anyway (otherwise we lose the op construction context) and + // record possible errors. The main compile method can choose to report + // errors or not (depending if the attributes are expected to be present). + clip_weight_range_status_.Update( + ctx->GetAttr("clip_weight_min", &clip_weight_min_)); + clip_weight_range_status_.Update( + ctx->GetAttr("clip_weight_max", &clip_weight_max_)); + if (clip_weight_range_status_.ok() && clip_weight_min_ > clip_weight_max_) { + clip_weight_range_status_ = absl::InvalidArgumentError(absl::StrCat( + "clip_weight_min must be smaller or equal to " + "clip_weight_max but got clip_weight_min as ", + clip_weight_min_, " and clip_weight_max as ", clip_weight_max_, ".")); + } + + const NameAttrList* name_attr; + OP_REQUIRES_OK(ctx, + ctx->GetAttr("combiner_table_vjp_computation", &name_attr)); + combiner_lookups_custom_vjp_computation_ = *name_attr; + OP_REQUIRES_OK( + ctx, ctx->GetAttr("combiner_weights_vjp_computation", &name_attr)); + combiner_weights_custom_vjp_computation_ = *name_attr; + } + + virtual absl::Status HandleClipWeightRangeStatus() { + // Most subclasses require the weight range attributes, and we return the + // status as-is. + return clip_weight_range_status_; + } + + // Returns an xla::Tuple of all table-shaped optimizer inputs. + virtual absl::StatusOr GetTablesInput( + XlaOpKernelContext* ctx, xla::XlaBuilder* builder) = 0; + + // Returns an xla::Tuple of all hyperparameter optimizer inputs. + virtual absl::StatusOr GetHyperparametersInput( + XlaOpKernelContext* ctx, xla::XlaBuilder* builder) = 0; + + // Returns the optimizer computation. + virtual absl::StatusOr BuildOptimizerComputation( + XlaOpKernelContext* ctx, int32_t feature_width) = 0; + + absl::StatusOr GetNumTablesInput(XlaOpKernelContext* ctx) { + // No side effects should remain from this builder -- we derive the number + // of inputs by inspecting the tuple XlaOp, which should be optimized away + // as the results are not consumed. + xla::XlaBuilder* builder = ctx->builder(); + TF_ASSIGN_OR_RETURN(xla::XlaOp tuple_op, GetTablesInput(ctx, builder)); + return GetTupleOpSize(builder, tuple_op); + } + + absl::StatusOr GetNumHyperparametersInput(XlaOpKernelContext* ctx) { + // No side effects should remain from this builder -- see comments above. + xla::XlaBuilder* builder = ctx->builder(); + TF_ASSIGN_OR_RETURN(xla::XlaOp tuple_op, + GetHyperparametersInput(ctx, builder)); + return GetTupleOpSize(builder, tuple_op); + } + + std::vector BuildVjpArguments(XlaOpKernelContext* ctx, + int32_t input_size, + int32_t feature_width) { + std::vector arguments; + + XlaCompiler::Argument valencies_arg; + XlaCompiler::Argument vectors_arg; + XlaCompiler::Argument weights_arg; + XlaCompiler::Argument activation_gradients_arg; + + valencies_arg.kind = XlaCompiler::Argument::kParameter; + valencies_arg.type = DT_INT32; + valencies_arg.shape = xla::ShapeUtil::MakeShape(xla::S32, {input_size}); + valencies_arg.name = "valencies"; + + vectors_arg.kind = XlaCompiler::Argument::kParameter; + vectors_arg.type = DT_FLOAT; + vectors_arg.shape = xla::ShapeUtil::MakeShape( + xla::F32, {input_size, max_valency_, feature_width}); + vectors_arg.name = "vectors"; + + weights_arg.kind = XlaCompiler::Argument::kParameter; + weights_arg.type = DT_FLOAT; + weights_arg.shape = + xla::ShapeUtil::MakeShape(xla::F32, {input_size, num_weights_}); + weights_arg.name = "weights"; + arguments.push_back(weights_arg); + + activation_gradients_arg.kind = XlaCompiler::Argument::kParameter; + activation_gradients_arg.type = DT_FLOAT; + activation_gradients_arg.shape = + xla::ShapeUtil::MakeShape(xla::F32, {input_size, feature_width}); + activation_gradients_arg.name = "activation_gradients"; + arguments.push_back(activation_gradients_arg); + + if (num_weights_ > 0) { + arguments = {valencies_arg, vectors_arg, weights_arg, + activation_gradients_arg}; + } else { + // Don't add the weights argument if it's not needed. This helps avoid + // issues of passing around zero-sized tensors and Xla values. + arguments = {valencies_arg, vectors_arg, activation_gradients_arg}; + } + + return arguments; + } + + absl::StatusOr BuildCombinerVjpComputation( + XlaOpKernelContext* ctx, int32_t input_size, int32_t feature_width, + const NameAttrList& computation) { + XlaCompiler::CompileOptions options; + options.use_tuple_arg = false; + options.always_return_tuple = false; + options.is_entry_computation = false; + + XlaCompiler* compiler = ctx->compiler(); + XlaCompiler::CompilationResult vjp_computation_result; + + TF_RETURN_IF_ERROR(compiler->CompileFunction( + options, computation, BuildVjpArguments(ctx, input_size, feature_width), + &vjp_computation_result)); + return std::move(*vjp_computation_result.computation); + } + + absl::StatusOr EmitTensorCoreComputations( + XlaOpKernelContext* ctx, xla::XlaBuilder* builder, int32_t input_size, + int32_t feature_width) { + xla::XlaOp weights = ctx->Input("weights"); + xla::XlaOp preserved_weights = ctx->Input("preserved_weights"); + xla::XlaOp activation_gradients = ctx->Input("activation_gradients"); + xla::XlaOp valencies = ctx->Input("preserved_valencies"); + xla::XlaOp vectors = ctx->Input("preserved_vectors"); + + // Build the required computations for the custom combiner. + TF_ASSIGN_OR_RETURN( + xla::XlaComputation combiner_vectors_vjp, + BuildCombinerVjpComputation(ctx, input_size, feature_width, + combiner_lookups_custom_vjp_computation_)); + TF_ASSIGN_OR_RETURN( + xla::XlaComputation combiner_weights_vjp, + BuildCombinerVjpComputation(ctx, input_size, feature_width, + combiner_weights_custom_vjp_computation_)); + + // The updated weights are the last output in the list. + const int32_t kUpdatedWeightsIndex = ctx->num_outputs() - 1; + + std::vector vjp_args; + if (num_weights_ > 0) { + xla::XlaOp broadcasted_preserved_weights = + xla::Broadcast(preserved_weights, {input_size}); + vjp_args = {valencies, vectors, broadcasted_preserved_weights, + activation_gradients}; + } else { + vjp_args = {valencies, vectors, activation_gradients}; + } + + // Compute the lookup gradients based on the activation gradients. This + // result will be passed to SC to drive the embedding table update. + xla::XlaOp lookup_gradients = + xla::Call(builder, combiner_vectors_vjp, vjp_args); + + // Compute the weights gradients based on the activation gradients. + if (num_weights_ > 0) { + // The weights VJP returns a tensor of shape f32[input_size, num_weights]. + xla::XlaOp weights_gradients_all_samples = + xla::Call(builder, combiner_weights_vjp, vjp_args); + // Local reduction, which aggregates the contributions from all samples + // and returns a tensor of shape f32[num_weights]. + xla::XlaOp per_replica_reduced_weights_gradients = xla::Reduce( + weights_gradients_all_samples, xla::ConstantR0(builder, 0.0), + xla::CreateScalarAddComputation(xla::F32, builder), {0}); + // Global reduction, which aggregates the contributions from all replicas + // and returns a tensor of shape f32[num_weights]. + // Here we assume that all replicas participate in the all-reduce (using + // default value of `replica_groups`) and that all-reduce from different + // modules do not participate in this reduction (using default value of + // `channel_id`). + xla::XlaOp global_reduced_weights_gradients = + xla::AllReduce(per_replica_reduced_weights_gradients, + xla::CreateScalarAddComputation(xla::F32, builder)); + // Use SGD optimizer on the weights. + // TODO(peitianpan): Add support for more optimizers. + xla::XlaOp learning_rate = ctx->Input("combiner_weights_learning_rate"); + xla::XlaOp updated_weights = + weights - learning_rate * global_reduced_weights_gradients; + ctx->SetOutput(kUpdatedWeightsIndex, updated_weights); + } else { + // The caller is not supposed to rely on this output if num_weights is 0. + ctx->SetOutput(kUpdatedWeightsIndex, xla::ConstantR0(builder, 0)); + } + + return lookup_gradients; + } + + absl::Status EmitSparseCoreComputations(XlaOpKernelContext* ctx, + xla::XlaBuilder* builder, + xla::XlaOp lookup_gradients, + int32_t max_ids_per_partition, + int32_t max_unique_ids_per_partition, + int32_t feature_width) { + xla::XlaOp row_pointers = ctx->Input("row_pointers"); + xla::XlaOp sorted_sample_ids = ctx->Input("sorted_sample_ids"); + xla::XlaOp sorted_token_ids = ctx->Input("sorted_token_ids"); + xla::XlaOp sorted_pos_ids = ctx->Input("sorted_pos_ids"); + xla::XlaOp sorted_gains = ctx->Input("sorted_gains"); + + xla::FrontendAttributes original_frontend_attributes = + builder->frontend_attributes(); + + xla::FrontendAttributes tuple_frontend_attributes; + + tuple_frontend_attributes.mutable_map()->insert( + {"_xla_compute_type", "sparse"}); + + builder->SetFrontendAttributes(tuple_frontend_attributes); + + TF_ASSIGN_OR_RETURN(xla::XlaOp tables, GetTablesInput(ctx, builder)); + TF_ASSIGN_OR_RETURN(xla::XlaOp hyperparameters, + GetHyperparametersInput(ctx, builder)); + + TF_ASSIGN_OR_RETURN(xla::Shape tables_shape, builder->GetShape(tables)); + if (tables_shape.tuple_shapes().size() + 1 != ctx->num_outputs()) { + return absl::InvalidArgumentError( + absl::StrCat("Expecting ", tables_shape.tuple_shapes().size() + 1, + " outputs but got ", ctx->num_outputs())); + } + + TF_ASSIGN_OR_RETURN(xla::XlaComputation optimizer, + BuildOptimizerComputation(ctx, feature_width)); + + xla::FrontendAttributes custom_call_frontend_attributes; + + custom_call_frontend_attributes.mutable_map()->insert( + {"_xla_compute_type", "sparse"}); + + custom_call_frontend_attributes.mutable_map()->insert( + {"_xla_sharding_strategy", "mod"}); + + custom_call_frontend_attributes.mutable_map()->insert( + {"_xla_pad_value", absl::StrCat(kXlaPadValue)}); + + custom_call_frontend_attributes.mutable_map()->insert( + {"_xla_max_ids_per_partition", absl::StrCat(max_ids_per_partition)}); + + custom_call_frontend_attributes.mutable_map()->insert( + {"_xla_max_unique_ids_per_partition", + absl::StrCat(max_unique_ids_per_partition)}); + + builder->SetFrontendAttributes(custom_call_frontend_attributes); + + xla::XlaOp updated_tables = xla::CustomCallWithComputation( + builder, + "SparseDenseMatmulCustomCombinerTcCombinerGradOptimizerUpdateMegachipO" + "p", + {row_pointers, sorted_token_ids, sorted_sample_ids, sorted_pos_ids, + sorted_gains, tables, lookup_gradients, hyperparameters}, + optimizer, tables_shape); + + builder->SetFrontendAttributes(tuple_frontend_attributes); + + // Updated embedding table. + for (int i = 0; i < tables_shape.tuple_shapes().size(); ++i) { + ctx->SetOutput(i, xla::GetTupleElement(updated_tables, i)); + } + + builder->SetFrontendAttributes(original_frontend_attributes); + return absl::OkStatus(); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* builder = ctx->builder(); + + OP_REQUIRES_OK(ctx, HandleClipWeightRangeStatus()); + + // Get the shape of the gradient. + OP_REQUIRES_VALUE(xla::Shape activation_shape, ctx, + ctx->InputXlaShape("activation_gradients")); + OP_REQUIRES( + ctx, + activation_shape.is_static() && activation_shape.dimensions_size() == 2, + absl::InvalidArgumentError(absl::StrCat( + "activations input has non static or non-rank 2 shape: ", + activation_shape.ToString()))); + OP_REQUIRES_VALUE(int64_t num_sparsecores_per_device, ctx, + GetSparseCoresPerLogicalDevice()); + int64_t num_samples_per_chip = activation_shape.dimensions(0); + OP_REQUIRES(ctx, num_samples_per_chip % num_sparsecores_per_device == 0, + absl::InvalidArgumentError(absl::StrCat( + "num_samples_per_chip ", num_samples_per_chip, + " not divisible by the number of sparsecores per chip ", + num_sparsecores_per_device))); + + int64_t per_sparse_core_batch_size = + num_samples_per_chip / num_sparsecores_per_device; + int64_t max_ids_per_partition = 0; + int64_t max_unique_ids_per_partition = 0; + + const int32_t feature_width = activation_shape.dimensions(1); + OP_REQUIRES_OK( + ctx, GetMaxIdsAndUniquesExternal(kUnknownProgramKey, table_name_, + per_sparse_core_batch_size, + feature_width, &max_ids_per_partition, + &max_unique_ids_per_partition)); + LOG(INFO) + << "Lowering XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp " + << "to HLO: table_name = '" << table_name_ + << "', max_ids = " << max_ids_per_partition + << ", max_uniques = " << max_unique_ids_per_partition; + + // Emit the two custom combiner VJP computations onto TC. + int32_t input_size = activation_shape.dimensions(0); + OP_REQUIRES_VALUE( + xla::XlaOp lookup_gradients, ctx, + EmitTensorCoreComputations(ctx, builder, input_size, feature_width)); + + // Pass the TC activation gradients back to SC for back-propagation with + // optimizer. + OP_REQUIRES_OK(ctx, + EmitSparseCoreComputations( + ctx, builder, lookup_gradients, max_ids_per_partition, + max_unique_ids_per_partition, feature_width)); + } + + protected: + int32_t max_valency_; + int32_t num_weights_; + float clip_weight_min_; + float clip_weight_max_; + std::string table_name_; + NameAttrList combiner_weights_custom_vjp_computation_; + NameAttrList combiner_lookups_custom_vjp_computation_; + + absl::Status clip_weight_range_status_; + + private: + XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase&) = + delete; + void operator=( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase&) = + delete; +}; + +// TC custom combiner VJP + SC back-propagation with a custom optimizer. +class XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp + : public XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase { + public: + explicit XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp( + OpKernelConstruction* ctx) + : XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("N", &num_tables_)); + + const NameAttrList* name_attr; + OP_REQUIRES_OK(ctx, + ctx->GetAttr("optimizer_custom_computation", &name_attr)); + optimizer_custom_computation_ = *name_attr; + } + + absl::Status HandleClipWeightRangeStatus() override { + // The custom optimizer BWD op does not require the weight clip range. + return absl::OkStatus(); + } + + absl::StatusOr GetTablesInput(XlaOpKernelContext* ctx, + xla::XlaBuilder* builder) override { + std::vector tables_input; + std::vector tables_shapes; + TF_RETURN_IF_ERROR(ctx->InputList("tables", &tables_input, &tables_shapes)); + if (num_tables_ != tables_shapes.size()) { + return absl::InvalidArgumentError(absl::StrCat("Expecting ", num_tables_, + " tables, but got ", + tables_shapes.size())); + } + return xla::Tuple(builder, tables_input); + } + + absl::StatusOr GetHyperparametersInput( + XlaOpKernelContext* ctx, xla::XlaBuilder* builder) override { + std::vector hyperparameters_input; + std::vector hyperparameters_shapes; + TF_RETURN_IF_ERROR(ctx->InputList("hyperparameters", &hyperparameters_input, + &hyperparameters_shapes)); + return xla::Tuple(builder, hyperparameters_input); + } + + absl::StatusOr BuildOptimizerComputation( + XlaOpKernelContext* ctx, int32_t feature_width) override { + XlaCompiler::CompileOptions options; + + // We don't use tuple args and always return tuple for this computation. + options.use_tuple_arg = false; + options.always_return_tuple = true; + options.is_entry_computation = false; + + XlaCompiler* compiler = ctx->compiler(); + + XlaCompiler::CompilationResult custom_computation_result; + + // The number of arguments is the number of tables + the number of + // hyperparameters + 1 for the activation gradients. + TF_ASSIGN_OR_RETURN(const int32_t num_tables_inputs, + GetNumTablesInput(ctx)); + TF_ASSIGN_OR_RETURN(const int32_t num_hyperparameters_inputs, + GetNumHyperparametersInput(ctx)); + int32_t num_arguments = 1 + num_tables_inputs + num_hyperparameters_inputs; + + std::vector arguments(num_arguments); + + // For all the arguments, we use the float type and the shape is + // {1, feature_width}. + for (int32_t i = 0; i < num_arguments; ++i) { + arguments[i].kind = XlaCompiler::Argument::kParameter; + arguments[i].type = DT_FLOAT; + arguments[i].shape = + xla::ShapeUtil::MakeShape(xla::F32, {1, feature_width}); + } + + TF_RETURN_IF_ERROR( + compiler->CompileFunction(options, optimizer_custom_computation_, + arguments, &custom_computation_result)); + + return std::move(*custom_computation_result.computation); + } + + private: + int32_t num_tables_; + NameAttrList optimizer_custom_computation_; + + XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp&) = delete; + void operator=( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp&) = delete; +}; + +REGISTER_XLA_OP(Name("XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput"), + XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp); + +// TC custom combiner VJP + SC back-propagation with the SGD optimizer. +class XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInputOp + : public XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase { + public: + explicit XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInputOp( + OpKernelConstruction* ctx) + : XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase(ctx) {} + + ~XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInputOp() override = + default; + + absl::StatusOr GetTablesInput(XlaOpKernelContext* ctx, + xla::XlaBuilder* builder) override { + return xla::Tuple(builder, {ctx->Input("embedding_table")}); + } + + absl::StatusOr GetHyperparametersInput( + XlaOpKernelContext* ctx, xla::XlaBuilder* builder) override { + return xla::Tuple(builder, {ctx->Input("learning_rate")}); + } + + absl::StatusOr BuildOptimizerComputation( + XlaOpKernelContext* ctx, const int32_t feature_width) override { + return BuildSgdOptimizerComputation(feature_width, clip_weight_min_, + clip_weight_max_); + } + + private: + XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInputOp( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInputOp&) = + delete; + void operator=( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInputOp&) = + delete; +}; + +REGISTER_XLA_OP( + Name("XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput"), + XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInputOp); + +// TC custom combiner VJP + SC back-propagation with the Adagrad optimizer. +class XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInputOp + : public XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase { + public: + explicit XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInputOp( + OpKernelConstruction* ctx) + : XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase(ctx) {} + + ~XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInputOp() + override = default; + + absl::StatusOr GetTablesInput(XlaOpKernelContext* ctx, + xla::XlaBuilder* builder) override { + return xla::Tuple( + builder, {ctx->Input("embedding_table"), ctx->Input("accumulator")}); + } + + absl::StatusOr GetHyperparametersInput( + XlaOpKernelContext* ctx, xla::XlaBuilder* builder) override { + return xla::Tuple(builder, {ctx->Input("learning_rate")}); + } + + absl::StatusOr BuildOptimizerComputation( + XlaOpKernelContext* ctx, const int32_t feature_width) override { + return BuildAdagradOptimizerComputation(feature_width, clip_weight_min_, + clip_weight_max_); + } + + private: + XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInputOp( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInputOp&) = // NOLINT + delete; + void operator=( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInputOp&) = // NOLINT + delete; +}; + +REGISTER_XLA_OP( + Name("XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput"), + XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInputOp); + +// TC custom combiner VJP + SC back-propagation with the AdagradMomentum +// optimizer. +class XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInputOp + : public XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase { + public: + explicit XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInputOp( // NOLINT + OpKernelConstruction* ctx) + : XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("exponent", &exponent_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("beta1", &beta1_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("beta2", &beta2_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); + } + + ~XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInputOp() + override = default; + + absl::StatusOr GetTablesInput(XlaOpKernelContext* ctx, + xla::XlaBuilder* builder) override { + return xla::Tuple(builder, + {ctx->Input("embedding_table"), ctx->Input("accumulator"), + ctx->Input("momenta")}); + } + + absl::StatusOr GetHyperparametersInput( + XlaOpKernelContext* ctx, xla::XlaBuilder* builder) override { + return xla::Tuple(builder, {ctx->Input("learning_rate")}); + } + + absl::StatusOr BuildOptimizerComputation( + XlaOpKernelContext* ctx, const int32_t feature_width) override { + return BuildAdagradMomentumOptimizerComputation( + feature_width, use_nesterov_, exponent_, beta1_, beta2_, epsilon_, + clip_weight_min_, clip_weight_max_); + } + + private: + bool use_nesterov_; + float exponent_; + float beta1_; + float beta2_; + float epsilon_; + + XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInputOp( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInputOp&) = // NOLINT + delete; + void operator=( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInputOp&) = // NOLINT + delete; +}; + +REGISTER_XLA_OP( + Name("XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrIn" + "put"), + XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInputOp); + +// TC custom combiner VJP + SC back-propagation with the Adam optimizer. +class XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInputOp + : public XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase { + public: + explicit XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInputOp( + OpKernelConstruction* ctx) + : XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase(ctx) { + OP_REQUIRES_OK(ctx, + ctx->GetAttr("use_sum_inside_sqrt", &use_sum_inside_sqrt_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("beta1", &beta1_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("beta2", &beta2_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); + } + + ~XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInputOp() override = + default; + + absl::StatusOr GetTablesInput(XlaOpKernelContext* ctx, + xla::XlaBuilder* builder) override { + return xla::Tuple(builder, {ctx->Input("embedding_table"), + ctx->Input("momenta"), ctx->Input("velocity")}); + } + + absl::StatusOr GetHyperparametersInput( + XlaOpKernelContext* ctx, xla::XlaBuilder* builder) override { + return xla::Tuple(builder, {ctx->Input("learning_rate")}); + } + + absl::StatusOr BuildOptimizerComputation( + XlaOpKernelContext* ctx, const int32_t feature_width) override { + return BuildAdamOptimizerComputation(feature_width, use_sum_inside_sqrt_, + beta1_, beta2_, epsilon_, + clip_weight_min_, clip_weight_max_); + } + + private: + bool use_sum_inside_sqrt_; + float beta1_; + float beta2_; + float epsilon_; + + XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInputOp( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInputOp&) = + delete; + void operator=( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInputOp&) = + delete; +}; + +REGISTER_XLA_OP( + Name("XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput"), + XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInputOp); + +// TC custom combiner VJP + SC back-propagation with the FTRL optimizer. +class XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInputOp + : public XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase { + public: + explicit XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInputOp( + OpKernelConstruction* ctx) + : XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputBase(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("multiply_linear_by_learning_rate", + &multiply_linear_by_learning_rate_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &beta_)); + OP_REQUIRES_OK(ctx, + ctx->GetAttr("learning_rate_power", &learning_rate_power_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("l1_regularization_strength", + &l1_regularization_strength_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("l2_regularization_strength", + &l2_regularization_strength_)); + } + + ~XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInputOp() override = + default; + + absl::StatusOr GetTablesInput(XlaOpKernelContext* ctx, + xla::XlaBuilder* builder) override { + return xla::Tuple(builder, + {ctx->Input("embedding_table"), ctx->Input("accumulator"), + ctx->Input("linear")}); + } + + absl::StatusOr GetHyperparametersInput( + XlaOpKernelContext* ctx, xla::XlaBuilder* builder) override { + return xla::Tuple(builder, {ctx->Input("learning_rate")}); + } + + absl::StatusOr BuildOptimizerComputation( + XlaOpKernelContext* ctx, const int32_t feature_width) override { + return BuildFtrlOptimizerComputation( + feature_width, multiply_linear_by_learning_rate_, beta_, + learning_rate_power_, l1_regularization_strength_, + l2_regularization_strength_, clip_weight_min_, clip_weight_max_); + } + + private: + bool multiply_linear_by_learning_rate_; + float beta_; + float learning_rate_power_; + float l1_regularization_strength_; + float l2_regularization_strength_; + + XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInputOp( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInputOp&) = + delete; + void operator=( + const XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInputOp&) = + delete; +}; + +REGISTER_XLA_OP( + Name("XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput"), + XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInputOp); + // This TensorFlow op calculates the gradients and performs SGD update on the // embedding table on SparseCore. It takes the activation gradients, input // sparse tensor represented by the `row_pointers`, `sorted_embedding_ids`, @@ -703,36 +1627,8 @@ class XlaSparseDenseMatmulGradWithSgdAndCsrInputOp xla::XlaComputation build_optimizer_computation( const int32_t feature_width) override { - xla::XlaComputation sgd_optimizer = [&] { - auto sgd_optimizer_builder = - std::make_unique("sgd_optimizer_builder"); - - xla::Shape per_row_shape = - xla::ShapeUtil::MakeShapeWithType({1, feature_width}); - - xla::XlaOp gradient = xla::Parameter(sgd_optimizer_builder.get(), 0, - per_row_shape, "gradient"); - - xla::XlaOp embedding_table = xla::Parameter( - sgd_optimizer_builder.get(), 1, per_row_shape, "embedding_table"); - - xla::XlaOp learning_rate = xla::Parameter(sgd_optimizer_builder.get(), 2, - per_row_shape, "learning_rate"); - - xla::XlaOp updated_embedding_table = - embedding_table - learning_rate * gradient; - - // Apply the weight clipping. - xla::XlaOp clipped_embedding_table = apply_weight_clipping_to_table( - sgd_optimizer_builder.get(), updated_embedding_table); - - xla::XlaOp updated_tables = - xla::Tuple(sgd_optimizer_builder.get(), {clipped_embedding_table}); - - return sgd_optimizer_builder->Build(updated_tables).value(); - }(); - - return sgd_optimizer; + return BuildSgdOptimizerComputation(feature_width, clip_weight_min_, + clip_weight_max_); } xla::XlaOp get_tables_input(XlaOpKernelContext* ctx) override { @@ -772,42 +1668,8 @@ class XlaSparseDenseMatmulGradWithAdagradAndCsrInputOp xla::XlaComputation build_optimizer_computation( const int32_t feature_width) override { - xla::XlaComputation adagrad_optimizer = [&] { - auto adagrad_optimizer_builder = - std::make_unique("adagrad_optimizer_builder"); - - xla::Shape per_row_shape = - xla::ShapeUtil::MakeShapeWithType({1, feature_width}); - - xla::XlaOp gradient = xla::Parameter(adagrad_optimizer_builder.get(), 0, - per_row_shape, "gradient"); - - xla::XlaOp embedding_table = xla::Parameter( - adagrad_optimizer_builder.get(), 1, per_row_shape, "embedding_table"); - - xla::XlaOp accumulator = xla::Parameter(adagrad_optimizer_builder.get(), - 2, per_row_shape, "accumulator"); - - xla::XlaOp learning_rate = xla::Parameter( - adagrad_optimizer_builder.get(), 3, per_row_shape, "learning_rate"); - - xla::XlaOp new_accumulator = accumulator + gradient * gradient; - - xla::XlaOp updated_embedding_table = - embedding_table - - learning_rate * gradient / xla::Sqrt(new_accumulator); - - // Apply the weight clipping. - xla::XlaOp clipped_embedding_table = apply_weight_clipping_to_table( - adagrad_optimizer_builder.get(), updated_embedding_table); - - xla::XlaOp updated_tables = - xla::Tuple(adagrad_optimizer_builder.get(), - {clipped_embedding_table, new_accumulator}); - return adagrad_optimizer_builder->Build(updated_tables).value(); - }(); - - return adagrad_optimizer; + return BuildAdagradOptimizerComputation(feature_width, clip_weight_min_, + clip_weight_max_); } xla::XlaOp get_tables_input(XlaOpKernelContext* ctx) override { @@ -857,82 +1719,9 @@ class XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInputOp xla::XlaComputation build_optimizer_computation( const int32_t feature_width) override { - xla::XlaComputation adagrad_momentum_optimizer = [&] { - auto adagrad_momentum_optimizer_builder = - std::make_unique( - "adagrad_momentum_optimizer_builder"); - - xla::Shape per_row_shape = - xla::ShapeUtil::MakeShapeWithType({1, feature_width}); - - xla::XlaOp gradient = - xla::Parameter(adagrad_momentum_optimizer_builder.get(), 0, - per_row_shape, "gradient"); - xla::XlaOp embedding_table = - xla::Parameter(adagrad_momentum_optimizer_builder.get(), 1, - per_row_shape, "embedding_table"); - xla::XlaOp accumulator = - xla::Parameter(adagrad_momentum_optimizer_builder.get(), 2, - per_row_shape, "accumulator"); - xla::XlaOp momenta = - xla::Parameter(adagrad_momentum_optimizer_builder.get(), 3, - per_row_shape, "momenta"); - xla::XlaOp learning_rate = - xla::Parameter(adagrad_momentum_optimizer_builder.get(), 4, - per_row_shape, "learning_rate"); - - xla::XlaOp beta1 = - xla::ConstantR0(adagrad_momentum_optimizer_builder.get(), beta1_); - xla::XlaOp beta2 = - xla::ConstantR0(adagrad_momentum_optimizer_builder.get(), beta2_); - xla::XlaOp epsilon = - xla::ConstantR0(adagrad_momentum_optimizer_builder.get(), epsilon_); - - // If beta_2 == 1: - // accumulator(t) = accumulator(t-1) + gradient(t) ^ 2 - // Else: - // accumulator(t) = beta_2 * accumulator(t-1) + - // (1-beta_2) * gradient(t) ^ 2 - xla::XlaOp exponent = xla::ConstantR0( - adagrad_momentum_optimizer_builder.get(), 1.0f / exponent_); - xla::XlaOp one = - xla::ConstantR0(adagrad_momentum_optimizer_builder.get(), 1.0f); - - xla::XlaOp new_accumulator = xla::Select( - xla::Eq(beta2, one), accumulator + gradient * gradient, - beta2 * accumulator + (one - beta2) * gradient * gradient); - - // scaled_gradient = (accumulator + epsilon)^(-1/k) * gradient - xla::XlaOp scaled_gradients = - Pow(new_accumulator + epsilon, xla::Neg(exponent)) * gradient; - - // momenta(t) = beta1 * momenta(t-1) + scaled_gradient(t) - xla::XlaOp new_momenta = beta1 * momenta + scaled_gradients; - - // Table update: - // non-nesterov: update = momenta_t - // nesterov: update = beta_1 * momenta_t + scaled_gradient - // weights(t) = weights(t-1) - lr * update - xla::XlaOp updated_embedding_table; - if (use_nesterov_) { - updated_embedding_table = - embedding_table - - learning_rate * (beta1 * new_momenta + scaled_gradients); - } else { - updated_embedding_table = embedding_table - learning_rate * new_momenta; - } - - // Apply the weight clipping. - xla::XlaOp clipped_embedding_table = apply_weight_clipping_to_table( - adagrad_momentum_optimizer_builder.get(), updated_embedding_table); - - xla::XlaOp updated_tables = - xla::Tuple(adagrad_momentum_optimizer_builder.get(), - {clipped_embedding_table, new_accumulator, new_momenta}); - return adagrad_momentum_optimizer_builder->Build(updated_tables).value(); - }(); - - return adagrad_momentum_optimizer; + return BuildAdagradMomentumOptimizerComputation( + feature_width, use_nesterov_, exponent_, beta1_, beta2_, epsilon_, + clip_weight_min_, clip_weight_max_); } xla::XlaOp get_tables_input(XlaOpKernelContext* ctx) override { @@ -986,64 +1775,9 @@ class XlaSparseDenseMatmulGradWithAdamAndCsrInputOp xla::XlaComputation build_optimizer_computation( const int32_t feature_width) override { - xla::XlaComputation adam_optimizer = [&] { - auto adam_optimizer_builder = - std::make_unique("adam_optimizer_builder"); - - xla::Shape per_row_shape = - xla::ShapeUtil::MakeShapeWithType({1, feature_width}); - - xla::XlaOp gradient = xla::Parameter(adam_optimizer_builder.get(), 0, - per_row_shape, "gradient"); - xla::XlaOp embedding_table = xla::Parameter( - adam_optimizer_builder.get(), 1, per_row_shape, "embedding_table"); - xla::XlaOp momenta = xla::Parameter(adam_optimizer_builder.get(), 2, - per_row_shape, "momenta"); - xla::XlaOp velocity = xla::Parameter(adam_optimizer_builder.get(), 3, - per_row_shape, "velocity"); - xla::XlaOp learning_rate = xla::Parameter(adam_optimizer_builder.get(), 4, - per_row_shape, "learning_rate"); - - xla::XlaOp beta1 = xla::ConstantR0(adam_optimizer_builder.get(), beta1_); - xla::XlaOp beta2 = xla::ConstantR0(adam_optimizer_builder.get(), beta2_); - xla::XlaOp epsilon = - xla::ConstantR0(adam_optimizer_builder.get(), epsilon_); - - // Depending on sum_inside_sqrt, the denominator is either: - // sum_inside_sqrt==true: sqrt(v + eps^2) - // sum_inside_sqrt==false: sqrt(v) + eps - // To simplify the for loop below, write the sqrt denominator as: - // sqrt(v + e1) + e2 - // and set e1 and e2 appropriately: - xla::XlaOp zero = xla::ConstantR0(adam_optimizer_builder.get(), 0.0f); - xla::XlaOp one = xla::ConstantR0(adam_optimizer_builder.get(), 1.0f); - xla::XlaOp e1 = use_sum_inside_sqrt_ ? epsilon * epsilon : zero; - xla::XlaOp e2 = use_sum_inside_sqrt_ ? zero : epsilon; - - // momentum(t) = beta_1 * momentum(t-1) - // + (1-beta_1)*gradient(t) - xla::XlaOp new_momenta = beta1 * momenta + (one - beta1) * gradient; - - // velocity(t) = beta_2 * velocity(t-1) - // + (1-beta_2)*gradient(t)*gradient(t) - xla::XlaOp new_velocity = - beta2 * velocity + (one - beta2) * gradient * gradient; - - xla::XlaOp updated_embedding_table = - embedding_table - - learning_rate * new_momenta / (xla::Sqrt(new_velocity + e1) + e2); - - // Apply the weight clipping. - xla::XlaOp clipped_embedding_table = apply_weight_clipping_to_table( - adam_optimizer_builder.get(), updated_embedding_table); - - xla::XlaOp updated_tables = - xla::Tuple(adam_optimizer_builder.get(), - {clipped_embedding_table, new_momenta, new_velocity}); - return adam_optimizer_builder->Build(updated_tables).value(); - }(); - - return adam_optimizer; + return BuildAdamOptimizerComputation(feature_width, use_sum_inside_sqrt_, + beta1_, beta2_, epsilon_, + clip_weight_min_, clip_weight_max_); } xla::XlaOp get_tables_input(XlaOpKernelContext* ctx) override { @@ -1101,104 +1835,10 @@ class XlaSparseDenseMatmulGradWithFtrlAndCsrInputOp xla::XlaComputation build_optimizer_computation( const int32_t feature_width) override { - xla::XlaComputation ftrl_optimizer = [&] { - auto ftrl_optimizer_builder = - std::make_unique("ftrl_optimizer_builder"); - - xla::Shape per_row_shape = - xla::ShapeUtil::MakeShapeWithType({1, feature_width}); - - xla::XlaOp gradient = xla::Parameter(ftrl_optimizer_builder.get(), 0, - per_row_shape, "gradient"); - - xla::XlaOp embedding_table = xla::Parameter( - ftrl_optimizer_builder.get(), 1, per_row_shape, "embedding_table"); - xla::XlaOp accumulator = xla::Parameter(ftrl_optimizer_builder.get(), 2, - per_row_shape, "accumulator"); - xla::XlaOp linear = xla::Parameter(ftrl_optimizer_builder.get(), 3, - per_row_shape, "linear"); - xla::XlaOp learning_rate = xla::Parameter(ftrl_optimizer_builder.get(), 4, - per_row_shape, "learning_rate"); - - // accumulator(t) = accumulator(t-1) + gradient(t) ^ 2 - xla::XlaOp new_accumulator = accumulator + gradient * gradient; - - xla::XlaOp learning_rate_power = - xla::ConstantR0(ftrl_optimizer_builder.get(), learning_rate_power_); - - xla::XlaOp power_old = Pow(accumulator, xla::Neg(learning_rate_power)); - xla::XlaOp power_new = - Pow(new_accumulator, xla::Neg(learning_rate_power)); - xla::XlaOp delta_p = power_new - power_old; - - xla::XlaOp zero = xla::ConstantR0(ftrl_optimizer_builder.get(), 0.0f); - - xla::XlaOp two = xla::ConstantR0(ftrl_optimizer_builder.get(), 2.0f); - - xla::XlaOp l1_regularization_strength = xla::ConstantR0( - ftrl_optimizer_builder.get(), l1_regularization_strength_); - - xla::XlaOp l2_regularization_strength = xla::ConstantR0( - ftrl_optimizer_builder.get(), l2_regularization_strength_); - - xla::XlaOp beta = xla::ConstantR0(ftrl_optimizer_builder.get(), beta_); - - // Note: - // min(|linear(t)|, lr*l1)*sgn(linear(t)) - // can be written as - // clamp( -lr*l1, linear(t), lr*l1) - // assuming lr>0 and l1>0. - xla::XlaOp new_linear; - xla::XlaOp numer; - xla::XlaOp denom; - if (multiply_linear_by_learning_rate_) { - new_linear = - linear + learning_rate * gradient - delta_p * embedding_table; - // if multiply_linear: - // linear(t) = linear(t-1) + lr*g - delta_p * table(t-1) - // Update numerator: - // N = min(|linear(t)|, lr*l1)*sgn(linear(t)) - linear(t) - // Update denomninator: - // D = power(t) + 2*lr*l2 + beta - // table(t) = N / D - numer = xla::Select( - xla::Eq(l1_regularization_strength, zero), xla::Neg(new_linear), - xla::Clamp(xla::Neg(learning_rate * l1_regularization_strength), - new_linear, learning_rate * l1_regularization_strength) - - new_linear); - denom = - power_new + two * learning_rate * l2_regularization_strength + beta; - } else { - new_linear = - linear + gradient - delta_p * embedding_table / learning_rate; - // if NOT multiply_linear: - // linear(t) = linear(t-1) + g - (1/lr) * delta_p * table(t-1) - // Update numerator: - // N = min(|linear(t)|, l1)*sgn(linear(t)) - linear(t) - // Update denomninator: - // D = (1/lr) * (power(t) + beta) + 2*l2 - // table(t) = N / D - numer = xla::Select(xla::Eq(l1_regularization_strength, zero), - xla::Neg(new_linear), - xla::Clamp(xla::Neg(l1_regularization_strength), - new_linear, l1_regularization_strength) - - new_linear); - denom = (power_new + beta) / learning_rate + - two * l2_regularization_strength; - } - xla::XlaOp updated_embedding_table = numer / denom; - - // Apply the weight clipping. - xla::XlaOp clipped_embedding_table = apply_weight_clipping_to_table( - ftrl_optimizer_builder.get(), updated_embedding_table); - - xla::XlaOp updated_tables = - xla::Tuple(ftrl_optimizer_builder.get(), - {clipped_embedding_table, new_accumulator, new_linear}); - return ftrl_optimizer_builder->Build(updated_tables).value(); - }(); - - return ftrl_optimizer; + return BuildFtrlOptimizerComputation( + feature_width, multiply_linear_by_learning_rate_, beta_, + learning_rate_power_, l1_regularization_strength_, + l2_regularization_strength_, clip_weight_min_, clip_weight_max_); } xla::XlaOp get_tables_input(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/core/tpu/ops/sparse_core_ops.cc b/tensorflow/core/tpu/ops/sparse_core_ops.cc index 8d0186f145b36a..db71374caa06d9 100644 --- a/tensorflow/core/tpu/ops/sparse_core_ops.cc +++ b/tensorflow/core/tpu/ops/sparse_core_ops.cc @@ -21,6 +21,51 @@ limitations under the License. namespace tensorflow { +namespace { + +absl::Status ValidateSparseDenseMatmulCustomCombinerGradWithCsrInputShape( + shape_inference::InferenceContext* c, const int weights_index, + const int preserved_valencies_index, const int preserved_vectors_index, + const int preserved_weights_index, const int activation_gradients_index, + const int tables_index, const int num_tables) { + shape_inference::ShapeHandle shape; + int num_weights; + int max_valency_int; + TF_RETURN_IF_ERROR(c->GetAttr("num_weights", &num_weights)); + TF_RETURN_IF_ERROR(c->GetAttr("max_valency", &max_valency_int)); + // Only check the shape of the weights when num_weights > 0 to avoid + // issues of 0-shaped values. + if (num_weights > 0) { + TF_RETURN_IF_ERROR(c->Merge(c->input(weights_index), + c->MakeShape({c->MakeDim(num_weights)}), + &shape)); + TF_RETURN_IF_ERROR(c->Merge(c->input(preserved_weights_index), + c->MakeShape({c->MakeDim(num_weights)}), + &shape)); + } + // Check that the preserved tensors have the expected shapes: + // valencies: [input_size]; + // vectors: [input_size, max_valency, feature_width]; + auto input_size = c->Dim(c->input(activation_gradients_index), 0); + auto max_valency = c->MakeDim(max_valency_int); + auto feature_width = c->Dim(c->input(tables_index), 1); + TF_RETURN_IF_ERROR(c->Merge(c->input(preserved_valencies_index), + c->MakeShape({input_size}), &shape)); + TF_RETURN_IF_ERROR( + c->Merge(c->input(preserved_vectors_index), + c->MakeShape({input_size, max_valency, feature_width}), &shape)); + // `updated_tables` refers to both the embedding table and the associated + // slot variables. They all have the same embedding table shape. + for (int i = 0; i < num_tables; ++i) { + c->set_output(i, c->input(tables_index)); + } + // `updated_weights` simply have a 1D shape of `num_weights`. + c->set_output(num_tables, c->MakeShape({c->MakeDim(num_weights)})); + return absl::OkStatus(); +} + +} // namespace + REGISTER_OP("XlaSparseDenseMatmul") .Input("row_ids: int32") .Input("col_ids: uint32") @@ -75,6 +120,7 @@ REGISTER_OP("XlaSparseDenseMatmulWithCsrInput") .Attr("quantization_config_high: float") .Attr("quantization_config_num_buckets: int >= 0") .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { int input_size; TF_RETURN_IF_ERROR(c->GetAttr("input_size", &input_size)); @@ -95,6 +141,82 @@ REGISTER_OP("XlaSparseDenseMatmulWithCsrInput") return absl::OkStatus(); }); +REGISTER_OP("XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput") + .Input("row_pointers: int32") + .Input("sorted_sample_ids: int32") + .Input("sorted_token_ids: int32") + .Input("sorted_pos_ids: int32") + .Input("sorted_gains: float32") + .Input("embedding_table: float32") + .Input("weights: float32") + .Output("activations: float32") + .Output("preserved_valencies: int32") + .Output("preserved_vectors: float32") + .Attr("input_size: int >= 0") + .Attr("max_valency: int >= 0") + .Attr("num_weights: int >= 0") + .Attr("combiner_computation: func") + .Attr("quantization_config_low: float") + .Attr("quantization_config_high: float") + .Attr("quantization_config_num_buckets: int >= 0") + .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { + constexpr int kRowPointersIndex = 0; + constexpr int kSortedSampleIdsIndex = 1; + constexpr int kEmbeddingTableIndex = 5; + constexpr int kEmbeddingTableRank = 2; + constexpr int kWeightsIndex = 6; + constexpr int kWeightsRank = 1; + constexpr int kOutputActivationsIndex = 0; + constexpr int kPreservedValenciesIndex = 1; + constexpr int kPreservedVectorsIndex = 2; + // This input_size is per-chip batch size. + int input_size; + TF_RETURN_IF_ERROR(c->GetAttr("input_size", &input_size)); + int max_valency; + TF_RETURN_IF_ERROR(c->GetAttr("max_valency", &max_valency)); + int num_weights; + TF_RETURN_IF_ERROR(c->GetAttr("num_weights", &num_weights)); + + shape_inference::ShapeHandle rank; + for (int i = kRowPointersIndex; i < kEmbeddingTableIndex; ++i) { + TF_RETURN_IF_ERROR( + c->WithRank(c->input(i), kSortedSampleIdsIndex, &rank)); + } + TF_RETURN_IF_ERROR(c->WithRank(c->input(kEmbeddingTableIndex), + kEmbeddingTableRank, &rank)); + for (int i = kSortedSampleIdsIndex + 1; i < kEmbeddingTableIndex; ++i) { + shape_inference::ShapeHandle merged; + TF_RETURN_IF_ERROR( + c->Merge(c->input(i), c->input(kSortedSampleIdsIndex), &merged)); + } + if (num_weights > 0) { + TF_RETURN_IF_ERROR( + c->WithRank(c->input(kWeightsIndex), kWeightsRank, &rank)); + shape_inference::DimensionHandle weights_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(kWeightsIndex), 0), + num_weights, &weights_dim)); + } + + shape_inference::DimensionHandle input_size_dim = c->MakeDim(input_size); + shape_inference::DimensionHandle max_valency_dim = + c->MakeDim(max_valency); + shape_inference::DimensionHandle feature_width_dim = + c->Dim(c->input(kEmbeddingTableIndex), 1); + shape_inference::ShapeHandle output_activations_shape; + TF_RETURN_IF_ERROR(c->ReplaceDim(c->input(kEmbeddingTableIndex), 0, + c->MakeDim(input_size), + &output_activations_shape)); + c->set_output(kOutputActivationsIndex, output_activations_shape); + c->set_output(kPreservedValenciesIndex, c->MakeShape({input_size_dim})); + c->set_output( + kPreservedVectorsIndex, + c->MakeShape({input_size_dim, max_valency_dim, feature_width_dim})); + + return absl::OkStatus(); + }); + REGISTER_OP("XlaSparseDenseMatmulGradWithSgdAndCsrInput") .Input("row_pointers: int32") .Input("sorted_sample_ids: int32") @@ -108,6 +230,7 @@ REGISTER_OP("XlaSparseDenseMatmulGradWithSgdAndCsrInput") .Attr("clip_weight_min: float = -inf") .Attr("clip_weight_max: float = inf") .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { c->set_output(0, c->input(6)); return absl::OkStatus(); @@ -128,6 +251,7 @@ REGISTER_OP("XlaSparseDenseMatmulGradWithAdagradAndCsrInput") .Attr("clip_weight_min: float = -inf") .Attr("clip_weight_max: float = inf") .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { c->set_output(0, c->input(6)); c->set_output(1, c->input(7)); @@ -156,6 +280,7 @@ REGISTER_OP("XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput") .Attr("clip_weight_min: float = -inf") .Attr("clip_weight_max: float = inf") .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { c->set_output(0, c->input(6)); c->set_output(1, c->input(7)); @@ -184,6 +309,7 @@ REGISTER_OP("XlaSparseDenseMatmulGradWithAdamAndCsrInput") .Attr("clip_weight_min: float = -inf") .Attr("clip_weight_max: float = inf") .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { c->set_output(0, c->input(6)); c->set_output(1, c->input(7)); @@ -213,6 +339,7 @@ REGISTER_OP("XlaSparseDenseMatmulGradWithFtrlAndCsrInput") .Attr("clip_weight_min: float = -inf") .Attr("clip_weight_max: float = inf") .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { c->set_output(0, c->input(6)); c->set_output(1, c->input(7)); @@ -344,6 +471,7 @@ REGISTER_OP("XlaSparseDenseMatmulWithStaticBufferSize") .Attr("max_ids_per_sparse_core: int >= 1") .Attr("max_unique_ids_per_sparse_core: int >= 1") .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { int input_size; TF_RETURN_IF_ERROR(c->GetAttr("input_size", &input_size)); @@ -379,6 +507,7 @@ REGISTER_OP("XlaSparseDenseMatmulGradWithSgdAndStaticBufferSize") .Attr("max_ids_per_sparse_core: int >= 1") .Attr("max_unique_ids_per_sparse_core: int >= 1") .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { c->set_output(0, c->input(6)); return absl::OkStatus(); @@ -401,6 +530,7 @@ REGISTER_OP("XlaSparseDenseMatmulGradWithAdagradAndStaticBufferSize") .Attr("max_ids_per_sparse_core: int >= 1") .Attr("max_unique_ids_per_sparse_core: int >= 1") .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { c->set_output(0, c->input(6)); c->set_output(1, c->input(7)); @@ -431,6 +561,7 @@ REGISTER_OP("XlaSparseDenseMatmulGradWithAdagradMomentumAndStaticBufferSize") .Attr("max_ids_per_sparse_core: int >= 1") .Attr("max_unique_ids_per_sparse_core: int >= 1") .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { c->set_output(0, c->input(6)); c->set_output(1, c->input(7)); @@ -461,6 +592,7 @@ REGISTER_OP("XlaSparseDenseMatmulGradWithAdamAndStaticBufferSize") .Attr("max_ids_per_sparse_core: int >= 1") .Attr("max_unique_ids_per_sparse_core: int >= 1") .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { c->set_output(0, c->input(6)); c->set_output(1, c->input(7)); @@ -492,6 +624,7 @@ REGISTER_OP("XlaSparseDenseMatmulGradWithFtrlAndStaticBufferSize") .Attr("max_ids_per_sparse_core: int >= 1") .Attr("max_unique_ids_per_sparse_core: int >= 1") .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { c->set_output(0, c->input(6)); c->set_output(1, c->input(7)); @@ -513,6 +646,7 @@ REGISTER_OP("XlaSparseDenseMatmulGradWithCsrInput") .Attr("M: int >= 1") .Attr("custom_computation: func") .Attr("table_name: string") + .Attr("num_sparsecores_per_device: int = -1") .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { int num_tables; TF_RETURN_IF_ERROR(c->GetAttr("N", &num_tables)); @@ -522,4 +656,271 @@ REGISTER_OP("XlaSparseDenseMatmulGradWithCsrInput") return absl::OkStatus(); }); +REGISTER_OP("XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput") + .Input("row_pointers: int32") + .Input("sorted_sample_ids: int32") + .Input("sorted_token_ids: int32") + .Input("sorted_pos_ids: int32") + .Input("sorted_gains: float32") + .Input("weights: float32") + .Input("preserved_valencies: int32") + .Input("preserved_vectors: float32") + .Input("preserved_weights: float32") + .Input("activation_gradients: float32") + .Input("learning_rate: float32") + .Input("combiner_weights_learning_rate: float32") + .Input("embedding_table: float32") + .Output("updated_embedding_table: float32") + .Output("updated_weights: float32") + .Attr("clip_weight_min: float = -inf") + .Attr("clip_weight_max: float = inf") + .Attr("max_valency: int >= 0") + .Attr("num_weights: int >= 0") + .Attr("combiner_table_vjp_computation: func") + .Attr("combiner_weights_vjp_computation: func") + .Attr("table_name: string") + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { + constexpr int kWeightsIndex = 5; + constexpr int kPreservedValenciesIndex = 6; + constexpr int kPreservedVectorsIndex = 7; + constexpr int kPreservedWeightsIndex = 8; + constexpr int kActivationGradientsIndex = 9; + constexpr int kTablesIndex = 12; + constexpr int kNumTables = 1; + TF_RETURN_IF_ERROR( + ValidateSparseDenseMatmulCustomCombinerGradWithCsrInputShape( + c, kWeightsIndex, kPreservedValenciesIndex, + kPreservedVectorsIndex, kPreservedWeightsIndex, + kActivationGradientsIndex, kTablesIndex, kNumTables)); + return absl::OkStatus(); + }); + +REGISTER_OP("XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput") + .Input("row_pointers: int32") + .Input("sorted_sample_ids: int32") + .Input("sorted_token_ids: int32") + .Input("sorted_pos_ids: int32") + .Input("sorted_gains: float32") + .Input("weights: float32") + .Input("preserved_valencies: int32") + .Input("preserved_vectors: float32") + .Input("preserved_weights: float32") + .Input("activation_gradients: float32") + .Input("learning_rate: float32") + .Input("combiner_weights_learning_rate: float32") + .Input("embedding_table: float32") + .Input("accumulator: float32") + .Output("updated_embedding_table: float32") + .Output("updated_accumulator: float32") + .Output("updated_weights: float32") + .Attr("clip_weight_min: float = -inf") + .Attr("clip_weight_max: float = inf") + .Attr("max_valency: int >= 0") + .Attr("num_weights: int >= 0") + .Attr("combiner_table_vjp_computation: func") + .Attr("combiner_weights_vjp_computation: func") + .Attr("table_name: string") + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { + constexpr int kWeightsIndex = 5; + constexpr int kPreservedValenciesIndex = 6; + constexpr int kPreservedVectorsIndex = 7; + constexpr int kPreservedWeightsIndex = 8; + constexpr int kActivationGradientsIndex = 9; + constexpr int kTablesIndex = 12; + constexpr int kNumTables = 2; + TF_RETURN_IF_ERROR( + ValidateSparseDenseMatmulCustomCombinerGradWithCsrInputShape( + c, kWeightsIndex, kPreservedValenciesIndex, + kPreservedVectorsIndex, kPreservedWeightsIndex, + kActivationGradientsIndex, kTablesIndex, kNumTables)); + return absl::OkStatus(); + }); + +REGISTER_OP( + "XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput") + .Input("row_pointers: int32") + .Input("sorted_sample_ids: int32") + .Input("sorted_token_ids: int32") + .Input("sorted_pos_ids: int32") + .Input("sorted_gains: float32") + .Input("weights: float32") + .Input("preserved_valencies: int32") + .Input("preserved_vectors: float32") + .Input("preserved_weights: float32") + .Input("activation_gradients: float32") + .Input("learning_rate: float32") + .Input("combiner_weights_learning_rate: float32") + .Input("embedding_table: float32") + .Input("accumulator: float32") + .Input("momenta: float32") + .Output("updated_embedding_table: float32") + .Output("updated_accumulator: float32") + .Output("updated_momenta: float32") + .Output("updated_weights: float32") + .Attr("use_nesterov: bool") + .Attr("exponent: float") + .Attr("beta1: float") + .Attr("beta2: float") + .Attr("epsilon: float") + .Attr("clip_weight_min: float = -inf") + .Attr("clip_weight_max: float = inf") + .Attr("max_valency: int >= 0") + .Attr("num_weights: int >= 0") + .Attr("combiner_table_vjp_computation: func") + .Attr("combiner_weights_vjp_computation: func") + .Attr("table_name: string") + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { + constexpr int kWeightsIndex = 5; + constexpr int kPreservedValenciesIndex = 6; + constexpr int kPreservedVectorsIndex = 7; + constexpr int kPreservedWeightsIndex = 8; + constexpr int kActivationGradientsIndex = 9; + constexpr int kTablesIndex = 12; + constexpr int kNumTables = 3; + TF_RETURN_IF_ERROR( + ValidateSparseDenseMatmulCustomCombinerGradWithCsrInputShape( + c, kWeightsIndex, kPreservedValenciesIndex, + kPreservedVectorsIndex, kPreservedWeightsIndex, + kActivationGradientsIndex, kTablesIndex, kNumTables)); + return absl::OkStatus(); + }); + +REGISTER_OP("XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput") + .Input("row_pointers: int32") + .Input("sorted_sample_ids: int32") + .Input("sorted_token_ids: int32") + .Input("sorted_pos_ids: int32") + .Input("sorted_gains: float32") + .Input("weights: float32") + .Input("preserved_valencies: int32") + .Input("preserved_vectors: float32") + .Input("preserved_weights: float32") + .Input("activation_gradients: float32") + .Input("learning_rate: float32") + .Input("combiner_weights_learning_rate: float32") + .Input("embedding_table: float32") + .Input("momenta: float32") + .Input("velocity: float32") + .Output("updated_embedding_table: float32") + .Output("updated_momenta: float32") + .Output("updated_velocity: float32") + .Output("updated_weights: float32") + .Attr("use_sum_inside_sqrt: bool") + .Attr("beta1: float") + .Attr("beta2: float") + .Attr("epsilon: float") + .Attr("clip_weight_min: float = -inf") + .Attr("clip_weight_max: float = inf") + .Attr("max_valency: int >= 0") + .Attr("num_weights: int >= 0") + .Attr("combiner_table_vjp_computation: func") + .Attr("combiner_weights_vjp_computation: func") + .Attr("table_name: string") + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { + constexpr int kWeightsIndex = 5; + constexpr int kPreservedValenciesIndex = 6; + constexpr int kPreservedVectorsIndex = 7; + constexpr int kPreservedWeightsIndex = 8; + constexpr int kActivationGradientsIndex = 9; + constexpr int kTablesIndex = 12; + constexpr int kNumTables = 3; + TF_RETURN_IF_ERROR( + ValidateSparseDenseMatmulCustomCombinerGradWithCsrInputShape( + c, kWeightsIndex, kPreservedValenciesIndex, + kPreservedVectorsIndex, kPreservedWeightsIndex, + kActivationGradientsIndex, kTablesIndex, kNumTables)); + return absl::OkStatus(); + }); + +REGISTER_OP("XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput") + .Input("row_pointers: int32") + .Input("sorted_sample_ids: int32") + .Input("sorted_token_ids: int32") + .Input("sorted_pos_ids: int32") + .Input("sorted_gains: float32") + .Input("weights: float32") + .Input("preserved_valencies: int32") + .Input("preserved_vectors: float32") + .Input("preserved_weights: float32") + .Input("activation_gradients: float32") + .Input("learning_rate: float32") + .Input("combiner_weights_learning_rate: float32") + .Input("embedding_table: float32") + .Input("accumulator: float32") + .Input("linear: float32") + .Output("updated_embedding_table: float32") + .Output("updated_accumulator: float32") + .Output("updated_linear: float32") + .Output("updated_weights: float32") + .Attr("multiply_linear_by_learning_rate: bool") + .Attr("beta: float") + .Attr("learning_rate_power: float") + .Attr("l1_regularization_strength: float") + .Attr("l2_regularization_strength: float") + .Attr("clip_weight_min: float = -inf") + .Attr("clip_weight_max: float = inf") + .Attr("max_valency: int >= 0") + .Attr("num_weights: int >= 0") + .Attr("combiner_table_vjp_computation: func") + .Attr("combiner_weights_vjp_computation: func") + .Attr("table_name: string") + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { + constexpr int kWeightsIndex = 5; + constexpr int kPreservedValenciesIndex = 6; + constexpr int kPreservedVectorsIndex = 7; + constexpr int kPreservedWeightsIndex = 8; + constexpr int kActivationGradientsIndex = 9; + constexpr int kTablesIndex = 12; + constexpr int kNumTables = 3; + TF_RETURN_IF_ERROR( + ValidateSparseDenseMatmulCustomCombinerGradWithCsrInputShape( + c, kWeightsIndex, kPreservedValenciesIndex, + kPreservedVectorsIndex, kPreservedWeightsIndex, + kActivationGradientsIndex, kTablesIndex, kNumTables)); + return absl::OkStatus(); + }); + +REGISTER_OP("XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput") + .Input("row_pointers: int32") + .Input("sorted_sample_ids: int32") + .Input("sorted_token_ids: int32") + .Input("sorted_pos_ids: int32") + .Input("sorted_gains: float32") + .Input("weights: float32") + // We need to preserve the outputs of the SC forward pass and feed them into + // the VJP computations in the backward pass. + .Input("preserved_valencies: int32") + .Input("preserved_vectors: float32") + .Input("preserved_weights: float32") + .Input("activation_gradients: float32") + .Input("tables: N * float32") + .Input("hyperparameters: M * float32") + .Input("combiner_weights_learning_rate: float32") + .Output("updated_tables: N * float32") + .Output("updated_weights: float32") + .Attr("N: int >= 1") + .Attr("M: int >= 1") + .Attr("max_valency: int >= 0") + .Attr("num_weights: int >= 0") + .Attr("combiner_table_vjp_computation: func") + .Attr("combiner_weights_vjp_computation: func") + .Attr("optimizer_custom_computation: func") + .Attr("table_name: string") + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { + constexpr int kWeightsIndex = 5; + constexpr int kPreservedValenciesIndex = 6; + constexpr int kPreservedVectorsIndex = 7; + constexpr int kPreservedWeightsIndex = 8; + constexpr int kActivationGradientsIndex = 9; + constexpr int kTablesIndex = 10; + int num_tables; + TF_RETURN_IF_ERROR(c->GetAttr("N", &num_tables)); + TF_RETURN_IF_ERROR( + ValidateSparseDenseMatmulCustomCombinerGradWithCsrInputShape( + c, kWeightsIndex, kPreservedValenciesIndex, + kPreservedVectorsIndex, kPreservedWeightsIndex, + kActivationGradientsIndex, kTablesIndex, num_tables)); + return absl::OkStatus(); + }); + } // namespace tensorflow diff --git a/tensorflow/core/tpu/tpu_execute.cc b/tensorflow/core/tpu/tpu_execute.cc index d077a0bec849ff..12009787a1094c 100644 --- a/tensorflow/core/tpu/tpu_execute.cc +++ b/tensorflow/core/tpu/tpu_execute.cc @@ -134,10 +134,10 @@ absl::Status FixTupleTableAsync(se::Stream* stream, // "bounded_shape". bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) { - if (dynamic_shape.dimensions_size() != bounded_shape.dimensions_size()) { + if (dynamic_shape.dimensions().size() != bounded_shape.dimensions().size()) { return false; } - for (int64_t i = 0; i < dynamic_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < dynamic_shape.dimensions().size(); ++i) { if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) { return false; } diff --git a/tensorflow/core/transforms/BUILD b/tensorflow/core/transforms/BUILD index 9f887a843ba4be..5ebaa61c329d3e 100644 --- a/tensorflow/core/transforms/BUILD +++ b/tensorflow/core/transforms/BUILD @@ -15,16 +15,11 @@ package( gentbl_cc_library( name = "PassIncGen", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "--name", - "TFGraph", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "--name", + "TFGraph", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", deps = [ @@ -69,12 +64,7 @@ cc_library( gentbl_cc_library( name = "PDLLUtilsIncGen", - tbl_outs = [ - ( - ["-x=cpp"], - "utils/pdll/PDLLUtils.h.inc", - ), - ], + tbl_outs = {"utils/pdll/PDLLUtils.h.inc": ["-x=cpp"]}, tblgen = "@llvm-project//mlir:mlir-pdll", td_file = "utils/pdll/utils.pdll", deps = [ diff --git a/tensorflow/core/transforms/constant_folding/BUILD b/tensorflow/core/transforms/constant_folding/BUILD index 9f18c597b42715..d3fe6d77e8769c 100644 --- a/tensorflow/core/transforms/constant_folding/BUILD +++ b/tensorflow/core/transforms/constant_folding/BUILD @@ -28,6 +28,8 @@ cc_library( "//tensorflow/core/transforms:op_cat_helper", "//tensorflow/core/transforms:utils", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:Dialect", diff --git a/tensorflow/core/transforms/constant_folding/pass.cc b/tensorflow/core/transforms/constant_folding/pass.cc index 68f3a0f0a23a65..d230fb54e1ae3b 100644 --- a/tensorflow/core/transforms/constant_folding/pass.cc +++ b/tensorflow/core/transforms/constant_folding/pass.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/core/transforms/constant_folding/pass.h" #include -#include +#include +#include +#include #include #include #include @@ -26,6 +28,8 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/match.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/core/transforms/remapper/BUILD b/tensorflow/core/transforms/remapper/BUILD index 1b0c4ca4ded504..455abdaf32d851 100644 --- a/tensorflow/core/transforms/remapper/BUILD +++ b/tensorflow/core/transforms/remapper/BUILD @@ -13,12 +13,7 @@ package( gentbl_cc_library( name = "MklPDLLPatternsIncGen", - tbl_outs = [ - ( - ["-x=cpp"], - "pdll/MklPDLLPatterns.h.inc", - ), - ], + tbl_outs = {"pdll/MklPDLLPatterns.h.inc": ["-x=cpp"]}, tblgen = "@llvm-project//mlir:mlir-pdll", td_file = "pdll/mkl_patterns.pdll", deps = [ diff --git a/tensorflow/core/transforms/remapper/pdll/mkl_patterns.pdll b/tensorflow/core/transforms/remapper/pdll/mkl_patterns.pdll index 3003dc282418b2..e4286a6fff0d5a 100644 --- a/tensorflow/core/transforms/remapper/pdll/mkl_patterns.pdll +++ b/tensorflow/core/transforms/remapper/pdll/mkl_patterns.pdll @@ -16,10 +16,10 @@ #include "tensorflow/core/transforms/utils/pdll/utils.pdll" Constraint AttrIsF32OrBF16(attr: Attr) [{ - TypeAttr type_attr = attr.dyn_cast(); + TypeAttr type_attr = llvm::dyn_cast(attr); if (!type_attr) return failure(); - return success(type_attr.getValue().isa() || - type_attr.getValue().isa()); + return success(llvm::isa(type_attr.getValue()) || + llvm::isa(type_attr.getValue())); }]; Rewrite ReplaceMulWith_MklSwish(op: Op, arg: Value, controls: ValueRange) -> Op [{ diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index b4fac84e7aa017..18d58405287bbf 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -1582,9 +1582,10 @@ absl::Status FastParseSingleExample(const Config& config, return errors::InvalidArgument("Key: ", feature_name, ". ", suffix); }; - auto parse_error = [feature_name] { - return errors::InvalidArgument("Key: ", feature_name, - ". Can't parse serialized Example."); + auto parse_error = [feature_name](absl::string_view description) { + return errors::InvalidArgument( + "Key: ", feature_name, + ". Can't parse serialized Example: ", description); }; DataType example_dtype; @@ -1619,27 +1620,30 @@ absl::Status FastParseSingleExample(const Config& config, case DT_INT64: { auto out_p = out->flat().data(); LimitedArraySlice slice(out_p, num_elements); - if (!feature.ParseInt64List(&slice)) return parse_error(); + if (!feature.ParseInt64List(&slice)) + return parse_error("Parsing int64_list failed."); if (slice.EndDistance() != 0) { - return parse_error(); + return parse_error("Some int64_list slice was not parsed."); } break; } case DT_FLOAT: { auto out_p = out->flat().data(); LimitedArraySlice slice(out_p, num_elements); - if (!feature.ParseFloatList(&slice)) return parse_error(); + if (!feature.ParseFloatList(&slice)) + return parse_error("Parsing float_list failed."); if (slice.EndDistance() != 0) { - return parse_error(); + return parse_error("Some float_list slice was not parsed."); } break; } case DT_STRING: { auto out_p = out->flat().data(); LimitedArraySlice slice(out_p, num_elements); - if (!feature.ParseBytesList(&slice)) return parse_error(); + if (!feature.ParseBytesList(&slice)) + return parse_error("Parsing bytes_list failed."); if (slice.EndDistance() != 0) { - return parse_error(); + return parse_error("Some bytes_list slice was not parsed."); } break; } @@ -1697,22 +1701,25 @@ absl::Status FastParseSingleExample(const Config& config, case DT_INT64: { // TODO(mrry): Use the fact that the `int64_list` is packed to read // out the length and pre-allocate the output tensor. - if (!feature.ParseInt64List(&int64_list)) return parse_error(); + if (!feature.ParseInt64List(&int64_list)) + return parse_error("Parsing int64_list failed."); num_elements = int64_list.size(); break; } case DT_FLOAT: { - if (!feature.ParseFloatList(&float_list)) return parse_error(); + if (!feature.ParseFloatList(&float_list)) + return parse_error("Parsing float_list failed."); num_elements = float_list.size(); break; } case DT_STRING: { int actual_num_elements = 0; if (!feature.GetNumElementsInBytesList(&actual_num_elements)) { - return parse_error(); + return parse_error("Could not get num elements in bytes_list."); } bytes_list.reserve(actual_num_elements); - if (!feature.ParseBytesList(&bytes_list)) return parse_error(); + if (!feature.ParseBytesList(&bytes_list)) + return parse_error("Parsing bytes_list failed."); num_elements = bytes_list.size(); break; } @@ -1778,7 +1785,9 @@ absl::Status FastParseSingleExample(const Config& config, } case DT_FLOAT: { if (!out->CopyFrom(float_list.tensor(), out_shape)) { - return parse_error(); + return parse_error(absl::StrCat("Size of float_list is ", + float_list.tensor().dims(), + ", expected ", out_shape.dims())); } break; } diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc index 1eab3903134ed7..b44978fcc84816 100644 --- a/tensorflow/core/util/example_proto_fast_parsing_test.cc +++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/lib/random/philox_random.h" @@ -44,7 +45,7 @@ string SerializedToReadable(string serialized) { string result; result += '"'; for (char c : serialized) - result += strings::StrCat("\\x", strings::Hex(c, strings::kZeroPad2)); + absl::StrAppend(&result, "\\x", absl::Hex(c, absl::kZeroPad2)); result += '"'; return result; } diff --git a/tensorflow/core/util/gpu_device_functions.h b/tensorflow/core/util/gpu_device_functions.h index 665e3938326e1f..ead9477563b4f9 100644 --- a/tensorflow/core/util/gpu_device_functions.h +++ b/tensorflow/core/util/gpu_device_functions.h @@ -196,7 +196,7 @@ __device__ const unsigned kGpuWarpAll = 0xffffffff; __device__ inline unsigned GpuLaneId() { unsigned int lane_id; #if GOOGLE_CUDA -#if __clang__ +#if __clang__ && !__NVCC__ return __nvvm_read_ptx_sreg_laneid(); #else // __clang__ asm("mov.u32 %0, %%laneid;" : "=r"(lane_id)); diff --git a/tensorflow/core/util/use_cudnn.h b/tensorflow/core/util/use_cudnn.h index ba13b74016ce7e..18cb057b30b7d1 100644 --- a/tensorflow/core/util/use_cudnn.h +++ b/tensorflow/core/util/use_cudnn.h @@ -27,7 +27,6 @@ namespace tensorflow { using tsl::CudnnDisableConv1x1Optimization; using tsl::CudnnRnnUseAutotune; using tsl::CudnnUseAutotune; -using tsl::CudnnUseFrontend; using tsl::CudnnUseRuntimeFusion; using tsl::DebugCudnnRnn; using tsl::DebugCudnnRnnAlgo; diff --git a/tensorflow/dtensor/build_defs.bzl b/tensorflow/dtensor/build_defs.bzl index a40457b340595b..8c05193d751e9f 100644 --- a/tensorflow/dtensor/build_defs.bzl +++ b/tensorflow/dtensor/build_defs.bzl @@ -11,8 +11,6 @@ ALL_BACKENDS = [ TPU_V3_DONUT_BACKEND = "tpu_v3_2x2" # 8 TPU devices; includes TFRT and non-TFRT targets TPU_V4_DONUT_BACKEND = "tpu_v4_2x2" # 8 TPU devices for non-Megacore targets and 4 for Megacore targets GPU_2DEVS_BACKEND = "2gpus" # 2 Physical GPUs. -PATHWAYS = "pw" -PATHWAYS_V3_DONUT_BACKEND = "pw_v3_2x2" # LINT.ThenChange( # python/tests/test_backend_name.py:backend_name, # python/tests/test_backend_name.oss.py:backend_name @@ -42,10 +40,6 @@ def _get_configurations( ], TPU_V4_DONUT_BACKEND: [ ], - PATHWAYS: [ - ], - PATHWAYS_V3_DONUT_BACKEND: [ - ], } configurations = [ dict(suffix = "cpu", backend = "cpu", tags = [], flags = [], env = {}, deps = []), diff --git a/tensorflow/dtensor/cc/BUILD b/tensorflow/dtensor/cc/BUILD index a2f0f88aa8b777..c3493376bfb2bb 100644 --- a/tensorflow/dtensor/cc/BUILD +++ b/tensorflow/dtensor/cc/BUILD @@ -1,7 +1,6 @@ #include "third_party/absl/strings/str_cat.h" #DTensor C++ runtime and libraries. -load("//tensorflow:tensorflow.bzl", "if_google", "if_libtpu") load("//tensorflow:tensorflow.default.bzl", "tf_kernel_library") load( "//tensorflow/core/platform:build_config.bzl", @@ -289,13 +288,7 @@ cc_library( cc_library( name = "default_parallel_executor_lib", - deps = if_libtpu( - if_false = if_google( - google_value = ["//tensorflow/dtensor/cc/google:default_parallel_executor"], - oss_value = [":default_parallel_executor"], - ), - if_true = [":default_parallel_executor"], - ), + deps = [":default_parallel_executor"], ) cc_library( diff --git a/tensorflow/dtensor/cc/dtensor_device.cc b/tensorflow/dtensor/cc/dtensor_device.cc index d55f2291b932d2..25e83bbe37409d 100644 --- a/tensorflow/dtensor/cc/dtensor_device.cc +++ b/tensorflow/dtensor/cc/dtensor_device.cc @@ -135,15 +135,9 @@ class DTensorDevice { static StatusOr Create(absl::string_view name, bool is_async, int in_flight_nodes_limit) { std::string use_parallel_executor; - TF_RETURN_IF_ERROR(tsl::ReadStringFromEnvVar( - "DTENSOR_USE_PARALLEL_EXECUTOR", "", &use_parallel_executor)); - std::unique_ptr parallel_executor; - if (!use_parallel_executor.empty()) { - TF_ASSIGN_OR_RETURN(parallel_executor, CreateDefaultParallelExecutor()); - } auto eager_executor = std::make_unique( is_async, /*enable_streaming_enqueue=*/true, in_flight_nodes_limit); - return new DTensorDevice(name, std::move(parallel_executor), + return new DTensorDevice(name, /*parallel_executor=*/nullptr, std::move(eager_executor), is_async, in_flight_nodes_limit); } diff --git a/tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.cc b/tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.cc index 12d032bde9c56b..027f53cc3fc3e2 100644 --- a/tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.cc +++ b/tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.h" +#include #include #include diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD index d3bdcb73f0839e..08ffac8e75b859 100644 --- a/tensorflow/dtensor/mlir/BUILD +++ b/tensorflow/dtensor/mlir/BUILD @@ -20,16 +20,10 @@ package( gentbl_cc_library( name = "tensorflow_dtensor_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_dtensor.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_dtensor.cc.inc", - ), - ], + tbl_outs = { + "ir/tf_dtensor.h.inc": ["-gen-op-decls"], + "ir/tf_dtensor.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_dtensor.td", td_srcs = [ @@ -48,13 +42,10 @@ gentbl_cc_library( gentbl_cc_library( name = "dtensor_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [( - [ - "-gen-pass-decls", - "-name=DTensor", - ], - "dtensor_passes.h.inc", - )], + tbl_outs = {"dtensor_passes.h.inc": [ + "-gen-pass-decls", + "-name=DTensor", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "Passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], diff --git a/tensorflow/dtensor/mlir/dtensor_dialect/BUILD b/tensorflow/dtensor/mlir/dtensor_dialect/BUILD index c0cca6cf3846f4..e32cb17a0bdaa2 100644 --- a/tensorflow/dtensor/mlir/dtensor_dialect/BUILD +++ b/tensorflow/dtensor/mlir/dtensor_dialect/BUILD @@ -33,24 +33,12 @@ td_library( gentbl_cc_library( name = "DialectIncGen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/ops.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/ops.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "ir/dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "ir/dialect.cc.inc", - ), - ], + tbl_outs = { + "ir/ops.h.inc": ["-gen-op-decls"], + "ir/ops.cc.inc": ["-gen-op-defs"], + "ir/dialect.h.inc": ["-gen-dialect-decls"], + "ir/dialect.cc.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/dtensor_ops.td", deps = [":dtensor_td_files"], diff --git a/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc b/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc index 15b065cf0ec2f5..3235badda66bd0 100644 --- a/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc +++ b/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc @@ -43,16 +43,32 @@ namespace { using mlir::TF::DTensorLayout; class RemoveDTensorLayoutAfterConstOrBlockArgPattern - : public mlir::OpRewritePattern::SplitMatchAndRewrite { + : public mlir::OpRewritePattern { public: - using SplitMatchAndRewrite::SplitMatchAndRewrite; + using OpRewritePattern::OpRewritePattern; - mlir::LogicalResult match(DTensorLayout layout_op) const override; - - void rewrite(DTensorLayout layout_op, - mlir::PatternRewriter& rewriter) const override { + mlir::LogicalResult matchAndRewrite( + DTensorLayout layout_op, mlir::PatternRewriter& rewriter) const override { + if (match(layout_op).failed()) { + return mlir::failure(); + } rewriter.replaceAllUsesWith(layout_op, layout_op.getInput()); rewriter.eraseOp(layout_op); + return mlir::success(); + } + + private: + mlir::LogicalResult match(DTensorLayout layout_op) const { + auto input = layout_op.getInput(); + if (mlir::isa(input)) { + return mlir::success(); + } + mlir::Operation* input_op = input.getDefiningOp(); + if (input_op != nullptr) { + return mlir::success(input_op->hasTrait()); + } else { + return layout_op->emitOpError() << "Can't find defining op for " << input; + } } }; @@ -63,20 +79,6 @@ class DTensorLayoutToXlaShardingOpPass void runOnOperation() override; }; -mlir::LogicalResult RemoveDTensorLayoutAfterConstOrBlockArgPattern::match( - DTensorLayout layout_op) const { - auto input = layout_op.getInput(); - if (mlir::isa(input)) { - return mlir::success(); - } - mlir::Operation* input_op = input.getDefiningOp(); - if (input_op != nullptr) { - return mlir::success(input_op->hasTrait()); - } else { - return layout_op->emitOpError() << "Can't find defining op for " << input; - } -} - void DTensorLayoutToXlaShardingOpPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); // Some patterns in tf2xla requires operands to be ConstantLike. diff --git a/tensorflow/dtensor/mlir/ir/tf_dtensor.td b/tensorflow/dtensor/mlir/ir/tf_dtensor.td index 999d8df041e74d..11a6ea761e00aa 100644 --- a/tensorflow/dtensor/mlir/ir/tf_dtensor.td +++ b/tensorflow/dtensor/mlir/ir/tf_dtensor.td @@ -31,17 +31,17 @@ include "mlir/IR/OpBase.td" //===----------------------------------------------------------------------===// class DTensor_DTensorAttr : - Attr()">, + Attr($_self)">, "DTensor " # description # " attribute">; def DTensor_LayoutAttr : DTensor_DTensorAttr<"Layout", "layout"> { let returnType = "mlir::dtensor::LayoutAttr::Layout"; - let convertFromStorage = "$_self.cast().getValue()"; + let convertFromStorage = "llvm::cast($_self).getValue()"; } def DTensor_MeshAttr : DTensor_DTensorAttr<"Mesh", "mesh"> { let returnType = "mlir::dtensor::MeshAttr::Mesh"; - let convertFromStorage = "$_self.cast().getValue()"; + let convertFromStorage = "llvm::cast($_self).getValue()"; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/dtensor/mlir/shape_utils.cc b/tensorflow/dtensor/mlir/shape_utils.cc index 0864fe28ba074a..27f9fc5f7f0a70 100644 --- a/tensorflow/dtensor/mlir/shape_utils.cc +++ b/tensorflow/dtensor/mlir/shape_utils.cc @@ -73,7 +73,7 @@ StatusOr> ExtractGlobalInputShape( return errors::Internal("global_shape does not have static rank"); return *global_shape; } - return ExtractGlobalOutputShape(input_value.get().cast()); + return ExtractGlobalOutputShape(cast(input_value.get())); } // If we reach this point, we're working with a function argument. @@ -85,7 +85,7 @@ StatusOr> ExtractGlobalInputShape( operand_index, op->getName()) .str()); - auto block_arg = input_value.get().dyn_cast(); + auto block_arg = mlir::dyn_cast(input_value.get()); auto global_shape_attr = enclosing_function.getArgAttrOfType( block_arg.getArgNumber(), kGlobalShapeDialectAttr); @@ -127,7 +127,7 @@ StatusOr> ExtractGlobalOutputShape( .str()); auto shape_attr = global_shape_attr[output_index]; - return shape_attr.cast().getShape(); + return llvm::cast(shape_attr).getShape(); } namespace { @@ -167,7 +167,7 @@ mlir::LogicalResult InferShapeOfTFOpWithCustomOperandConstantFn( for (const auto& inferred_return_type : llvm::enumerate(inferred_return_types)) { if (auto shaped_type = - inferred_return_type.value().dyn_cast()) { + llvm::dyn_cast(inferred_return_type.value())) { if (shaped_type.hasRank()) { inferred_return_shapes[inferred_return_type.index()] = mlir::ShapedTypeComponents(shaped_type.getShape(), @@ -207,14 +207,14 @@ mlir::LogicalResult InferShapeOfTFOpWithCustomOperandConstantFn( auto op_result_as_shape_fn = [](shape_inference::InferenceContext& ic, mlir::OpResult op_result) -> shape_inference::ShapeHandle { - auto rt = op_result.getType().dyn_cast(); + auto rt = llvm::dyn_cast(op_result.getType()); if (!rt || rt.getRank() != 1 || !rt.hasStaticShape()) return {}; std::vector dims(rt.getDimSize(0), ic.UnknownDim()); mlir::Attribute attr; if (matchPattern(op_result, m_Constant(&attr))) { - auto elements = attr.dyn_cast(); + auto elements = llvm::dyn_cast(attr); if (elements) for (const auto& element : llvm::enumerate(elements.getValues())) @@ -241,10 +241,8 @@ absl::Status InferSPMDExpandedLocalShapeForResourceOutput( GetGlobalShapeOfValueFromDTensorLayout(*op_result)); const std::vector& local_shape = output_layout.LocalShapeFromGlobalShape(global_shape); - auto resource_type = op_result->getType() - .cast() - .getElementType() - .dyn_cast(); + auto resource_type = llvm::dyn_cast( + llvm::cast(op_result->getType()).getElementType()); auto sub_types = resource_type.getSubtypes(); auto resource_arg_sub_type = sub_types.front(); @@ -274,7 +272,7 @@ mlir::Operation* InferSPMDExpandedLocalShape(mlir::Operation* op) { const auto& return_type = std::get<0>(it); auto& op_result = std::get<1>(it); const auto element_type = - op_result.getType().cast().getElementType(); + llvm::cast(op_result.getType()).getElementType(); if (return_type.hasRank()) { op_result.setType( @@ -292,7 +290,7 @@ StatusOr> GetShapeOfValue(const mlir::Value& value, // Getting the subtype or self allows supporting extracting the underlying // shape that variant or resource tensors point to. mlir::Type type = GetSubtypeOrSelf(value); - if (auto ranked_type = type.dyn_cast()) { + if (auto ranked_type = llvm::dyn_cast(type)) { if (ranked_type.hasStaticShape() || !fail_on_dynamic) return ranked_type.getShape(); else @@ -303,7 +301,7 @@ StatusOr> GetShapeOfValue(const mlir::Value& value, StatusOr> GetGlobalShapeOfValueFromDTensorLayout( const mlir::Value& value) { - if (value.isa() && + if (mlir::isa(value) && mlir::isa(value.getDefiningOp())) { auto layout_op = mlir::cast(value.getDefiningOp()); if (layout_op.getGlobalShape()) return layout_op.getGlobalShape().value(); diff --git a/tensorflow/dtensor/mlir/spmd_expansion.cc b/tensorflow/dtensor/mlir/spmd_expansion.cc index ff7e1444520af0..434bf869a4ebc7 100644 --- a/tensorflow/dtensor/mlir/spmd_expansion.cc +++ b/tensorflow/dtensor/mlir/spmd_expansion.cc @@ -122,10 +122,8 @@ mlir::LogicalResult UpdateResourceArgumentType( return mlir::success(); } - auto resource_type = resource_arg.getType() - .cast() - .getElementType() - .dyn_cast(); + auto resource_type = llvm::dyn_cast( + llvm::cast(resource_arg.getType()).getElementType()); if (!resource_type) return mlir::success(); auto sub_types = resource_type.getSubtypes(); @@ -190,7 +188,7 @@ bool GetResourceArgIndexIfUsedInAssignmentOp( GetForwardedDTensorLayoutInput(assign_variable_op.getResource()); if (llvm::isa(resource)) { *resource_argument_index_for_assign_variable = - resource.cast().getArgNumber(); + cast(resource).getArgNumber(); return true; } } @@ -223,16 +221,14 @@ mlir::LogicalResult UpdateFunctionArgsUsingLayout(mlir::func::FuncOp function) { // If argument is a resource type update the subtype shape information // to reflect local shape of resources. - if (arg_type.isa()) { + if (isa(arg_type)) { if (mlir::failed(UpdateResourceArgumentType(argument_index, function))) return mlir::failure(); continue; } - mlir::RankedTensorType ranked_type = - function.getFunctionType() - .getInput(argument_index) - .dyn_cast(); + mlir::RankedTensorType ranked_type = llvm::dyn_cast( + function.getFunctionType().getInput(argument_index)); if (!ranked_type) continue; // If input value is non-resource type, then update the value to reflect @@ -266,7 +262,8 @@ mlir::LogicalResult UpdateFunctionWithLocalInputShapes( mlir::func::FuncOp function) { for (auto& operand : function_operands) { const int index = operand.getOperandNumber(); - auto arg_type = operand.get().getType().dyn_cast(); + auto arg_type = + llvm::dyn_cast(operand.get().getType()); if (!arg_type) continue; auto arg_local_shape = arg_type.getShape(); diff --git a/tensorflow/dtensor/mlir/value_utils.cc b/tensorflow/dtensor/mlir/value_utils.cc index aff45541759515..e9240996904fd0 100644 --- a/tensorflow/dtensor/mlir/value_utils.cc +++ b/tensorflow/dtensor/mlir/value_utils.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -57,7 +58,8 @@ mlir::Value GetForwardedInput(mlir::Value value) { bool value_updated; do { value_updated = false; - if (mlir::BlockArgument argument = value.dyn_cast()) { + if (mlir::BlockArgument argument = + mlir::dyn_cast(value)) { mlir::Region* region = argument.getParentRegion(); if (region == nullptr) break; mlir::Operation* parent_op = region->getParentOp(); @@ -86,7 +88,7 @@ namespace ops_util = ::mlir::TF::collection_ops_util; int ValueRank(mlir::Value operand_value) { mlir::Type type = GetSubtypeOrSelf(operand_value); - const auto operand_type = type.cast(); + const auto operand_type = llvm::cast(type); if (!operand_type.hasRank()) return -1; return operand_type.getRank(); } @@ -116,7 +118,7 @@ mlir::Value IntConst(mlir::OpBuilder& builder, mlir::Location loc, } StatusOr> GetTFShapeFromType(mlir::Type type) { - auto ranked_type = type.dyn_cast(); + auto ranked_type = llvm::dyn_cast(type); if (!ranked_type) { return errors::InvalidArgument( llvm::formatv("Type {0} is not a RankedTensorType.", type).str()); @@ -166,7 +168,7 @@ mlir::Value IntConstWithMatchingType(mlir::OpBuilder& builder, mlir::Location loc, llvm::ArrayRef values, mlir::Type type) { - if (type.cast().getElementType().isInteger(64)) { + if (llvm::cast(type).getElementType().isInteger(64)) { return Int64Const(builder, loc, values); } else { llvm::SmallVector values32(values.begin(), values.end()); @@ -176,7 +178,7 @@ mlir::Value IntConstWithMatchingType(mlir::OpBuilder& builder, StatusOr ExtractConstIntFromValue(mlir::Value value) { value = GetForwardedInput(value); - if (value.isa()) + if (mlir::isa(value)) return errors::Internal("unable get constant value from block argument"); mlir::DenseIntElementsAttr attr; if (!matchPattern(value, m_Constant(&attr))) { @@ -195,7 +197,7 @@ StatusOr ExtractConstIntFromValue(mlir::Value value) { absl::Status ExtractConstVectorFromValue( mlir::Value value, llvm::SmallVector* out_vector) { value = GetForwardedInput(value); - if (value.isa()) + if (mlir::isa(value)) return errors::Internal("unable get constant value from block argument"); mlir::DenseIntElementsAttr attr; if (!matchPattern(value, m_Constant(&attr))) { @@ -263,7 +265,7 @@ StatusOr SelectScalarValueFromArray(mlir::OpBuilder& builder, int index, mlir::Location location, mlir::Value array) { - mlir::TensorType arrayType = array.getType().cast(); + mlir::TensorType arrayType = llvm::cast(array.getType()); if (arrayType.getRank() != 2 || arrayType.getDimSize(0) != 1) { return errors::InvalidArgument("Input array must have shape [1, N]."); } @@ -289,8 +291,8 @@ StatusOr SelectScalarValueFromArray(mlir::OpBuilder& builder, mlir::Type GetSubtypeOrSelf(mlir::Value val) { mlir::Type type = val.getType(); if (auto type_with_subtype = - mlir::getElementTypeOrSelf(val) - .dyn_cast()) { + mlir::dyn_cast( + mlir::getElementTypeOrSelf(val))) { if (type_with_subtype.GetSubtypes().size() == 1) { type = type_with_subtype.GetSubtypes().front(); } @@ -299,10 +301,8 @@ mlir::Type GetSubtypeOrSelf(mlir::Value val) { } bool IsResourceType(mlir::Value val) { - return val.getType() - .cast() - .getElementType() - .isa(); + return mlir::isa( + mlir::cast(val.getType()).getElementType()); } } // namespace dtensor diff --git a/tensorflow/dtensor/python/accelerator_util.py b/tensorflow/dtensor/python/accelerator_util.py index b1e96c169de4e1..08b6ebffef96d7 100644 --- a/tensorflow/dtensor/python/accelerator_util.py +++ b/tensorflow/dtensor/python/accelerator_util.py @@ -153,8 +153,6 @@ def initialize_accelerator_system( The default value is `localhost` in local mode, and `worker` when in the multi-client mode. All DTensor clients within the same multi-client cluster share the same job name. - - `DTENSOR_USE_PARALLEL_EXECUTOR`: string, with its value being `pw` to - specify that the backend is Pathways, and TensorFlow otherwise. Args: device_type: Type of accelerator to use, can be CPU, GPU, or TPU. If None, @@ -259,7 +257,7 @@ def initialize_accelerator_system( )._collective_use_nccl_communication = config.gpu_use_nccl_communication( ) - if device_type == "TPU" and not config.backend_is_pw(): + if device_type == "TPU": tpu_util.initialize_tpu_system(use_megacore=experimental_enable_megcore) _INITIALIZED_ACCELERATOR_SYSTEM_TYPE = device_type @@ -291,7 +289,7 @@ def shutdown_accelerator_system() -> None: "not supported." ) - if device_type == "TPU" and not config.backend_is_pw(): + if device_type == "TPU": tpu_util.shutdown_tpu_system() # reset TF context to stop gRPC servers. diff --git a/tensorflow/dtensor/python/config.py b/tensorflow/dtensor/python/config.py index d03491d20bbe70..35de028478687d 100644 --- a/tensorflow/dtensor/python/config.py +++ b/tensorflow/dtensor/python/config.py @@ -45,8 +45,8 @@ @tf_export("experimental.dtensor.local_devices", v1=[]) def local_devices( - device_type: str, - for_client_id: Optional[int] = None) -> List[tf_device.DeviceSpec]: + device_type: str, for_client_id: Optional[int] = None +) -> List[tf_device.DeviceSpec]: """Returns a list of device specs configured on this client.""" if device_type.upper() not in ["CPU", "GPU", "TPU"]: raise ValueError(f"Device type {device_type} is not CPU, GPU, or TPU.") @@ -61,7 +61,9 @@ def local_devices( replica=0, # replica is deprecated and mostly hard-coded now. task=for_client_id, device_type=device_type, - device_index=i) for i in range(num_local_devices(device_type)) + device_index=i, + ) + for i in range(num_local_devices(device_type)) ] @@ -89,11 +91,15 @@ def client_id() -> int: # If missing, assume running with a single client with client_id of 0. client_id_value = int(os.environ.get(_DT_CLIENT_ID, "0")) if client_id_value < 0: - raise ValueError(f"Environment variable {_DT_CLIENT_ID} " - f"must be >= 0, got {client_id_value}. ") + raise ValueError( + f"Environment variable {_DT_CLIENT_ID} " + f"must be >= 0, got {client_id_value}. " + ) if client_id_value >= num_clients(): - raise ValueError(f"Environment variable {_DT_CLIENT_ID} " - f"must be < {num_clients()}, got {client_id_value}") + raise ValueError( + f"Environment variable {_DT_CLIENT_ID} " + f"must be < {num_clients()}, got {client_id_value}" + ) return client_id_value @@ -110,8 +116,9 @@ def job_name() -> str: """Returns the job name used by all clients in this DTensor cluster.""" # If missing, assumes the program runs locally and use localhost as job name # per TensorFlow convention. - return os.environ.get(_DT_JOB_NAME, - "localhost" if num_clients() == 1 else "worker") + return os.environ.get( + _DT_JOB_NAME, "localhost" if num_clients() == 1 else "worker" + ) @tf_export("experimental.dtensor.full_job_name", v1=[]) @@ -160,7 +167,8 @@ def jobs() -> List[str]: raise ValueError( f"Unexpected DTENSOR_JOBS content {d_jobs}. Sort entries " "in DTENSOR_JOBS because cluster construction relies on " - "the order.") + "the order." + ) return d_jobs_list @@ -212,8 +220,3 @@ def use_multi_device_mode() -> bool: def gpu_use_nccl_communication() -> bool: """Return True if environment indicates NCCL shall be used for GPU.""" return os.environ.get("DTENSOR_GPU_USE_NCCL_COMMUNICATION", "0") != "0" - - -def backend_is_pw() -> bool: - """Return True if environment indicates the backend is Pathways.""" - return os.environ.get("DTENSOR_USE_PARALLEL_EXECUTOR") == "pw" diff --git a/tensorflow/dtensor/python/tests/BUILD b/tensorflow/dtensor/python/tests/BUILD index c01d2d32aab5cd..8fc0ee9a402b95 100644 --- a/tensorflow/dtensor/python/tests/BUILD +++ b/tensorflow/dtensor/python/tests/BUILD @@ -5,8 +5,6 @@ load( "//tensorflow/dtensor:build_defs.bzl", "ALL_BACKENDS", "GPU_2DEVS_BACKEND", - "PATHWAYS", - "PATHWAYS_V3_DONUT_BACKEND", "TPU_V3_DONUT_BACKEND", "TPU_V4_DONUT_BACKEND", "dtensor_test", @@ -196,13 +194,10 @@ dtensor_test( additional_backends = [ TPU_V3_DONUT_BACKEND, GPU_2DEVS_BACKEND, - PATHWAYS, - PATHWAYS_V3_DONUT_BACKEND, ], deps = [ ":test_util", "//tensorflow/dtensor/python:api", - "//tensorflow/dtensor/python:config", "//tensorflow/dtensor/python:d_variable", "//tensorflow/dtensor/python:dtensor_device", "//tensorflow/dtensor/python:layout", @@ -234,7 +229,6 @@ dtensor_test( deps = [ ":test_util", "//tensorflow/dtensor/python:api", - "//tensorflow/dtensor/python:config", "//tensorflow/dtensor/python:d_variable", "//tensorflow/dtensor/python:dtensor_device", "//tensorflow/dtensor/python:layout", diff --git a/tensorflow/dtensor/python/tests/collective_test.py b/tensorflow/dtensor/python/tests/collective_test.py index a6af55a5f2cd18..0f720a14a3d495 100644 --- a/tensorflow/dtensor/python/tests/collective_test.py +++ b/tensorflow/dtensor/python/tests/collective_test.py @@ -21,11 +21,9 @@ # pylint: disable=g-direct-tensorflow-import from tensorflow.dtensor.python import api -from tensorflow.dtensor.python import config from tensorflow.dtensor.python import d_variable from tensorflow.dtensor.python import dtensor_device from tensorflow.dtensor.python import layout as layout_lib -from tensorflow.dtensor.python.tests import test_backend_util from tensorflow.dtensor.python.tests import test_util from tensorflow.python.eager.polymorphic_function import polymorphic_function from tensorflow.python.framework import constant_op @@ -103,7 +101,6 @@ def testReduceOnInt8(self): self.assertDTensorEqual(expected_result, self.scalar_layout, dtensor_result) def testTwoReducesWithAssign(self): - self.skipForPathways('TODO(b/260775095)') # FIXME(b/238384852): The purpose of this test is to validate the control # dependency added by DTensor. # However, as we have no way of testing the per-device graph @@ -247,11 +244,8 @@ def testDeviceIdTensorOnSplitHost(self): # core IDs: both are range(8). So local device IDs happen to be usable here. # TODO(b/180046115): Add a device.get_tpu_core_ids method and translate # device IDs to core IDs before setting the list here. - if not config.backend_is_pw(): - device = dtensor_device.DTensorDevice(meshes=[mesh]) - device.set_tpu_core_ids('tpu_mesh', local_ids) - else: - test_backend_util.config_test_mesh(mesh) + device = dtensor_device.DTensorDevice(meshes=[mesh]) + device.set_tpu_core_ids('tpu_mesh', local_ids) layout_x = Layout.batch_sharded(mesh, _MESH_DIM_X, 2) layout_y = Layout.batch_sharded(mesh, _MESH_DIM_Y, 2) @@ -536,7 +530,6 @@ def testAllReduceCombinerWithIndirectDependency(self): # The purpose of this test is to validate the depdency check in AllReduce # AllReduce combiner (dtensor_allreduce_combine_optimization). Specifically, # the side effects from indirect dependency. - self.skipForPathways('TODO(b/260775095)') self.skipForDeviceType(['TPU'], 'This test requires 8 TPU cores.', unless_device_count_equals_to=8) diff --git a/tensorflow/dtensor/python/tests/test_backend_name.py b/tensorflow/dtensor/python/tests/test_backend_name.py index 0aa665b58bc359..8c679a3dc157ce 100644 --- a/tensorflow/dtensor/python/tests/test_backend_name.py +++ b/tensorflow/dtensor/python/tests/test_backend_name.py @@ -29,8 +29,6 @@ class DTensorTestUtilBackend(enum.Enum): TPU_STREAM_EXECUTOR = 'tpu_se' TPU_V3_DONUT_BACKEND = 'tpu_v3_2x2' TPU_V4_DONUT_BACKEND = 'tpu_v4_2x2' - PATHWAYS = 'pw' - PATHWAYS_V3_DONUT_BACKEND = 'pw_v3_2x2' DTENSOR_TEST_UTIL_BACKEND = DTensorTestUtilBackend( diff --git a/tensorflow/dtensor/python/tests/test_backend_util.py b/tensorflow/dtensor/python/tests/test_backend_util.py index 02fc82a71b7543..b52bc1cfe71340 100644 --- a/tensorflow/dtensor/python/tests/test_backend_util.py +++ b/tensorflow/dtensor/python/tests/test_backend_util.py @@ -19,8 +19,6 @@ import os from tensorflow.dtensor.python import accelerator_util -from tensorflow.dtensor.python import config -from tensorflow.dtensor.python import layout as layout_lib from tensorflow.dtensor.python.tests.test_backend_name import DTENSOR_TEST_UTIL_BACKEND from tensorflow.python.platform import test as tf_test @@ -39,16 +37,6 @@ def tearDown(self): accelerator_util.shutdown_accelerator_system() -def config_test_mesh(mesh: layout_lib.Mesh): - """No Op. - - Args: - mesh: The DTensor mesh. - """ - if config.backend_is_pw(): - del mesh - - def slice_host_devices_for_multiworker(num_clients, client_id, ports): """Configure the current process to only use a slice of devices.""" if num_clients == 0: diff --git a/tensorflow/dtensor/python/tests/test_util.py b/tensorflow/dtensor/python/tests/test_util.py index 50aa465fd969d2..d6e875a5b5fe15 100644 --- a/tensorflow/dtensor/python/tests/test_util.py +++ b/tensorflow/dtensor/python/tests/test_util.py @@ -35,7 +35,6 @@ from tensorflow.dtensor.python.config import is_tpu_present # pylint: disable=unused-import from tensorflow.dtensor.python.config import preferred_device_type # pylint: disable=unused-import from tensorflow.dtensor.python.config import use_multi_device_mode # pylint: disable=unused-import -from tensorflow.dtensor.python.tests import test_backend_util from tensorflow.dtensor.python.tests.test_backend_name import DTENSOR_TEST_UTIL_BACKEND from tensorflow.dtensor.python.tests.test_backend_name import DTensorTestUtilBackend from tensorflow.dtensor.python.tests.test_backend_util import DTensorTestBackendConfigurator @@ -83,7 +82,9 @@ def create_device_array(shape, device_type): tf_device.DeviceSpec( # pylint: disable=g-complex-comprehension job='localhost/replica:0/task:0', device_type=device_type, - device_index=i) for i in range(device_count) + device_index=i, + ) + for i in range(device_count) ]).reshape(shape) @@ -110,8 +111,10 @@ def reset_logical_devices(device_type, count): reset_context() devices = tf_config.list_physical_devices(device_type) if device_type.upper() not in ('CPU', 'GPU'): - raise ValueError('resetting logical device for non-supported device type : ' - '%s' % device_type) + raise ValueError( + 'resetting logical device for non-supported device type : %s' + % device_type + ) if count < len(devices): devices = devices[:count] @@ -125,7 +128,8 @@ def reset_logical_devices(device_type, count): if device_type.upper() == 'GPU': dev_config = context.LogicalDeviceConfiguration( memory_limit=_DEFAULT_GPU_MEMORY_LIMIT, - experimental_device_ordinal=ordinal) + experimental_device_ordinal=ordinal, + ) else: dev_config = context.LogicalDeviceConfiguration() configs.append(dev_config) @@ -183,7 +187,7 @@ def tearDown(self): @staticmethod def configTestMesh( # pylint: disable=invalid-name - device_type_mesh_map: typing.Dict[typing.Text, layout_lib.Mesh] + device_type_mesh_map: typing.Dict[typing.Text, layout_lib.Mesh], ) -> layout_lib.Mesh: """Configs corresponding mesh given test context. @@ -202,8 +206,9 @@ def configTestMesh( # pylint: disable=invalid-name def get_mesh(device_type): mesh = device_type_mesh_map.get(device_type, None) if mesh is None: - raise ValueError('Requires a %s mesh to run test on %s.' % - (device_type, device_type)) + raise ValueError( + 'Requires a %s mesh to run test on %s.' % (device_type, device_type) + ) return mesh mesh = None @@ -220,15 +225,14 @@ def get_mesh(device_type): reset_logical_devices('CPU', np.prod(mesh.shape())) accelerator_util.initialize_accelerator_system('CPU') - test_backend_util.config_test_mesh(mesh) - return mesh def skipForDeviceType( # pylint: disable=invalid-name self, device_type: typing.List[str], reason: str, - unless_device_count_equals_to=None): + unless_device_count_equals_to=None, + ): """Skip the test for the specific device_type. Args: @@ -239,16 +243,22 @@ def skipForDeviceType( # pylint: disable=invalid-name of TPUs equals to the specified count. """ physical_device_types = set( - [d.device_type for d in tf_config.list_physical_devices()]) + [d.device_type for d in tf_config.list_physical_devices()] + ) for device in device_type: if device == 'TPU' and is_tpu_present(): if unless_device_count_equals_to is None: self.skipTest(reason) - elif len(list_local_logical_devices( - device)) != unless_device_count_equals_to: + elif ( + len(list_local_logical_devices(device)) + != unless_device_count_equals_to + ): self.skipTest(reason) - if device == 'CPU' and len( - physical_device_types) == 1 and 'CPU' in physical_device_types: + if ( + device == 'CPU' + and len(physical_device_types) == 1 + and 'CPU' in physical_device_types + ): # Make sure we skip when only `CPU` is present. self.skipTest(reason) if device == 'GPU' and 'GPU' in physical_device_types: @@ -264,19 +274,17 @@ def skipTest(self, reason): # pylint: disable=invalid-name self._backend_configurator.tearDown() super().skipTest(reason) - def skipForPathways(self, reason: str): # pylint: disable=invalid-name - if config.backend_is_pw(): - self.skipTest(reason) - def assertDTensorEqual( self, # pylint: disable=invalid-name expected_result, expected_layout, result_dtensor, - tol=DEFAULT_TOL): + tol=DEFAULT_TOL, + ): """Asserts DTensor is of the particular value.""" if issubclass( - type(result_dtensor), resource_variable_ops.BaseResourceVariable): + type(result_dtensor), resource_variable_ops.BaseResourceVariable + ): result_dtensor = result_dtensor.value() if expected_layout is not None: # This, the assertEqual, is a pure proto raw bytes comparison. To make it @@ -288,11 +296,13 @@ def assertDTensorEqual( expected_str = expected_layout.to_string() got_str = api.fetch_layout(result_dtensor).to_string() index_for_mesh = expected_str.find('mesh:') - if index_for_mesh != -1 and got_str.find( - expected_str[index_for_mesh:]) != -1: + if ( + index_for_mesh != -1 + and got_str.find(expected_str[index_for_mesh:]) != -1 + ): # the mesh part is same. cut them so it is more readable. expected_str = expected_str[:index_for_mesh] - got_str = got_str[:got_str.find('mesh:')] + got_str = got_str[: got_str.find('mesh:')] self.assertEqual( api.fetch_layout(result_dtensor), @@ -375,9 +385,7 @@ def product(*lists): # (("test1", ...), ("test2", ...), ...). # Function returns the product of the lists with the labels concatenated. return [ # pylint: disable=g-complex-comprehension - (''.join(p[0] - for p in elt), *sum((p[1:] - for p in elt), ())) + (''.join(p[0] for p in elt), *sum((p[1:] for p in elt), ())) for elt in itertools.product(*lists) ] diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 741e089eddf082..2a6cb89cb485ad 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -137,7 +137,9 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", ], diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index 751289797875dc..814bf035d0e76b 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -13,7 +13,6 @@ limitations under the License. ==============================================================================*/ -#include #include #include "absl/log/check.h" diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 11798ad56641d8..61ccbd124eaaab 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -19,10 +19,12 @@ limitations under the License. #include #include #include +#include #include #include #include +#include "absl/status/status.h" #include "absl/strings/ascii.h" #include "xla/tsl/platform/status.h" #include "tensorflow/core/framework/api_def.pb.h" diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 048c05193d9b6a..e59c706e7355b6 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index d693a2b2a4ad08..bffb769004b56e 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -15,11 +15,14 @@ limitations under the License. #include "tensorflow/java/src/gen/cc/op_specs.h" +#include #include +#include #include #include #include +#include "absl/log/log.h" #include "absl/strings/match.h" #include "absl/strings/str_join.h" #include "absl/strings/strip.h" diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc index a58746774996a5..b3878e85c6b132 100644 --- a/tensorflow/java/src/gen/cc/source_writer.cc +++ b/tensorflow/java/src/gen/cc/source_writer.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/java/src/gen/cc/source_writer.h" #include +#include #include #include diff --git a/tensorflow/java/src/main/native/BUILD b/tensorflow/java/src/main/native/BUILD index 527288dc5348ad..f7fd51a7e9ccb5 100644 --- a/tensorflow/java/src/main/native/BUILD +++ b/tensorflow/java/src/main/native/BUILD @@ -4,29 +4,37 @@ load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_cuda_library") -package(default_visibility = [ - "//tensorflow/java:__pkg__", - "//tensorflow/tools/android/inference_interface:__pkg__", -]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/java:__pkg__", + # TODO(ashankar): Temporary hack for the Java API and + # //third_party/tensorflow/tools/android/inference_interface:android_tensorflow_inference_jni + # to co-exist in a single shared library. However, the hope is that + # //third_party/tensorflow/tools/android/inference_interface:android_tensorflow_jni can be + # removed once the Java API provides feature parity with it. + "//tensorflow/tools/android/inference_interface:__pkg__", + ], + licenses = ["notice"], +) -licenses(["notice"]) # Apache 2.0 +filegroup( + name = "native_srcs", + srcs = glob([ + "*.cc", + "*.h", + ]), + visibility = ["//visibility:public"], +) tf_cuda_library( name = "native", - srcs = glob(["*.cc"]) + select({ - # The Android toolchain makes "jni.h" available in the include path. - # For non-Android toolchains, generate jni.h and jni_md.h. - "//tensorflow:android": [], - "//conditions:default": [ - ":jni.h", - ":jni_md.h", - ], - }), + srcs = glob(["*.cc"]), hdrs = glob(["*.h"]), copts = tf_copts(), - includes = select({ - "//tensorflow:android": [], - "//conditions:default": ["."], + features = select({ + "//tensorflow:android": ["-layering_check"], + "//conditions:default": [], }), deps = select({ "//tensorflow:android": [ @@ -38,34 +46,8 @@ tf_cuda_library( "//tensorflow/core:all_kernels", "//tensorflow/core:direct_session", "//tensorflow/core:ops", + "@bazel_tools//tools/jdk:jni", ], }), alwayslink = 1, ) - -# Silly rules to make -# #include -# in the source headers work -# (in combination with the "includes" attribute of the tf_cuda_library rule -# above. Not needed when using the Android toolchain). -# -# Inspired from: -# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD -# but hopefully there is a simpler alternative to this. -genrule( - name = "copy_jni_h", - srcs = ["@bazel_tools//tools/jdk:jni_header"], - outs = ["jni.h"], - cmd = "cp -f $< $@", -) - -genrule( - name = "copy_jni_md_h", - srcs = select({ - "//tensorflow:windows": ["@bazel_tools//tools/jdk:jni_md_header-windows"], - "//tensorflow:macos": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], - "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"], - }), - outs = ["jni_md.h"], - cmd = "cp -f $< $@", -) diff --git a/tensorflow/lite/acceleration/configuration/configuration.proto b/tensorflow/lite/acceleration/configuration/configuration.proto index 29d911bd1b05b1..f9e480ec2b3e6b 100644 --- a/tensorflow/lite/acceleration/configuration/configuration.proto +++ b/tensorflow/lite/acceleration/configuration/configuration.proto @@ -337,6 +337,8 @@ message XNNPackSettings { // reloaded from this cache which can reduce initialization time and the // packing memory footprint. optional string weight_cache_file_path = 3; + // Extra flags to pass to xnn_create_runtime. + optional int32 runtime_flags = 4; } // CoreML Delegate settings. diff --git a/tensorflow/lite/acceleration/configuration/configuration_generated.h b/tensorflow/lite/acceleration/configuration/configuration_generated.h index 4cb4861e78f4f4..0e7d3219ef4974 100644 --- a/tensorflow/lite/acceleration/configuration/configuration_generated.h +++ b/tensorflow/lite/acceleration/configuration/configuration_generated.h @@ -1701,6 +1701,7 @@ struct XNNPackSettingsT : public ::flatbuffers::NativeTable { int32_t num_threads = 0; tflite::XNNPackFlags flags = tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS; std::string weight_cache_file_path{}; + int32_t runtime_flags = 0; }; struct XNNPackSettings FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -1709,7 +1710,8 @@ struct XNNPackSettings FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_NUM_THREADS = 4, VT_FLAGS = 6, - VT_WEIGHT_CACHE_FILE_PATH = 8 + VT_WEIGHT_CACHE_FILE_PATH = 8, + VT_RUNTIME_FLAGS = 10 }; int32_t num_threads() const { return GetField(VT_NUM_THREADS, 0); @@ -1720,12 +1722,16 @@ struct XNNPackSettings FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const ::flatbuffers::String *weight_cache_file_path() const { return GetPointer(VT_WEIGHT_CACHE_FILE_PATH); } + int32_t runtime_flags() const { + return GetField(VT_RUNTIME_FLAGS, 0); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_NUM_THREADS, 4) && VerifyField(verifier, VT_FLAGS, 4) && VerifyOffset(verifier, VT_WEIGHT_CACHE_FILE_PATH) && verifier.VerifyString(weight_cache_file_path()) && + VerifyField(verifier, VT_RUNTIME_FLAGS, 4) && verifier.EndTable(); } XNNPackSettingsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -1746,6 +1752,9 @@ struct XNNPackSettingsBuilder { void add_weight_cache_file_path(::flatbuffers::Offset<::flatbuffers::String> weight_cache_file_path) { fbb_.AddOffset(XNNPackSettings::VT_WEIGHT_CACHE_FILE_PATH, weight_cache_file_path); } + void add_runtime_flags(int32_t runtime_flags) { + fbb_.AddElement(XNNPackSettings::VT_RUNTIME_FLAGS, runtime_flags, 0); + } explicit XNNPackSettingsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -1761,8 +1770,10 @@ inline ::flatbuffers::Offset CreateXNNPackSettings( ::flatbuffers::FlatBufferBuilder &_fbb, int32_t num_threads = 0, tflite::XNNPackFlags flags = tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS, - ::flatbuffers::Offset<::flatbuffers::String> weight_cache_file_path = 0) { + ::flatbuffers::Offset<::flatbuffers::String> weight_cache_file_path = 0, + int32_t runtime_flags = 0) { XNNPackSettingsBuilder builder_(_fbb); + builder_.add_runtime_flags(runtime_flags); builder_.add_weight_cache_file_path(weight_cache_file_path); builder_.add_flags(flags); builder_.add_num_threads(num_threads); @@ -1773,13 +1784,15 @@ inline ::flatbuffers::Offset CreateXNNPackSettingsDirect( ::flatbuffers::FlatBufferBuilder &_fbb, int32_t num_threads = 0, tflite::XNNPackFlags flags = tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS, - const char *weight_cache_file_path = nullptr) { + const char *weight_cache_file_path = nullptr, + int32_t runtime_flags = 0) { auto weight_cache_file_path__ = weight_cache_file_path ? _fbb.CreateString(weight_cache_file_path) : 0; return tflite::CreateXNNPackSettings( _fbb, num_threads, flags, - weight_cache_file_path__); + weight_cache_file_path__, + runtime_flags); } ::flatbuffers::Offset CreateXNNPackSettings(::flatbuffers::FlatBufferBuilder &_fbb, const XNNPackSettingsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -4971,7 +4984,8 @@ inline bool operator==(const XNNPackSettingsT &lhs, const XNNPackSettingsT &rhs) return (lhs.num_threads == rhs.num_threads) && (lhs.flags == rhs.flags) && - (lhs.weight_cache_file_path == rhs.weight_cache_file_path); + (lhs.weight_cache_file_path == rhs.weight_cache_file_path) && + (lhs.runtime_flags == rhs.runtime_flags); } inline bool operator!=(const XNNPackSettingsT &lhs, const XNNPackSettingsT &rhs) { @@ -4991,6 +5005,7 @@ inline void XNNPackSettings::UnPackTo(XNNPackSettingsT *_o, const ::flatbuffers: { auto _e = num_threads(); _o->num_threads = _e; } { auto _e = flags(); _o->flags = _e; } { auto _e = weight_cache_file_path(); if (_e) _o->weight_cache_file_path = _e->str(); } + { auto _e = runtime_flags(); _o->runtime_flags = _e; } } inline ::flatbuffers::Offset XNNPackSettings::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const XNNPackSettingsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { @@ -5004,11 +5019,13 @@ inline ::flatbuffers::Offset CreateXNNPackSettings(::flatbuffer auto _num_threads = _o->num_threads; auto _flags = _o->flags; auto _weight_cache_file_path = _o->weight_cache_file_path.empty() ? 0 : _fbb.CreateString(_o->weight_cache_file_path); + auto _runtime_flags = _o->runtime_flags; return tflite::CreateXNNPackSettings( _fbb, _num_threads, _flags, - _weight_cache_file_path); + _weight_cache_file_path, + _runtime_flags); } diff --git a/tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev b/tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev index 569042d3c88e7b..b8881307b3aa33 100644 --- a/tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev +++ b/tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev @@ -335,6 +335,8 @@ message XNNPackSettings { // reloaded from this cache which can reduce initialization time and the // packing memory footprint. optional string weight_cache_file_path = 3; + // Extra flags to pass to xnn_create_runtime + optional int32 runtime_flags = 4; } // CoreML Delegate settings. diff --git a/tensorflow/lite/cmake/DownloadPThreadPool.cmake b/tensorflow/lite/cmake/DownloadPThreadPool.cmake index 08441656f177de..e12799e3231a31 100644 --- a/tensorflow/lite/cmake/DownloadPThreadPool.cmake +++ b/tensorflow/lite/cmake/DownloadPThreadPool.cmake @@ -19,8 +19,8 @@ PROJECT(pthreadpool-download NONE) INCLUDE(ExternalProject) ExternalProject_Add(pthreadpool - URL https://github.com/google/pthreadpool/archive/b1aee199d54003fb557076a201bcac3398af580b.zip - URL_HASH SHA256=215724985c4845cdcadcb5f26a2a8777943927bb5a172a00e7716fe16a6f3c1b + URL https://github.com/google/pthreadpool/archive/b92447772365661680f486e39a91dfe6675adafc.zip + URL_HASH SHA256=745e56516d6a58d183eb33d9017732d87cff43ce9f78908906f9faa52633e421 SOURCE_DIR "${CMAKE_BINARY_DIR}/pthreadpool-source" BINARY_DIR "${CMAKE_BINARY_DIR}/pthreadpool" CONFIGURE_COMMAND "" diff --git a/tensorflow/lite/core/BUILD b/tensorflow/lite/core/BUILD index f52c4cc77b46fd..5e0ded435cd945 100644 --- a/tensorflow/lite/core/BUILD +++ b/tensorflow/lite/core/BUILD @@ -148,6 +148,19 @@ cc_library( alwayslink = 1, # TODO(b/161243354): eliminate this. ) +# This is a private target, its visibility is set to public only to be +# used by LiteRT dependencies. +# Do not use this target directly and don't consider it as a part of the public API. +# TODO(weiyiw): Refactor LiteRT deps from TFLite. +alias( + name = "private_cc_api_stable", + actual = ":cc_api_stable", + tags = ["avoid_dep"], + visibility = [ + "//visibility:public", + ], +) + # TODO(b/242310498): move logger.cc from tensorflow/lite/ to here. cc_library( name = "cc_api_stable", diff --git a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc index 1133b1b69c0e84..8154615931a43b 100644 --- a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc +++ b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc @@ -40,6 +40,7 @@ static TfLiteDelegate* CreateDelegate(const void* settings) { if (xnnpack_settings->flags()) { options.flags = xnnpack_settings->flags(); } + options.runtime_flags = xnnpack_settings->runtime_flags(); if (xnnpack_settings->weight_cache_file_path()) { options.weight_cache_file_path = xnnpack_settings->weight_cache_file_path()->c_str(); diff --git a/tensorflow/lite/core/api/tensor_utils.cc b/tensorflow/lite/core/api/tensor_utils.cc index 18a643c78dc272..c5052c78f840cd 100644 --- a/tensorflow/lite/core/api/tensor_utils.cc +++ b/tensorflow/lite/core/api/tensor_utils.cc @@ -33,8 +33,8 @@ TfLiteStatus ResetVariableTensor(TfLiteTensor* tensor) { } // TODO(b/139446230): Provide a platform header to better handle these // specific scenarios. -#if __ANDROID__ || defined(__x86_64__) || defined(__i386__) || \ - defined(__i386) || defined(__x86__) || defined(__X86__) || \ +#if defined(__ANDROID__) || defined(__x86_64__) || defined(__i386__) || \ + defined(__i386) || defined(__x86__) || defined(__X86__) || \ defined(_X86_) || defined(_M_IX86) || defined(_M_X64) memset(tensor->data.raw, value, tensor->bytes); #else diff --git a/tensorflow/lite/core/async/task_internal_test.cc b/tensorflow/lite/core/async/task_internal_test.cc index b0dc1ae385917f..68e8004fa5d434 100644 --- a/tensorflow/lite/core/async/task_internal_test.cc +++ b/tensorflow/lite/core/async/task_internal_test.cc @@ -14,8 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/core/async/task_internal.h" -#include - #include #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index a2b03389d68673..98ae0d425ab146 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -53,6 +53,10 @@ filegroup( "//tensorflow/compiler/mlir/lite/core/c:lite_headers_filegroup", "//tensorflow/lite/core/async/c:types.h", ], + visibility = [ + # Temporary workaround to make visible to litert in OSS (default vis is not transformed correctly.) + "//visibility:public", + ], ) filegroup( @@ -277,7 +281,7 @@ cc_test( ) # This is a private target, its visibility is set to public only to be -# used by "tflite_custom_c_library". +# used by "tflite_custom_c_library" and LiteRT dependencies. # Do not use this target directly and don't consider it as a part of the public API. alias( name = "private_c_api_types", @@ -552,7 +556,7 @@ tflite_cc_library_with_c_headers_test( ) # This is a private target, its visibility is set to public only to be -# used by "custom_c_library_with_tflite". +# used by "custom_c_library_with_tflite" and LiteRT dependencies. # Do not use this target directly and don't consider it as a part of the public API. alias( name = "private_c_api_opaque_without_op_resolver", @@ -618,6 +622,7 @@ cc_test( "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/delegates:delegate_test_util", "//tensorflow/lite/testing:util", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/core/c/c_api.cc b/tensorflow/lite/core/c/c_api.cc index beb924415e298e..fcbda4e4fb0c81 100644 --- a/tensorflow/lite/core/c/c_api.cc +++ b/tensorflow/lite/core/c/c_api.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include // NOLINT diff --git a/tensorflow/lite/core/c/c_api_experimental.cc b/tensorflow/lite/core/c/c_api_experimental.cc index d2128efe608fbc..a07d8246b62923 100644 --- a/tensorflow/lite/core/c/c_api_experimental.cc +++ b/tensorflow/lite/core/c/c_api_experimental.cc @@ -17,9 +17,7 @@ limitations under the License. #include -#include #include -#include #include #include "tensorflow/lite/builtin_ops.h" diff --git a/tensorflow/lite/core/c/c_api_experimental_test.cc b/tensorflow/lite/core/c/c_api_experimental_test.cc index f98ddb0b2c00db..7ee05979e427db 100644 --- a/tensorflow/lite/core/c/c_api_experimental_test.cc +++ b/tensorflow/lite/core/c/c_api_experimental_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/c_api.h" diff --git a/tensorflow/lite/core/c/c_api_test.cc b/tensorflow/lite/core/c/c_api_test.cc index 8aeb116b692260..b9a0af08807292 100644 --- a/tensorflow/lite/core/c/c_api_test.cc +++ b/tensorflow/lite/core/c/c_api_test.cc @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include +#include #include #include #include diff --git a/tensorflow/lite/core/c/common.cc b/tensorflow/lite/core/c/common.cc index d458b1eb29b5ab..baa6282fd5b12e 100644 --- a/tensorflow/lite/core/c/common.cc +++ b/tensorflow/lite/core/c/common.cc @@ -113,14 +113,25 @@ TfLiteQuantization TfLiteQuantizationClone(const TfLiteQuantization& src) { case kTfLiteAffineQuantization: { dst.params = calloc(1, sizeof(TfLiteAffineQuantization)); const TfLiteAffineQuantization* const src_params = - (TfLiteAffineQuantization*)(src.params); + reinterpret_cast(src.params); TfLiteAffineQuantization* const dst_params = - (TfLiteAffineQuantization*)(dst.params); + reinterpret_cast(dst.params); dst_params->quantized_dimension = src_params->quantized_dimension; dst_params->scale = TfLiteFloatArrayCopy(src_params->scale); dst_params->zero_point = TfLiteIntArrayCopy(src_params->zero_point); break; } + case kTfLiteBlockwiseQuantization: { + dst.params = calloc(1, sizeof(TfLiteBlockwiseQuantization)); + const TfLiteBlockwiseQuantization* const src_params = + (TfLiteBlockwiseQuantization*)(src.params); + TfLiteBlockwiseQuantization* const dst_params = + (TfLiteBlockwiseQuantization*)(dst.params); + dst_params->blocksize = src_params->blocksize; + dst_params->scale = src_params->scale; + dst_params->zero_point = src_params->zero_point; + break; + } } return dst; } @@ -225,7 +236,7 @@ void TfLiteTensorDataFree(TfLiteTensor* t) { void TfLiteQuantizationFree(TfLiteQuantization* quantization) { if (quantization->type == kTfLiteAffineQuantization) { TfLiteAffineQuantization* q_params = - (TfLiteAffineQuantization*)(quantization->params); + reinterpret_cast(quantization->params); if (q_params->scale) { TfLiteFloatArrayFree(q_params->scale); q_params->scale = nullptr; diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h index 87a9b1a5075051..3f1fe32b8b4f47 100644 --- a/tensorflow/lite/core/c/common.h +++ b/tensorflow/lite/core/c/common.h @@ -322,12 +322,18 @@ typedef struct TfLiteBFloat16 { const char* TfLiteTypeGetName(TfLiteType type); /// SupportedQuantizationTypes. +#ifdef __cplusplus typedef enum TfLiteQuantizationType : int { +#else +typedef enum TfLiteQuantizationType { +#endif /// No quantization. kTfLiteNoQuantization = 0, /// Affine quantization (with support for per-channel quantization). /// Corresponds to TfLiteAffineQuantization. kTfLiteAffineQuantization = 1, + /// Blockwise quantization. + kTfLiteBlockwiseQuantization = 2, } TfLiteQuantizationType; /// Structure specifying the quantization used by the tensor, if-any. @@ -353,6 +359,20 @@ typedef struct TfLiteAffineQuantization { int32_t quantized_dimension; } TfLiteAffineQuantization; +/// Parameters for blockwise quantization across the output channels dimension. +/// For a particular value in quantized_dimension, quantized values can be +/// converted back to float using: +/// `real_value = scale * (quantized_value - zero_point)` +typedef struct TfLiteBlockwiseQuantization { + // Index of the tensor containing the scales. + int32_t scale; + // Index of the tensor containing the zero points. + int32_t zero_point; + // Quantization blocksize. + int32_t blocksize; + int32_t quantized_dimension; +} TfLiteBlockwiseQuantization; + /// A union of pointers that points to memory for a given tensor. /// /// Do not access these members directly, if possible, use diff --git a/tensorflow/lite/core/c/common_test.cc b/tensorflow/lite/core/c/common_test.cc index fadc3f2bc68f08..e449b4821a4404 100644 --- a/tensorflow/lite/core/c/common_test.cc +++ b/tensorflow/lite/core/c/common_test.cc @@ -17,9 +17,11 @@ limitations under the License. #include #include +#include #include #include #include +#include #include #include diff --git a/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc index 51c927836135dc..1e8401aeb6f35e 100644 --- a/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc +++ b/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h" -#include +#include +#include #include #include -#include #include #include @@ -25,9 +25,9 @@ limitations under the License. #include "flatbuffers/verifier.h" // from @flatbuffers #include "tensorflow/lite/acceleration/configuration/c/delegate_plugin.h" #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_types.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/benchmark_result_evaluator.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/blocking_validator_runner.h" -#include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api_types.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner_options.h" diff --git a/tensorflow/lite/core/interpreter_builder.cc b/tensorflow/lite/core/interpreter_builder.cc index 1c6cc8c2ac9dd9..8741022e3c2a70 100644 --- a/tensorflow/lite/core/interpreter_builder.cc +++ b/tensorflow/lite/core/interpreter_builder.cc @@ -407,6 +407,20 @@ TfLiteStatus InterpreterBuilder::ParseNodes( TfLiteStatus InterpreterBuilder::ParseQuantization( const QuantizationParameters* src_quantization, TfLiteQuantization* quantization, const std::vector& dims) { + // Blockwise quantization. + if (src_quantization && src_quantization->details_type() == + QuantizationDetails_BlockwiseQuantization) { + auto* src_quant = src_quantization->details_as_BlockwiseQuantization(); + quantization->type = kTfLiteBlockwiseQuantization; + auto* blockwise_quantization = + reinterpret_cast( + malloc(sizeof(TfLiteBlockwiseQuantization))); + blockwise_quantization->scale = src_quant->scales(); + blockwise_quantization->quantized_dimension = 0; + blockwise_quantization->blocksize = src_quant->block_size(); + quantization->params = reinterpret_cast(blockwise_quantization); + return kTfLiteOk; + } quantization->type = kTfLiteNoQuantization; quantization->params = nullptr; if (!src_quantization || !src_quantization->scale() || diff --git a/tensorflow/lite/core/kernels/register.cc b/tensorflow/lite/core/kernels/register.cc index 216f1dece7ef8e..0c331b98ead1f9 100644 --- a/tensorflow/lite/core/kernels/register.cc +++ b/tensorflow/lite/core/kernels/register.cc @@ -354,7 +354,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_DYNAMIC_UPDATE_SLICE, Register_DYNAMIC_UPDATE_SLICE(), /* min_version = */ 1, - /* max_version = */ 3); + /* max_version = */ 4); AddBuiltin(BuiltinOperator_UNSORTED_SEGMENT_PROD, Register_UNSORTED_SEGMENT_PROD()); AddBuiltin(BuiltinOperator_UNSORTED_SEGMENT_MAX, diff --git a/tensorflow/lite/delegates/delegate_test.cc b/tensorflow/lite/delegates/delegate_test.cc index cfd17b4a07f884..eeb4160c6097c5 100644 --- a/tensorflow/lite/delegates/delegate_test.cc +++ b/tensorflow/lite/delegates/delegate_test.cc @@ -1067,6 +1067,7 @@ class TestDelegateWithDynamicTensors : public ::testing::Test { TfLiteRegistration reg = DynamicCopyOpRegistration(); interpreter_->AddNodeWithParameters({0}, {1, 2}, nullptr, 0, nullptr, ®); + delegate_ = TfLiteDelegateCreate(); delegate_.Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) -> TfLiteStatus { // In this test, the delegate replaces all the nodes if this function is diff --git a/tensorflow/lite/delegates/delegate_test_util.cc b/tensorflow/lite/delegates/delegate_test_util.cc index 1097c1408238e0..91899a74ce393e 100644 --- a/tensorflow/lite/delegates/delegate_test_util.cc +++ b/tensorflow/lite/delegates/delegate_test_util.cc @@ -158,6 +158,7 @@ SimpleDelegate::SimpleDelegate(const std::vector& nodes, automatic_shape_propagation_(automatic_shape_propagation), custom_op_(custom_op), set_output_tensor_dynamic_(set_output_tensor_dynamic) { + delegate_ = TfLiteDelegateCreate(); delegate_.Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) -> TfLiteStatus { auto* simple = static_cast(delegate->data_); diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD index 4533a2417d0d88..a7572d9f74bdfb 100644 --- a/tensorflow/lite/delegates/flex/BUILD +++ b/tensorflow/lite/delegates/flex/BUILD @@ -166,6 +166,7 @@ cc_library( ":delegate_data", ":tflite_subgraph_execute", ":util", + "//tensorflow/core:session_options", "//tensorflow/core/tfrt/fallback:op_kernel_runner", "//tensorflow/lite:kernel_api", "//tensorflow/lite:macros", @@ -173,10 +174,14 @@ cc_library( "//tensorflow/lite:string", "//tensorflow/lite:string_util", "//tensorflow/lite:util", + "//tensorflow/lite/core:subgraph", "//tensorflow/lite/core/api", "//tensorflow/lite/core/c:common", "//tensorflow/lite/delegates/utils:simple_delegate", "//tensorflow/lite/kernels:kernel_util", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@flatbuffers", ] + if_mobile([ diff --git a/tensorflow/lite/delegates/flex/delegate.cc b/tensorflow/lite/delegates/flex/delegate.cc index eee6bd04e6de37..f7fca34d49d739 100644 --- a/tensorflow/lite/delegates/flex/delegate.cc +++ b/tensorflow/lite/delegates/flex/delegate.cc @@ -14,19 +14,24 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/flex/delegate.h" +#include +#include +#include #include #include -#include -#include "absl/strings/str_cat.h" -#include "tensorflow/core/framework/variant.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/lite/context_util.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/public/session_options.h" #include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/core/macros.h" +#include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/delegates/flex/buffer_map.h" #include "tensorflow/lite/delegates/flex/kernel.h" -#include "tensorflow/lite/delegates/flex/util.h" +#include "tensorflow/lite/delegates/utils/simple_delegate.h" +#include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/util.h" @@ -158,11 +163,9 @@ TfLiteStatus FlexDelegate::CopyFromBufferHandle( if (output->bytes != t_data.size()) { TF_LITE_KERNEL_LOG(context, - absl::StrCat("The given ", output->bytes, - " bytes are not enough to store " - "TensorFlow's aligned buffer of size ", - t_data.size(), " bytes.") - .c_str()); + "The given %zu bytes are not enough to store " + "TensorFlow's aligned buffer of size %zu bytes.", + output->bytes, t_data.size()); return kTfLiteError; } diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc index 7a8bf163161914..9e6532d6b7b908 100644 --- a/tensorflow/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -14,7 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/flex/kernel.h" +#include + #include +#include +#include #include #include #include @@ -22,23 +26,38 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/delegates/flex/buffer_map.h" #include "tensorflow/lite/delegates/flex/delegate.h" #include "tensorflow/lite/delegates/flex/delegate_data.h" #include "tensorflow/lite/delegates/flex/util.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/string_type.h" +#include "tensorflow/lite/util.h" // Note: this is part of TF Lite's Flex delegation code which is to be // completed soon. @@ -343,7 +362,7 @@ class OpNode { tf_tensor->TotalBytes() != tensor->bytes) { TF_LITE_KERNEL_LOG(context, "FlexDelegate: Tensor %s(%d) buffer size mismatch " - "%zu(%lld) != %ld(%ld)", + "%zu(%" PRId64 ") != %zu(%" PRId64 ")", tensor->name, tensor_index, tf_tensor->TotalBytes(), tf_tensor->NumElements(), tensor->bytes, NumElements(tensor)); @@ -466,14 +485,14 @@ TfLiteStatus DelegateKernel::Init(TfLiteContext* context, op_data_->shared_info.tensor_release_map = flex_delegate_data->GetTensorReleaseMap(context); - CHECK(params->output_tensors); + TF_LITE_ENSURE(context, params->output_tensors != nullptr); std::set output_set; for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) { op_data_->subgraph_outputs.push_back(tensor_index); output_set.insert(tensor_index); } - CHECK(params->input_tensors); + TF_LITE_ENSURE(context, params->input_tensors != nullptr); for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) { op_data_->subgraph_inputs.push_back(tensor_index); } @@ -482,7 +501,7 @@ TfLiteStatus DelegateKernel::Init(TfLiteContext* context, op_data_->nodes.reserve(params->nodes_to_replace->size); - CHECK(params->nodes_to_replace); + TF_LITE_ENSURE(context, params->nodes_to_replace != nullptr); absl::Status status; // Now we explicitly disable reusing TFLite tensor buffers for certain TF ops, @@ -813,7 +832,7 @@ TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) { tf_tensor.TotalBytes() != tensor->bytes) { TF_LITE_KERNEL_LOG(context, "FlexDelegate: Tensor %s(%d) buffer size mismatch " - "%zu(%lld) != %ld(%ld)", + "%zu(%" PRId64 ") != %zu(%" PRId64 ")", tensor->name, tensor_index, tf_tensor.TotalBytes(), tf_tensor.NumElements(), tensor->bytes, NumElements(tensor)); diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 421a6faebfbd1d..4e216c6677ffe8 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -42,25 +42,31 @@ _DELEGATE_NO_GL_DEPS = select({ ":tflite_profile", #"//third_party/GL:EGL_headers", #"//third_party/GL:GLES3_headers", + # go/keep-sorted start "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "//tensorflow/lite:kernel_api", - "//tensorflow/lite:minimal_logging", "//tensorflow/lite/async:backend_async_kernel_interface", "//tensorflow/lite/core/async/interop/c:types", "//tensorflow/lite/core/c:common", - "//tensorflow/lite/delegates:serialization", "//tensorflow/lite/delegates/gpu/cl:util", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_builder", "//tensorflow/lite/delegates/gpu/common:model_builder_helper", "//tensorflow/lite/delegates/gpu/common:quantization_util", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates:serialization", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/profiling/telemetry", - "//tensorflow/lite/profiling/telemetry:telemetry_status", + "//tensorflow/lite/profiling/telemetry/c:telemetry_setting", "//tensorflow/lite/profiling/telemetry/c:telemetry_setting_internal", + "//tensorflow/lite/profiling/telemetry:telemetry_status", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite:minimal_logging", + # go/keep-sorted end ] config_setting( @@ -70,14 +76,12 @@ config_setting( config_setting( name = "tflite_gpu_extra_gles_deps", - # copybara:uncomment_begin(google-only) - # constraint_values = [ - # "//third_party/bazel_platforms/os:linux", - # ], - # copybara:uncomment_end + constraint_values = [ + "//third_party/bazel_platforms/cpu:x86_64", + "//third_party/bazel_platforms/os:linux", + ], values = { "copt": "-DTFLITE_GPU_EXTRA_GLES_DEPS", - "cpu": "k8", }, ) diff --git a/tensorflow/lite/delegates/gpu/cl/buffer.h b/tensorflow/lite/delegates/gpu/cl/buffer.h index 088a66aa57af2b..01d4e631247737 100644 --- a/tensorflow/lite/delegates/gpu/cl/buffer.h +++ b/tensorflow/lite/delegates/gpu/cl/buffer.h @@ -97,8 +97,9 @@ template absl::Status Buffer::WriteData(CLCommandQueue* queue, const absl::Span data) { if (size_ != sizeof(T) * data.size()) { - return absl::InvalidArgumentError( - "absl::Span data size is different from buffer allocated size."); + return absl::InvalidArgumentError(absl::StrCat( + "absl::Span data size is different from buffer allocated size: ", + size_, " vs ", sizeof(T) * data.size())); } RETURN_IF_ERROR(queue->EnqueueWriteBuffer(buffer_, size_, data.data())); return absl::OkStatus(); diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc index bc682f7250525b..2eb95df35ae5fa 100644 --- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc +++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc @@ -185,6 +185,11 @@ void LoadOpenCLFunctionExtensions(cl_platform_id platform_id) { // cl_arm_import_memory extension LoadFunctionExtension(platform_id, clImportMemoryARM); + + // cl_khr_semaphore extension + LoadFunctionExtension(platform_id, clCreateSemaphoreWithPropertiesKHR); + LoadFunctionExtension(platform_id, clEnqueueWaitSemaphoresKHR); + LoadFunctionExtension(platform_id, clEnqueueSignalSemaphoresKHR); } #ifdef __WINDOWS__ @@ -450,6 +455,11 @@ PFN_clGetCommandBufferInfoKHR clGetCommandBufferInfoKHR; // cl_arm_import_memory extension PFN_clImportMemoryARM clImportMemoryARM; +// cl_khr_semaphore extension +PFN_clCreateSemaphoreWithPropertiesKHR clCreateSemaphoreWithPropertiesKHR; +PFN_clEnqueueWaitSemaphoresKHR clEnqueueWaitSemaphoresKHR; +PFN_clEnqueueSignalSemaphoresKHR clEnqueueSignalSemaphoresKHR; + DEFINE_QCOM_FUNCTION_PTRS cl_mem CreateImage2DLegacy(cl_context context, cl_mem_flags flags, diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h index e42fbe2454a92f..5dd762a2d51a8f 100644 --- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h +++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h @@ -580,6 +580,23 @@ typedef cl_mem(CL_API_CALL *PFN_clImportMemoryARM)( const cl_import_properties_arm * /*properties*/, void * /*memory*/, size_t /*size*/, cl_int * /*errcode_ret*/); +// cl_khr_semaphore extension +typedef cl_semaphore_khr(CL_API_CALL *PFN_clCreateSemaphoreWithPropertiesKHR)( + cl_context /*context*/, const cl_semaphore_properties_khr * /*sema_props*/, + cl_int * /*errcode_ret*/); +typedef cl_int(CL_API_CALL *PFN_clEnqueueWaitSemaphoresKHR)( + cl_command_queue /*command_queue*/, cl_uint /*num_sema_objects*/, + const cl_semaphore_khr * /*sema_objects*/, + const cl_semaphore_payload_khr * /*sema_payload_list*/, + cl_uint /*num_events_in_wait_list*/, const cl_event * /*event_wait_list*/, + cl_event * /*event*/); +typedef cl_int(CL_API_CALL *PFN_clEnqueueSignalSemaphoresKHR)( + cl_command_queue /*command_queue*/, cl_uint /*num_sema_objects*/, + const cl_semaphore_khr * /*sema_objects*/, + const cl_semaphore_payload_khr * /*sema_payload_list*/, + cl_uint /*num_events_in_wait_list*/, const cl_event * /*event_wait_list*/, + cl_event * /*event*/); + extern PFN_clGetPlatformIDs clGetPlatformIDs; extern PFN_clGetPlatformInfo clGetPlatformInfo; extern PFN_clGetDeviceIDs clGetDeviceIDs; @@ -710,6 +727,12 @@ extern PFN_clGetCommandBufferInfoKHR clGetCommandBufferInfoKHR; // cl_arm_import_memory extension extern PFN_clImportMemoryARM clImportMemoryARM; +// cl_khr_semaphore extension +extern PFN_clCreateSemaphoreWithPropertiesKHR + clCreateSemaphoreWithPropertiesKHR; +extern PFN_clEnqueueWaitSemaphoresKHR clEnqueueWaitSemaphoresKHR; +extern PFN_clEnqueueSignalSemaphoresKHR clEnqueueSignalSemaphoresKHR; + // For convenient image creation // It uses clCreateImage if it available (clCreateImage available since cl 1.2) // otherwise it will use legacy clCreateImage2D diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 31d12a503dafc7..8200eb9a42282d 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model_builder.h" #include +#include #include #include #include @@ -31,9 +32,13 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/core/interpreter.h" +#include "tensorflow/lite/core/interpreter_builder.h" +#include "tensorflow/lite/core/model_builder.h" #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/lstm_parser.h" @@ -3369,7 +3374,7 @@ TfLiteIntArray* GetOpsToReplace( partition_helper.num_total_nodes()); } absl::StrAppend(&error_message, " operations will run on the CPU."); - TF_LITE_KERNEL_LOG(context, error_message.c_str()); + TF_LITE_KERNEL_LOG(context, "%s", error_message.c_str()); } return ConvertVectorToTfLiteIntArray(ops_to_replace); } @@ -3612,7 +3617,8 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer, const tflite::OpResolver& op_resolver, - GraphFloat32* graph, bool allow_quant_ops) { + GraphFloat32* graph, bool allow_quant_ops, + bool apply_model_transformations) { std::unique_ptr interpreter; tflite::InterpreterBuilder interpreter_builder(flatbuffer, op_resolver); if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) { @@ -3638,9 +3644,11 @@ absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer, return absl::InternalError("Conversion from TfLite model failed."); } - ModelTransformer transformer(graph); - if (!ApplyModelTransformations(&transformer)) { - return absl::InternalError("Graph transformations failed"); + if (apply_model_transformations) { + ModelTransformer transformer(graph); + if (!ApplyModelTransformations(&transformer)) { + return absl::InternalError("Graph transformations failed"); + } } return absl::OkStatus(); diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.h b/tensorflow/lite/delegates/gpu/common/model_builder.h index 62c2310880cdd2..e6522bcb0b5e06 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.h +++ b/tensorflow/lite/delegates/gpu/common/model_builder.h @@ -89,7 +89,8 @@ absl::Status BuildFinalModel( absl::Status BuildFromFlatBuffer(const FlatBufferModel& flatbuffer, const OpResolver& op_resolver, GraphFloat32* graph, - bool allow_quant_ops = false); + bool allow_quant_ops = false, + bool apply_model_transformations = true); // Module-internal converter, exposed for unit testing purpose only. absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, diff --git a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc index 6d3dec487e4ea1..02b4e16aa1a78e 100644 --- a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index 9fac6e598f1b1a..cfad378991585c 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -20,13 +20,8 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/delegate.h" -#include "tensorflow/lite/logger.h" - -#if defined(__ANDROID__) -#include -#endif - #include +#include #include #include #include @@ -40,28 +35,35 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/types/span.h" #include "tensorflow/lite/builtin_ops.h" - -#if defined(__ANDROID__) -#include "tensorflow/lite/async/backend_async_kernel_interface.h" -#include "tensorflow/lite/core/async/c/task.h" -#include "tensorflow/lite/core/async/interop/c/attribute_map.h" -#include "tensorflow/lite/core/async/interop/c/constants.h" -#include "tensorflow/lite/core/async/interop/c/types.h" -#endif - #include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/delegates/gpu/android_hardware_buffer.h" #include "tensorflow/lite/delegates/gpu/api.h" #include "tensorflow/lite/delegates/gpu/cl/api.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_builder.h" #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" #include "tensorflow/lite/delegates/gpu/common/quantization_util.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/delegate_options.h" #include "tensorflow/lite/delegates/gpu/tflite_profile.h" #include "tensorflow/lite/delegates/serialization.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/logger.h" +#include "tensorflow/lite/minimal_logging.h" +#include "tensorflow/lite/profiling/telemetry/c/telemetry_setting.h" +#include "tensorflow/lite/profiling/telemetry/telemetry.h" +#include "tensorflow/lite/profiling/telemetry/telemetry_status.h" #if defined(__ANDROID__) +#include + +#include "tensorflow/lite/async/backend_async_kernel_interface.h" +#include "tensorflow/lite/core/async/c/task.h" +#include "tensorflow/lite/core/async/interop/c/attribute_map.h" +#include "tensorflow/lite/core/async/interop/c/constants.h" +#include "tensorflow/lite/core/async/interop/c/types.h" +#include "tensorflow/lite/delegates/gpu/android_hardware_buffer.h" #include "tensorflow/lite/delegates/gpu/async_buffers.h" #include "tensorflow/lite/delegates/gpu/gl/android_sync.h" #include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" @@ -71,12 +73,6 @@ limitations under the License. #include "tensorflow/lite/delegates/utils/utils.h" #endif -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/minimal_logging.h" -#include "tensorflow/lite/profiling/telemetry/c/telemetry_setting_internal.h" -#include "tensorflow/lite/profiling/telemetry/telemetry.h" -#include "tensorflow/lite/profiling/telemetry/telemetry_status.h" - #ifndef CL_DELEGATE_NO_GL #include "tensorflow/lite/delegates/gpu/gl/api2.h" #endif @@ -469,7 +465,7 @@ absl::Status DelegateKernelCore::Setup( InitializeOpenClApi(&graph, &builder, &graph_is_destroyed, context, delegate_params, delegate_->serialization()); if (!status.ok()) { - TF_LITE_KERNEL_LOG(context, std::string(status.message()).c_str()); + TF_LITE_KERNEL_LOG(context, "%s", std::string(status.message()).c_str()); TF_LITE_KERNEL_LOG(context, "Falling back to OpenGL"); // Graph needs to be re-created because it is moved above. diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc index 985da96ebff678..7db139c4ccfa33 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "absl/container/flat_hash_set.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc index 00a95c816e9976..f6aee5dd889678 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc @@ -27,7 +27,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" -#include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/common/access_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/types.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc index fbca570d892f2f..0a057b14a80a2c 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc index d1a7fd78e1a87b..81b8e89f2252f0 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc @@ -15,15 +15,17 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" +#include +#include #include #include #include +#include #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" -#include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" #include "tensorflow/lite/delegates/gpu/gl/variable.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/activation_builder.cc b/tensorflow/lite/delegates/hexagon/builders/activation_builder.cc index 97202338826f2f..ee17b849706819 100644 --- a/tensorflow/lite/delegates/hexagon/builders/activation_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/activation_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/arg_min_max_builder.cc b/tensorflow/lite/delegates/hexagon/builders/arg_min_max_builder.cc index 0b3921f1b34622..f4b39a02357651 100644 --- a/tensorflow/lite/delegates/hexagon/builders/arg_min_max_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/arg_min_max_builder.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/arg_min_max_builder.h" -#include +#include #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/arithmetic_builder.cc b/tensorflow/lite/delegates/hexagon/builders/arithmetic_builder.cc index 4e888de5fc5eb3..c054b86735bee5 100644 --- a/tensorflow/lite/delegates/hexagon/builders/arithmetic_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/arithmetic_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "hexagon/hexagon_nn_ops.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/cast_builder.cc b/tensorflow/lite/delegates/hexagon/builders/cast_builder.cc index 7f624203dae9d0..d4b8adb105e92b 100644 --- a/tensorflow/lite/delegates/hexagon/builders/cast_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/cast_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.cc b/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.cc index 17c1ce63718662..0c17d2f0baae5f 100644 --- a/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include +#include #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/builders/op_builder.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc b/tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc index 58c7bd76fb0239..744d048b699799 100644 --- a/tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc +++ b/tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include #include #include "tensorflow/lite/core/c/builtin_op_data.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/hardswish_builder.cc b/tensorflow/lite/delegates/hexagon/builders/hardswish_builder.cc index 5e6ff2699fd1e2..dc42c8f51ed612 100644 --- a/tensorflow/lite/delegates/hexagon/builders/hardswish_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/hardswish_builder.cc @@ -15,8 +15,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/l2_normalization_builder.cc b/tensorflow/lite/delegates/hexagon/builders/l2_normalization_builder.cc index 9e87d4109dba51..e4bb336b6e369f 100644 --- a/tensorflow/lite/delegates/hexagon/builders/l2_normalization_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/l2_normalization_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/matmul_builder.cc b/tensorflow/lite/delegates/hexagon/builders/matmul_builder.cc index fa91b50808560e..c242ff8e7d11c8 100644 --- a/tensorflow/lite/delegates/hexagon/builders/matmul_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/matmul_builder.cc @@ -16,7 +16,7 @@ limitations under the License. #include -#include +#include #include "hexagon/hexagon_nn_ops.h" #include "tensorflow/lite/core/c/builtin_op_data.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/min_max_builder.cc b/tensorflow/lite/delegates/hexagon/builders/min_max_builder.cc index 9b6103fcc93536..772c52a7f6b4c9 100644 --- a/tensorflow/lite/delegates/hexagon/builders/min_max_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/min_max_builder.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/min_max_builder.h" +#include + #include "tensorflow/lite/core/c/common.h" namespace tflite { diff --git a/tensorflow/lite/delegates/hexagon/builders/mirror_pad_builder.cc b/tensorflow/lite/delegates/hexagon/builders/mirror_pad_builder.cc index 353b8a007d65fb..bcce11acd02edd 100644 --- a/tensorflow/lite/delegates/hexagon/builders/mirror_pad_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/mirror_pad_builder.cc @@ -16,7 +16,7 @@ limitations under the License. #include -#include +#include #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/neg_op_builder.cc b/tensorflow/lite/delegates/hexagon/builders/neg_op_builder.cc index 93511dc491dad0..715aa3955793c8 100644 --- a/tensorflow/lite/delegates/hexagon/builders/neg_op_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/neg_op_builder.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/neg_op_builder.h" -#include +#include namespace tflite { namespace delegates { diff --git a/tensorflow/lite/delegates/hexagon/builders/op_builder.cc b/tensorflow/lite/delegates/hexagon/builders/op_builder.cc index 91258a418fd326..a3cb4157a5b3eb 100644 --- a/tensorflow/lite/delegates/hexagon/builders/op_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/op_builder.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/op_builder.h" +#include +#include + #include "hexagon/hexagon_nn_ops.h" #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/core/c/common.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/pack_builder.cc b/tensorflow/lite/delegates/hexagon/builders/pack_builder.cc index 7ccdb299d5d835..9d7cc75f7a9908 100644 --- a/tensorflow/lite/delegates/hexagon/builders/pack_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/pack_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/pad_builder.cc b/tensorflow/lite/delegates/hexagon/builders/pad_builder.cc index d49a3de4ab9b42..4047d438f309ca 100644 --- a/tensorflow/lite/delegates/hexagon/builders/pad_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/pad_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/pool_2d_builder.cc b/tensorflow/lite/delegates/hexagon/builders/pool_2d_builder.cc index 45529b68858c30..729d988c24935b 100644 --- a/tensorflow/lite/delegates/hexagon/builders/pool_2d_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/pool_2d_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/quantize_builder.cc b/tensorflow/lite/delegates/hexagon/builders/quantize_builder.cc index 6e653fd70e48fc..078f27161f34e1 100644 --- a/tensorflow/lite/delegates/hexagon/builders/quantize_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/quantize_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/reduce_builder.cc b/tensorflow/lite/delegates/hexagon/builders/reduce_builder.cc index a41a9fb23ee72e..38e3a2e6633de2 100644 --- a/tensorflow/lite/delegates/hexagon/builders/reduce_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/reduce_builder.cc @@ -16,7 +16,7 @@ limitations under the License. #include -#include +#include #include #include "tensorflow/lite/core/c/builtin_op_data.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/reshape_builder.cc b/tensorflow/lite/delegates/hexagon/builders/reshape_builder.cc index 5946abff4d1fd8..58e2cc80f00605 100644 --- a/tensorflow/lite/delegates/hexagon/builders/reshape_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/reshape_builder.cc @@ -16,7 +16,7 @@ limitations under the License. #include -#include +#include #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/resize_bilinear_builder.cc b/tensorflow/lite/delegates/hexagon/builders/resize_bilinear_builder.cc index 5cdd5398de1b29..8c846b41595946 100644 --- a/tensorflow/lite/delegates/hexagon/builders/resize_bilinear_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/resize_bilinear_builder.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/resize_bilinear_builder.h" +#include +#include + #include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { diff --git a/tensorflow/lite/delegates/hexagon/builders/resize_nearest_neighbor_builder.cc b/tensorflow/lite/delegates/hexagon/builders/resize_nearest_neighbor_builder.cc index 7276e9ad4500d9..b21665f30e568d 100644 --- a/tensorflow/lite/delegates/hexagon/builders/resize_nearest_neighbor_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/resize_nearest_neighbor_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/rsqrt_builder.cc b/tensorflow/lite/delegates/hexagon/builders/rsqrt_builder.cc index ad52495f54eaa3..f31800edb01e6c 100644 --- a/tensorflow/lite/delegates/hexagon/builders/rsqrt_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/rsqrt_builder.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include +#include #include "tensorflow/lite/delegates/hexagon/builders/op_builder.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/slice_builder.cc b/tensorflow/lite/delegates/hexagon/builders/slice_builder.cc index 05dfd3ffeb070e..149106d4350983 100644 --- a/tensorflow/lite/delegates/hexagon/builders/slice_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/slice_builder.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/slice_builder.h" +#include #include #include "tensorflow/lite/kernels/internal/tensor.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/softmax_builder.cc b/tensorflow/lite/delegates/hexagon/builders/softmax_builder.cc index 9915512856a2d1..28165875621516 100644 --- a/tensorflow/lite/delegates/hexagon/builders/softmax_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/softmax_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/space_to_depth_builder.cc b/tensorflow/lite/delegates/hexagon/builders/space_to_depth_builder.cc index 65e3899b79fe8c..6426fc36a0770b 100644 --- a/tensorflow/lite/delegates/hexagon/builders/space_to_depth_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/space_to_depth_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/c/builtin_op_data.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/split_builder.cc b/tensorflow/lite/delegates/hexagon/builders/split_builder.cc index 6ea35f60114e18..a3a0254df5cd82 100644 --- a/tensorflow/lite/delegates/hexagon/builders/split_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/split_builder.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include - #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/hexagon/builders/op_builder.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/squared_difference.cc b/tensorflow/lite/delegates/hexagon/builders/squared_difference.cc index b040aa0a12b993..51231f07fd79bc 100644 --- a/tensorflow/lite/delegates/hexagon/builders/squared_difference.cc +++ b/tensorflow/lite/delegates/hexagon/builders/squared_difference.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "hexagon/hexagon_nn_ops.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" diff --git a/tensorflow/lite/delegates/hexagon/builders/strided_slice_builder.cc b/tensorflow/lite/delegates/hexagon/builders/strided_slice_builder.cc index 9eabf5334199eb..257e1910455e1d 100644 --- a/tensorflow/lite/delegates/hexagon/builders/strided_slice_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/strided_slice_builder.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/builders/strided_slice_builder.h" +#include #include #include "tensorflow/lite/core/c/builtin_op_data.h" diff --git a/tensorflow/lite/delegates/interpreter_utils.cc b/tensorflow/lite/delegates/interpreter_utils.cc index 6574082597718a..767673e51595b4 100644 --- a/tensorflow/lite/delegates/interpreter_utils.cc +++ b/tensorflow/lite/delegates/interpreter_utils.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/lite/delegates/interpreter_utils.h" #include +#include +#include #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api_test.cc index 75111df2a7e6c9..c5d3c61b10f3a6 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include -#include #include +#include #include #include diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc index 4644fa74630153..cd33aa60e800da 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc @@ -16,11 +16,12 @@ limitations under the License. #include #include +#include #include +#include #include #include #include -#include #include #include #include diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_errno_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_errno_test.cc index 8fb14d4634e738..eb35af31ed5e33 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_errno_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_errno_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include "tensorflow/lite/core/c/common.h" diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_nnapi_failure_handling_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_nnapi_failure_handling_test.cc index 88ed5d911249f0..e123d8e1d2161b 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_nnapi_failure_handling_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_nnapi_failure_handling_test.cc @@ -14,15 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include -#include #include -#include #include -#include -#include #include -#include #include #include diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_signed_quantization_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_signed_quantization_test.cc index 473f4fd0e04469..9d3da57a99770f 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_signed_quantization_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_signed_quantization_test.cc @@ -12,8 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include +#include #include +#include #include #include "tensorflow/lite/core/c/common.h" diff --git a/tensorflow/lite/delegates/nnapi/quant_lstm_sup.cc b/tensorflow/lite/delegates/nnapi/quant_lstm_sup.cc index 4b1f4af1a09f77..3cb6454e48f779 100644 --- a/tensorflow/lite/delegates/nnapi/quant_lstm_sup.cc +++ b/tensorflow/lite/delegates/nnapi/quant_lstm_sup.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/delegates/nnapi/quant_lstm_sup.h" #include +#include +#include #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/kernel_util.h" diff --git a/tensorflow/lite/delegates/nnapi/quant_lstm_sup_test.cc b/tensorflow/lite/delegates/nnapi/quant_lstm_sup_test.cc index bcb549e2fe2ac4..ed8ddf20837a20 100644 --- a/tensorflow/lite/delegates/nnapi/quant_lstm_sup_test.cc +++ b/tensorflow/lite/delegates/nnapi/quant_lstm_sup_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include #include #include diff --git a/tensorflow/lite/delegates/utils/async_type_helpers.cc b/tensorflow/lite/delegates/utils/async_type_helpers.cc index 4f6904c45bfe27..2d8bf0b79fc325 100644 --- a/tensorflow/lite/delegates/utils/async_type_helpers.cc +++ b/tensorflow/lite/delegates/utils/async_type_helpers.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/lite/delegates/utils/async_type_helpers.h" #include -#include #include "tensorflow/lite/async/interop/c/attribute_map.h" #include "tensorflow/lite/async/interop/c/constants.h" diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index a0905e314a020b..9c7a9d7021fdaf 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -240,11 +240,11 @@ cc_library( ":file_util", ":flexbuffers_util", ":quantization_util", - ":tflite_with_xnnpack_dynamic_fully_connected", - ":tflite_with_xnnpack_logging", - ":tflite_with_xnnpack_qs8", - ":tflite_with_xnnpack_qu8", - ":tflite_with_xnnpack_transient_indirection_buffer", + ":tflite_with_xnnpack_dynamic_fully_connected", # buildcleaner: keep + ":tflite_with_xnnpack_logging", # buildcleaner: keep + ":tflite_with_xnnpack_qs8", # buildcleaner: keep + ":tflite_with_xnnpack_qu8", # buildcleaner: keep + ":tflite_with_xnnpack_transient_indirection_buffer", # buildcleaner: keep ":weight_cache", "//tensorflow/compiler/mlir/lite/kernels/internal:compatibility_macros", "//tensorflow/compiler/mlir/lite/tools/optimize:reduced_precision_metadata", @@ -257,7 +257,6 @@ cc_library( "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:padding", - "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/kernels/internal/utils:sparsity_format_converter", "//tensorflow/lite/schema:schema_fbs", @@ -301,7 +300,6 @@ cc_library( "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:padding", - "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/kernels/internal/utils:sparsity_format_converter", "//tensorflow/lite/schema:schema_fbs", diff --git a/tensorflow/lite/delegates/xnnpack/batch_matrix_multiply_test.cc b/tensorflow/lite/delegates/xnnpack/batch_matrix_multiply_test.cc index c65b6c336629c0..d1ead6f43c1046 100644 --- a/tensorflow/lite/delegates/xnnpack/batch_matrix_multiply_test.cc +++ b/tensorflow/lite/delegates/xnnpack/batch_matrix_multiply_test.cc @@ -65,10 +65,7 @@ TEST_F(BatchMatrixMultiplyTest, 3D) { .Test(xnnpack_delegate.get()); } -// TODO(b/332675940): This test is currently disabled since the TFLite default -// implementation of `BatchMatMul` can't handle per-channel quantized inputs. -TEST_F(BatchMatrixMultiplyTest, - DISABLED_DynamicallyQuantizedPerChannelWeights2D) { +TEST_F(BatchMatrixMultiplyTest, DynamicallyQuantizedPerChannelWeights2D) { const auto height = shape_rng(); const auto input_channels = channels_rng(); const auto output_channels = channels_rng(); @@ -81,10 +78,8 @@ TEST_F(BatchMatrixMultiplyTest, .Test(xnnpack_delegate.get()); } -// TODO(b/332675940): This test is currently disabled since the TFLite default -// implementation of `BatchMatMul` can't handle per-channel quantized inputs. TEST_F(BatchMatrixMultiplyTest, - DISABLED_DynamicallyQuantizedPerChannelWeights2DTransposeB) { + DynamicallyQuantizedPerChannelWeights2DTransposeB) { const auto height = shape_rng(); const auto input_channels = channels_rng(); const auto output_channels = channels_rng(); @@ -128,6 +123,36 @@ TEST_F(BatchMatrixMultiplyTest, .Test(xnnpack_delegate.get()); } +TEST_F(BatchMatrixMultiplyTest, DynamicallyQuantizedPerChannelWeights3D) { + const auto batch = shape_rng(); + const auto height = shape_rng(); + const auto input_channels = channels_rng(); + const auto output_channels = channels_rng(); + auto xnnpack_delegate = get_delegate(); + + BatchMatrixMultiplyTester() + .InputADims({batch, height, input_channels}) + .InputBDims({batch, input_channels, output_channels}) + .InputBQuant(BatchMatrixMultiplyTester::kChannel) + .Test(xnnpack_delegate.get()); +} + +TEST_F(BatchMatrixMultiplyTest, + DynamicallyQuantizedPerChannelWeights3DTransposeB) { + const auto batch = shape_rng(); + const auto height = shape_rng(); + const auto input_channels = channels_rng(); + const auto output_channels = channels_rng(); + auto xnnpack_delegate = get_delegate(); + + BatchMatrixMultiplyTester() + .InputADims({batch, height, input_channels}) + .InputBDims({batch, output_channels, input_channels}) + .InputBQuant(BatchMatrixMultiplyTester::kChannel) + .TransposeB(true) + .Test(xnnpack_delegate.get()); +} + TEST_F(BatchMatrixMultiplyTest, BroadcastOne3D) { const auto batch = shape_rng(); const auto height = shape_rng(); diff --git a/tensorflow/lite/delegates/xnnpack/unary_elementwise_test.cc b/tensorflow/lite/delegates/xnnpack/unary_elementwise_test.cc index 97cea8e5294ae1..afd77fad0607ad 100644 --- a/tensorflow/lite/delegates/xnnpack/unary_elementwise_test.cc +++ b/tensorflow/lite/delegates/xnnpack/unary_elementwise_test.cc @@ -34,6 +34,9 @@ ToleranceInfo GetTolerance(BuiltinOperator op) { return ToleranceInfo{.relative = 1.0e+4f}; case BuiltinOperator_GELU: return ToleranceInfo{.relative = 5.0f, .absolute = 10.0f}; + case BuiltinOperator_COS: + case BuiltinOperator_SIN: + return ToleranceInfo{.relative = 5.0f, .absolute = 3.0f}; default: return ToleranceInfo{}; } @@ -139,11 +142,15 @@ TEST_P(UnaryTest, MultiThreading) { } BuiltinOperator all_unary_ops[] = { - BuiltinOperator_ABS, BuiltinOperator_CEIL, BuiltinOperator_ELU, - BuiltinOperator_FLOOR, BuiltinOperator_GELU, BuiltinOperator_NEG, - BuiltinOperator_HARD_SWISH, BuiltinOperator_RELU, BuiltinOperator_RELU6, - BuiltinOperator_RELU_N1_TO_1, BuiltinOperator_ROUND, BuiltinOperator_RSQRT, - BuiltinOperator_SQRT, BuiltinOperator_SQUARE, BuiltinOperator_TANH, + BuiltinOperator_ABS, BuiltinOperator_CEIL, + BuiltinOperator_COS, BuiltinOperator_ELU, + BuiltinOperator_EXP, BuiltinOperator_FLOOR, + BuiltinOperator_GELU, BuiltinOperator_NEG, + BuiltinOperator_HARD_SWISH, BuiltinOperator_RELU, + BuiltinOperator_RELU6, BuiltinOperator_RELU_N1_TO_1, + BuiltinOperator_ROUND, BuiltinOperator_RSQRT, + BuiltinOperator_SIN, BuiltinOperator_SQRT, + BuiltinOperator_SQUARE, BuiltinOperator_TANH, BuiltinOperator_LOGISTIC, }; diff --git a/tensorflow/lite/delegates/xnnpack/weights_cache_test.cc b/tensorflow/lite/delegates/xnnpack/weights_cache_test.cc index c047ca4442735e..b6ae7246fad4ea 100644 --- a/tensorflow/lite/delegates/xnnpack/weights_cache_test.cc +++ b/tensorflow/lite/delegates/xnnpack/weights_cache_test.cc @@ -48,10 +48,6 @@ TEST(XNNPACK_WEIGHTS_CACHE, WithSize) { const Model* model = GetModel(buffer.data()); DummyOpResolver resolver; - std::unique_ptr interpreter; - ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter)); - ASSERT_EQ(kTfLiteOk, interpreter->AllocateTensors()); - size_t four_mb = 4194304; std::unique_ptr @@ -66,6 +62,10 @@ TEST(XNNPACK_WEIGHTS_CACHE, WithSize) { delegate(TfLiteXNNPackDelegateCreate(&delegate_options), TfLiteXNNPackDelegateDelete); + std::unique_ptr interpreter; + ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter)); + ASSERT_EQ(kTfLiteOk, interpreter->AllocateTensors()); + ASSERT_EQ(kTfLiteOk, interpreter->ModifyGraphWithDelegate(delegate.get())); ASSERT_TRUE( @@ -79,10 +79,6 @@ TEST(XNNPACK_WEIGHTS_CACHE, InvokeBeforeFinalization) { const Model* model = GetModel(buffer.data()); DummyOpResolver resolver; - std::unique_ptr interpreter; - ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter)); - ASSERT_EQ(kTfLiteOk, interpreter->AllocateTensors()); - std::unique_ptr weights_cache(TfLiteXNNPackDelegateWeightsCacheCreate(), @@ -96,6 +92,10 @@ TEST(XNNPACK_WEIGHTS_CACHE, InvokeBeforeFinalization) { delegate(TfLiteXNNPackDelegateCreate(&delegate_options), TfLiteXNNPackDelegateDelete); + std::unique_ptr interpreter; + ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter)); + ASSERT_EQ(kTfLiteOk, interpreter->AllocateTensors()); + ASSERT_EQ(kTfLiteOk, interpreter->ModifyGraphWithDelegate(delegate.get())); // Invoking before finalization fails. @@ -107,10 +107,6 @@ TEST(XNNPACK_WEIGHTS_CACHE, HardFinalization) { const Model* model = GetModel(buffer.data()); DummyOpResolver resolver; - std::unique_ptr interpreter1; - ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter1)); - ASSERT_EQ(kTfLiteOk, interpreter1->AllocateTensors()); - std::unique_ptr weights_cache(TfLiteXNNPackDelegateWeightsCacheCreate(), @@ -123,6 +119,11 @@ TEST(XNNPACK_WEIGHTS_CACHE, HardFinalization) { std::unique_ptr delegate1(TfLiteXNNPackDelegateCreate(&delegate_options), TfLiteXNNPackDelegateDelete); + + std::unique_ptr interpreter1; + ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter1)); + ASSERT_EQ(kTfLiteOk, interpreter1->AllocateTensors()); + ASSERT_EQ(kTfLiteOk, interpreter1->ModifyGraphWithDelegate(delegate1.get())); ASSERT_TRUE( TfLiteXNNPackDelegateWeightsCacheFinalizeHard(weights_cache.get())); @@ -131,12 +132,12 @@ TEST(XNNPACK_WEIGHTS_CACHE, HardFinalization) { // We cannot create new instances using the same weights cache after hard // finalization. - std::unique_ptr interpreter2; - ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter2)); - ASSERT_EQ(kTfLiteOk, interpreter2->AllocateTensors()); std::unique_ptr delegate2(TfLiteXNNPackDelegateCreate(&delegate_options), TfLiteXNNPackDelegateDelete); + std::unique_ptr interpreter2; + ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter2)); + ASSERT_EQ(kTfLiteOk, interpreter2->AllocateTensors()); ASSERT_NE(kTfLiteOk, interpreter2->ModifyGraphWithDelegate(delegate2.get())); } @@ -154,12 +155,13 @@ TEST(XNNPACK_WEIGHTS_CACHE, SoftFinalization) { TfLiteXNNPackDelegateOptionsDefault(); delegate_options.weights_cache = weights_cache.get(); - std::unique_ptr interpreter1; - ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter1)); - ASSERT_EQ(kTfLiteOk, interpreter1->AllocateTensors()); std::unique_ptr delegate1(TfLiteXNNPackDelegateCreate(&delegate_options), TfLiteXNNPackDelegateDelete); + + std::unique_ptr interpreter1; + ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter1)); + ASSERT_EQ(kTfLiteOk, interpreter1->AllocateTensors()); ASSERT_EQ(kTfLiteOk, interpreter1->ModifyGraphWithDelegate(delegate1.get())); ASSERT_TRUE( @@ -168,12 +170,12 @@ TEST(XNNPACK_WEIGHTS_CACHE, SoftFinalization) { ASSERT_EQ(kTfLiteOk, interpreter1->Invoke()); // Build a second interpreter, it should work after soft finalization. - std::unique_ptr interpreter2; - ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter2)); - ASSERT_EQ(kTfLiteOk, interpreter2->AllocateTensors()); std::unique_ptr delegate2(TfLiteXNNPackDelegateCreate(&delegate_options), TfLiteXNNPackDelegateDelete); + std::unique_ptr interpreter2; + ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter2)); + ASSERT_EQ(kTfLiteOk, interpreter2->AllocateTensors()); ASSERT_EQ(kTfLiteOk, interpreter2->ModifyGraphWithDelegate(delegate2.get())); ASSERT_EQ(kTfLiteOk, interpreter2->Invoke()); } @@ -196,13 +198,13 @@ TEST_P(WeightsCacheTest, SoftFinalizationMultithreaded) { delegate_options.weights_cache = weights_cache.get(); // Create the first interpreter and finalize it. + std::unique_ptr + initial_delegate(TfLiteXNNPackDelegateCreate(&delegate_options), + TfLiteXNNPackDelegateDelete); std::unique_ptr initial_interpreter; ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&initial_interpreter)); ASSERT_EQ(kTfLiteOk, initial_interpreter->AllocateTensors()); - std::unique_ptr - initial_delegate(TfLiteXNNPackDelegateCreate(&delegate_options), - TfLiteXNNPackDelegateDelete); ASSERT_EQ(kTfLiteOk, initial_interpreter->ModifyGraphWithDelegate( initial_delegate.get())); @@ -221,14 +223,14 @@ TEST_P(WeightsCacheTest, SoftFinalizationMultithreaded) { threads.reserve(num_threads); for (size_t i = 0; i < num_threads; i++) { threads.emplace_back(std::thread([&] { - std::unique_ptr interpreter; - ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter)); - ASSERT_EQ(kTfLiteOk, interpreter->AllocateTensors()); - std::unique_ptr delegate(TfLiteXNNPackDelegateCreate(&delegate_options), TfLiteXNNPackDelegateDelete); + std::unique_ptr interpreter; + ASSERT_EQ(kTfLiteOk, InterpreterBuilder(model, resolver)(&interpreter)); + ASSERT_EQ(kTfLiteOk, interpreter->AllocateTensors()); + ASSERT_EQ(kTfLiteOk, interpreter->ModifyGraphWithDelegate(delegate.get())); ASSERT_EQ(kTfLiteOk, interpreter->Invoke()); diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index ed9bcf7a47168b..f005f69c545a73 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -71,6 +71,13 @@ namespace tflite { namespace xnnpack { namespace { +// VisitDotAttentionNode uses a clamp to add a constant value to the XNNPack +// subgraph. The constant data must outlive the XNNPack delegate and there is no +// simple way of doing this. Therefore a clamp was used to clamp some arbitrary +// data to this constant value. The static input data to the clamp can be +// anything. +const float kConstantClampData = 0.f; + constexpr char kOdmlSDPA[] = "odml.scaled_dot_product_attention"; template @@ -91,6 +98,96 @@ void CopyTensorDataInt32OrInt64(int64_t* dst, const TfLiteTensor& tensor, } } +bool CheckZeroPoint(TfLiteContext* context, const TfLiteTensor& tensor, int t, + const TfLiteIntArray* quantization_zero_point) { + if (quantization_zero_point == nullptr) { + TF_LITE_KERNEL_LOG(context, + "missing zero point quantization parameters for " + "%s tensor %d in XNNPACK delegate", + TfLiteTypeGetName(tensor.type), t); + return false; + } + return true; +} + +bool CheckFp16Scale(TfLiteContext* context, const TfLiteTensor& tensor, int t, + const TfLiteBlockwiseQuantization* quantization_params) { + const TfLiteTensor& scale = context->tensors[quantization_params->scale]; + int num_scales = NumElements(&scale); + std::vector dequantized_scale(num_scales); + DequantizeFloat16(reinterpret_cast(scale.data.data), + dequantized_scale.data(), num_scales); + for (int i = 0; i < num_scales; i++) { + if (!std::isnormal(dequantized_scale[i]) || dequantized_scale[i] <= 0.0f) { + TF_LITE_KERNEL_LOG(context, + "unsupported scale value (%f) in channel %d for " + "%s tensor %d in XNNPACK delegate", + dequantized_scale[i], i, + TfLiteTypeGetName(tensor.type), t); + return false; + } + } + return true; +} + +bool CheckFp32Scale(TfLiteContext* context, const TfLiteTensor& tensor, int t, + const TfLiteFloatArray* quantization_scale, + const TfLiteIntArray* quantization_zero_point) { + if (quantization_scale == nullptr) { + TF_LITE_KERNEL_LOG(context, + "missing scale quantization parameters for %s " + "tensor %d in XNNPACK delegate", + TfLiteTypeGetName(tensor.type), t); + return false; + } + if (quantization_zero_point != nullptr && + quantization_scale->size != quantization_zero_point->size) { + TF_LITE_KERNEL_LOG(context, + "mismatching number of scale (%d) and zero " + "point (%d) quantization parameters for %s " + "tensor %d in XNNPACK delegate", + quantization_scale->size, quantization_zero_point->size, + TfLiteTypeGetName(tensor.type), t); + return false; + } + for (int i = 0; i < quantization_scale->size; i++) { + const float scale = quantization_scale->data[i]; + if (!std::isnormal(scale) || scale <= 0.0f) { + TF_LITE_KERNEL_LOG(context, + "unsupported scale value (%f) in channel %d for " + "%s tensor %d in XNNPACK delegate", + scale, i, TfLiteTypeGetName(tensor.type), t); + return false; + } + } + return true; +} + +xnn_datatype CheckPerTensorQuantization( + TfLiteContext* context, const TfLiteTensor& tensor, int t, + const TfLiteFloatArray* quantization_scale, + const TfLiteIntArray* quantization_zero_point) { + // Per-tensor quantization parameters + if (kTfLiteInt8 != tensor.type) { + TF_LITE_KERNEL_LOG(context, + "unsupported per-tensor quantization scale " + "parameter for %s tensor %d in XNNPACK delegate", + TfLiteTypeGetName(tensor.type), t); + return xnn_datatype_invalid; + } + + const int zero_point = quantization_zero_point->data[0]; + if (zero_point < std::numeric_limits::min() || + zero_point > std::numeric_limits::max()) { + TF_LITE_KERNEL_LOG(context, + "unsupported zero-point value (%d) for INT8 " + "tensor %d in XNNPACK delegate", + zero_point, t); + return xnn_datatype_invalid; + } + return xnn_datatype_qint8; +} + xnn_datatype GetXNNPackDatatype(TfLiteContext* context, const TfLiteTensor& tensor, int t) { switch (tensor.type) { @@ -163,111 +260,108 @@ xnn_datatype GetXNNPackDatatype(TfLiteContext* context, } case kTfLiteInt8: case kTfLiteInt4: { - if (tensor.quantization.type != kTfLiteAffineQuantization) { - TF_LITE_KERNEL_LOG(context, - "unsupported quantization type %d for %s " - "tensor %d in XNNPACK delegate", - tensor.quantization.type, - TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; - } - const auto quantization_params = - static_cast( - tensor.quantization.params); - if (quantization_params->scale == nullptr) { - TF_LITE_KERNEL_LOG(context, - "missing scale quantization parameters for %s " - "tensor %d in XNNPACK delegate", - TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; - } - if (quantization_params->zero_point == nullptr) { - TF_LITE_KERNEL_LOG(context, - "missing zero point quantization parameters for " - "%s tensor %d in XNNPACK delegate", - TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; - } - if (quantization_params->scale->size != - quantization_params->zero_point->size) { - TF_LITE_KERNEL_LOG(context, - "mismatching number of scale (%d) and zero " - "point (%d) quantization parameters for %s " - "tensor %d in XNNPACK delegate", - quantization_params->scale->size, - quantization_params->zero_point->size, - TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; - } - - for (int i = 0; i < quantization_params->scale->size; i++) { - const float scale = quantization_params->scale->data[i]; - if (!std::isnormal(scale) || scale <= 0.0f) { - TF_LITE_KERNEL_LOG(context, - "unsupported scale value (%f) in channel %d for " - "%s tensor %d in XNNPACK delegate", - scale, i, TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; + switch (tensor.quantization.type) { + case kTfLiteAffineQuantization: { + const auto quantization_params = + static_cast( + tensor.quantization.params); + const auto quantization_scale = quantization_params->scale; + const auto quantization_zero_point = quantization_params->zero_point; + if (!CheckFp32Scale(context, tensor, t, quantization_scale, + quantization_zero_point)) { + return xnn_datatype_invalid; + } + if (quantization_scale->size == 1) { + return CheckPerTensorQuantization(context, tensor, t, + quantization_scale, + quantization_zero_point); + } + if (!CheckZeroPoint(context, tensor, t, quantization_zero_point)) { + return xnn_datatype_invalid; + } + if (NumDimensions(&tensor) >= 1 && + quantization_scale->size == + SizeOfDimension(&tensor, + quantization_params->quantized_dimension)) { + // Per-channel quantization parameters + for (int c = 0; + c < SizeOfDimension(&tensor, + quantization_params->quantized_dimension); + c++) { + if (quantization_params->zero_point->data[c] != 0 && + (tensor.type != kTfLiteInt4 && + quantization_params->zero_point->data[c] != 8)) { + TF_LITE_KERNEL_LOG(context, + "unsupported zero-point value %d in channel " + "%d of %s tensor %d in XNNPACK delegate", + quantization_params->zero_point->data[c], c, + TfLiteTypeGetName(tensor.type), t); + return xnn_datatype_invalid; + } + } + } else { + TF_LITE_KERNEL_LOG( + context, + "mismatching number of quantization parameters %d and outer " + "dimension %d for INT8 tensor %d in XNNPACK delegate", + quantization_params->scale->size, + SizeOfDimension(&tensor, + quantization_params->quantized_dimension), + t); + return xnn_datatype_invalid; + } + break; } - } - - if (quantization_params->scale->size == 1) { - // Per-tensor quantization parameters - if (kTfLiteInt8 != tensor.type) { - TF_LITE_KERNEL_LOG(context, - "unsupported per-tensor quantization scale " - "parameter for %s tensor %d in XNNPACK delegate", - TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; + case kTfLiteBlockwiseQuantization: { + const auto quantization_params = + reinterpret_cast( + tensor.quantization.params); + if (!CheckFp16Scale(context, tensor, t, quantization_params)) { + return xnn_datatype_invalid; + } + int64_t num_scales = + NumElements(&context->tensors[quantization_params->scale]); + int64_t num_filter_elements = NumElements(&tensor); + if (num_filter_elements / num_scales != + quantization_params->blocksize) { + TF_LITE_KERNEL_LOG( + context, + "Unsupported combination of filter elements %" PRId64 + " number of scales %" PRId64 " and blocksize %" PRId32 + " for %s tensor %d in XNNPACK delegate", + num_filter_elements, num_scales, quantization_params->blocksize, + tensor.name, t); + return xnn_datatype_invalid; + } + break; } - - const int zero_point = quantization_params->zero_point->data[0]; - if (zero_point < std::numeric_limits::min() || - zero_point > std::numeric_limits::max()) { + default: TF_LITE_KERNEL_LOG(context, - "unsupported zero-point value (%d) for INT8 " + "unsupported quantization type %d for %s " "tensor %d in XNNPACK delegate", - zero_point, t); + tensor.quantization.type, + TfLiteTypeGetName(tensor.type), t); return xnn_datatype_invalid; - } - return xnn_datatype_qint8; - } else if (NumDimensions(&tensor) >= 1 && - quantization_params->scale->size == - SizeOfDimension( - &tensor, quantization_params->quantized_dimension)) { - // Per-channel quantization parameters - for (int c = 0; - c < - SizeOfDimension(&tensor, quantization_params->quantized_dimension); - c++) { - if (quantization_params->zero_point->data[c] != 0 && - (tensor.type != kTfLiteInt4 && - quantization_params->zero_point->data[c] != 8)) { - TF_LITE_KERNEL_LOG(context, - "unsupported zero-point value %d in channel " - "%d of %s tensor %d in XNNPACK delegate", - quantization_params->zero_point->data[c], c, - TfLiteTypeGetName(tensor.type), t); - return xnn_datatype_invalid; + } + + switch (tensor.type) { + case kTfLiteInt4: + switch (tensor.quantization.type) { + case kTfLiteAffineQuantization: + return xnn_datatype_qcint4; + case kTfLiteBlockwiseQuantization: + return xnn_datatype_qbint4; + default: + TF_LITE_KERNEL_LOG(context, + "Unsupported quantization type %d for INT4 " + "tensor %d in XNNPACK delegate", + tensor.quantization.type, t); + return xnn_datatype_invalid; } - } - switch (tensor.type) { - case kTfLiteInt4: - return xnn_datatype_qcint4; - case kTfLiteInt8: - return xnn_datatype_qcint8; - default: - return xnn_datatype_invalid; - } - } else { - TF_LITE_KERNEL_LOG( - context, - "mismatching number of quantization parameters %d and outer " - "dimension %d for INT8 tensor %d in XNNPACK delegate", - quantization_params->scale->size, - SizeOfDimension(&tensor, quantization_params->quantized_dimension), - t); - return xnn_datatype_invalid; + case kTfLiteInt8: + return xnn_datatype_qcint8; + default: + return xnn_datatype_invalid; } break; } @@ -623,6 +717,8 @@ class Delegate { return (options_.flags & TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SLINKY) != 0; } + uint32_t runtime_flags() const { return options_.runtime_flags; } + bool support_variable_ops() const { if (options_.flags & TFLITE_XNNPACK_DELEGATE_FLAG_VARIABLE_OPERATORS) { return true; @@ -1118,6 +1214,20 @@ class Subgraph { ->quantized_dimension, dims.data(), data, XNN_INVALID_VALUE_ID, flags, &xnnpack_id); break; + case xnn_datatype_qbint4: { + const auto* quantization_params = + reinterpret_cast( + context->tensors[t].quantization.params); + const TfLiteTensor& scale_tensor = + context->tensors[quantization_params->scale]; + status = xnn_define_blockwise_quantized_tensor_value_v2( + subgraph.get(), datatype, 0, + reinterpret_cast(scale_tensor.data.data), + dims.size(), quantization_params->quantized_dimension, + quantization_params->blocksize, dims.data(), data, + XNN_INVALID_VALUE_ID, flags, xnn_datatype_fp16, &xnnpack_id); + break; + } default: status = xnn_define_tensor_value( subgraph.get(), datatype, dims.size(), dims.data(), data, @@ -1206,6 +1316,7 @@ class Subgraph { constexpr uint32_t XNN_FLAG_SLINKY_ENABLED = 0x40000000; flags |= XNN_FLAG_SLINKY_ENABLED; } + flags |= delegate.runtime_flags(); if (delegate.weight_cache_provider_.IsActive() && delegate.weight_cache_provider_.CanStartBuildStep()) { @@ -1901,6 +2012,22 @@ class Subgraph { node_index); } + static TfLiteStatus CheckTensorFloatType(TfLiteContext* context, + const TfLiteTensor& tensor, + int tensor_index, int node_index) { + switch (tensor.type) { + case kTfLiteFloat32: + case kTfLiteFloat16: + return kTfLiteOk; + default: + TF_LITE_MAYBE_KERNEL_LOG( + context, "%s: unsupported type %s in tensor #%d in node #%d", + __FUNCTION__, TfLiteTypeGetName(tensor.type), tensor_index, + node_index); + return kTfLiteError; + } + } + static TfLiteStatus CheckTensorFloat32OrQInt8Type(const Delegate& delegate, TfLiteContext* context, const TfLiteTensor& tensor, @@ -2168,32 +2295,59 @@ class Subgraph { case kTfLiteInt8: if (delegate.support_signed_8bit_quantization() && (kTfLiteInt8 == tensor.type || kTfLiteInt4 == tensor.type)) { - if (tensor.quantization.type != kTfLiteAffineQuantization) { - TF_LITE_MAYBE_KERNEL_LOG( - context, - "unsupported quantization type %d in tensor #%d in node #%d", - tensor.quantization.type, tensor_index, node_index); - return kTfLiteError; - } - const TfLiteAffineQuantization* quantization_params = - static_cast( - tensor.quantization.params); - if (quantization_params->scale == nullptr) { - TF_LITE_MAYBE_KERNEL_LOG(context, - "missing scale quantization parameters in " - "tensor #%d in node #%d", - tensor_index, node_index); - return kTfLiteError; - } - if (quantization_params->scale->size > 1 && - quantization_params->quantized_dimension != - expected_quantized_dimension) { - TF_LITE_MAYBE_KERNEL_LOG( - context, - "unsupported quantized dimension %d in tensor #%d in node #%d", - quantization_params->quantized_dimension, tensor_index, - node_index); - return kTfLiteError; + switch (tensor.quantization.type) { + case kTfLiteAffineQuantization: { + const TfLiteAffineQuantization* quantization_params = + static_cast( + tensor.quantization.params); + if (quantization_params->scale == nullptr) { + TF_LITE_MAYBE_KERNEL_LOG( + context, + "missing scale quantization parameters in " + "tensor #%d in node #%d", + tensor_index, node_index); + return kTfLiteError; + } + if (quantization_params->scale->size > 1 && + quantization_params->quantized_dimension != + expected_quantized_dimension) { + TF_LITE_MAYBE_KERNEL_LOG( + context, + "unsupported quantized dimension %d in tensor #%d in node " + "#%d", + quantization_params->quantized_dimension, tensor_index, + node_index); + return kTfLiteError; + } + break; + } + case kTfLiteBlockwiseQuantization: { + const TfLiteBlockwiseQuantization* quantization_params = + reinterpret_cast( + tensor.quantization.params); + if (quantization_params->scale == kTfLiteOptionalTensor) { + TF_LITE_MAYBE_KERNEL_LOG( + context, + "missing scale quantization parameters in " + "tensor #%d in node #%d", + tensor_index, node_index); + return kTfLiteError; + } + if (quantization_params->blocksize % 32 != 0) { + TF_LITE_MAYBE_KERNEL_LOG( + context, + "Blocksize %" PRId32 + " must be multiple of 32 in tensor #%d in node #%d", + quantization_params->blocksize, tensor_index, node_index); + return kTfLiteError; + } + break; + } + default: + TF_LITE_MAYBE_KERNEL_LOG( + context, + "unsupported quantization type %d in tensor #%d in node #%d", + tensor.quantization.type, tensor_index, node_index); } return kTfLiteOk; } @@ -2696,16 +2850,44 @@ class Subgraph { #endif switch (registration->builtin_code) { case kTfLiteBuiltinAbs: - return VisitAbsNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinAdd: { - const TfLiteAddParams* add_params = - static_cast(node->builtin_data); + case kTfLiteBuiltinCeil: + case kTfLiteBuiltinCos: + case kTfLiteBuiltinDequantize: + case kTfLiteBuiltinElu: + case kTfLiteBuiltinExp: + case kTfLiteBuiltinFloor: + case kTfLiteBuiltinGelu: + case kTfLiteBuiltinHardSwish: + case kTfLiteBuiltinLeakyRelu: + case kTfLiteBuiltinLogistic: + case kTfLiteBuiltinNeg: + case kTfLiteBuiltinQuantize: + case kTfLiteBuiltinRelu: + case kTfLiteBuiltinRelu6: + case kTfLiteBuiltinReluN1To1: + case kTfLiteBuiltinRound: + case kTfLiteBuiltinRsqrt: + case kTfLiteBuiltinSin: + case kTfLiteBuiltinSqrt: + case kTfLiteBuiltinSquare: + case kTfLiteBuiltinTanh: + return VisitUnaryNode(subgraph, delegate, logging_context, node_index, + node, (BuiltinOperator)registration->builtin_code, + context->tensors, input_output_tensors); + + case kTfLiteBuiltinAdd: + case kTfLiteBuiltinDiv: + case kTfLiteBuiltinMaximum: + case kTfLiteBuiltinMinimum: + case kTfLiteBuiltinMul: + case kTfLiteBuiltinPrelu: + case kTfLiteBuiltinSquaredDifference: + case kTfLiteBuiltinSub: + return VisitBinaryNode( + subgraph, delegate, logging_context, node_index, node, + (BuiltinOperator)registration->builtin_code, context->tensors, + quasi_static_tensors, input_output_tensors); - return VisitAddNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, add_params, - input_output_tensors); - } case kTfLiteBuiltinAssignVariable: return VisitAssignVariableNode(subgraph, delegate, logging_context, node_index, node, context->tensors, @@ -2726,9 +2908,6 @@ class Subgraph { node_index, node, context->tensors, batchmatmul_params, input_output_tensors); } - case kTfLiteBuiltinCeil: - return VisitCeilNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); case kTfLiteBuiltinConcatenation: { const TfLiteConcatenationParams* concat_params = static_cast(node->builtin_data); @@ -2761,21 +2940,6 @@ class Subgraph { subgraph, delegate, logging_context, node_index, node, context->tensors, depth_to_space_params, input_output_tensors); } - case kTfLiteBuiltinDequantize: - return VisitDequantizeNode(subgraph, delegate, logging_context, - node_index, node, context->tensors, - input_output_tensors); - case kTfLiteBuiltinDiv: { - const TfLiteDivParams* div_params = - static_cast(node->builtin_data); - - return VisitDivNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, div_params, - input_output_tensors); - } - case kTfLiteBuiltinElu: - return VisitEluNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); case kTfLiteBuiltinExpandDims: return VisitExpandDimsNode(subgraph, delegate, logging_context, node_index, node, context->tensors, @@ -2798,28 +2962,6 @@ class Subgraph { fc_params, quasi_static_tensors, input_output_tensors); } - case kTfLiteBuiltinFloor: - return VisitFloorNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinGelu: - return VisitGeluNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinHardSwish: - return VisitHardSwishNode(subgraph, delegate, logging_context, - node_index, node, context->tensors, - input_output_tensors); - case kTfLiteBuiltinLeakyRelu: { - const TfLiteLeakyReluParams* leaky_relu_params = - static_cast(node->builtin_data); - - return VisitLeakyReluNode(subgraph, delegate, logging_context, - node_index, node, context->tensors, - leaky_relu_params, input_output_tensors); - } - case kTfLiteBuiltinLogistic: - return VisitLogisticNode(subgraph, delegate, logging_context, - node_index, node, context->tensors, - input_output_tensors); case kTfLiteBuiltinMaxPool2d: { const TfLitePoolParams* pool_params = static_cast(node->builtin_data); @@ -2837,9 +2979,6 @@ class Subgraph { context->tensors, reducer_params, input_output_tensors); } - case kTfLiteBuiltinMaximum: - return VisitMaximumNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); case kTfLiteBuiltinMean: { const TfLiteReducerParams* reducer_params = static_cast(node->builtin_data); @@ -2848,48 +2987,13 @@ class Subgraph { context->tensors, reducer_params, input_output_tensors); } - case kTfLiteBuiltinMinimum: - return VisitMinimumNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinMul: { - const TfLiteMulParams* mul_params = - static_cast(node->builtin_data); - - return VisitMulNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, mul_params, - input_output_tensors); - } - case kTfLiteBuiltinNeg: - return VisitNegNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); case kTfLiteBuiltinPad: return VisitPadNode(subgraph, delegate, logging_context, node_index, node, context->tensors, input_output_tensors); - case kTfLiteBuiltinPrelu: - return VisitPreluNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, quasi_static_tensors, - input_output_tensors); - case kTfLiteBuiltinQuantize: - return VisitQuantizeNode(subgraph, delegate, logging_context, - node_index, node, context->tensors, - input_output_tensors); case kTfLiteBuiltinReadVariable: return VisitReadVariableNode(subgraph, delegate, logging_context, node_index, node, context->tensors, input_output_tensors); - case kTfLiteBuiltinRelu: - return VisitReluNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, 0.0f, - std::numeric_limits::infinity(), - input_output_tensors); - case kTfLiteBuiltinReluN1To1: - return VisitReluNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, -1.0f, 1.0f, - input_output_tensors); - case kTfLiteBuiltinRelu6: - return VisitReluNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, 0.0f, 6.0f, - input_output_tensors); case kTfLiteBuiltinReshape: { const TfLiteReshapeParams* reshape_params = static_cast(node->builtin_data); @@ -2906,12 +3010,6 @@ class Subgraph { node_index, node, context->tensors, resize_params, input_output_tensors); } - case kTfLiteBuiltinRound: - return VisitRoundNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinRsqrt: - return VisitRsqrtNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); case kTfLiteBuiltinSlice: return VisitSliceNode(subgraph, delegate, logging_context, node_index, node, context->tensors, input_output_tensors); @@ -2938,16 +3036,6 @@ class Subgraph { node, context->tensors, split_params, input_output_tensors); } - case kTfLiteBuiltinSqrt: - return VisitSqrtNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinSquare: - return VisitSquareNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); - case kTfLiteBuiltinSquaredDifference: - return VisitSquaredDifferenceNode(subgraph, delegate, logging_context, - node_index, node, context->tensors, - input_output_tensors); case kTfLiteBuiltinStridedSlice: { const auto* params = static_cast(node->builtin_data); @@ -2955,17 +3043,6 @@ class Subgraph { node_index, node, context->tensors, params, input_output_tensors); } - case kTfLiteBuiltinSub: { - const TfLiteSubParams* sub_params = - static_cast(node->builtin_data); - - return VisitSubNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, sub_params, - input_output_tensors); - } - case kTfLiteBuiltinTanh: - return VisitTanhNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, input_output_tensors); case kTfLiteBuiltinTranspose: { return VisitTransposeNode(subgraph, delegate, logging_context, node_index, node, context->tensors, @@ -3051,119 +3128,13 @@ class Subgraph { } } - static TfLiteStatus VisitAbsNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, + static TfLiteStatus VisitAssignVariableNode( + xnn_subgraph_t subgraph, Delegate& delegate, + TfLiteContext* logging_context, int node_index, const TfLiteNode* node, const TfLiteTensor* tensors, const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_ABS, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_abs( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_ABS), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitAddNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, const TfLiteAddParams* add_params, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_ADD, node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input1_tensor, - node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input1_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_ADD, node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input2_tensor, - node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input2_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], - BuiltinOperator_ADD, node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - - if (input1_tensor.type != input2_tensor.type || - input1_tensor.type != output_tensor.type) { - TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "unsupported mixed types in ADD operator #%d", - node_index); - return kTfLiteError; - } - const float scale_min = 1.0f / 1024.0f; - const float scale_max = 256.0f; - TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( - logging_context, input1_tensor, output_tensor, scale_min, scale_max, - BuiltinOperator_ADD, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( - logging_context, input2_tensor, output_tensor, scale_min, scale_max, - BuiltinOperator_ADD, node_index)); - - float output_min = -std::numeric_limits::infinity(); - float output_max = +std::numeric_limits::infinity(); - if (add_params != nullptr) { - TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( - logging_context, node_index, add_params->activation, &output_min, - &output_max)); - } - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_add2( - subgraph, output_min, output_max, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_ADD), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitAssignVariableNode( - xnn_subgraph_t subgraph, Delegate& delegate, - TfLiteContext* logging_context, int node_index, const TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - if (!delegate.support_variable_ops()) { - return kTfLiteError; + if (!delegate.support_variable_ops()) { + return kTfLiteError; } if (subgraph == nullptr) { const int resource_tensor_id = node->inputs->data[0]; @@ -3342,31 +3313,44 @@ class Subgraph { } // Validate or create the quantization parameters for the per-channel - // quantized input_b. Note that we currently only expect the `B` tensor - // to be per-tensor quantized, and not per-channel (see b/332675940). + // quantized input_b. TfLiteAffineQuantization* quant_params_b = reinterpret_cast( input_b.quantization.params); + const int num_quant_params = quant_params_b->scale->size; + float* scale_b = quant_params_b->scale->data; + const int zero_point_b = num_quant_params > 1 + ? quant_params_b->zero_point->data[0] + : input_b.params.zero_point; + int32_t quantized_dimension = quant_params_b->quantized_dimension; if (quant_params_b->scale->size != batch_size_b * n) { - if (quant_params_b->scale->size != 1) { + if ((batch_size_b * n) % num_quant_params) { TF_LITE_MAYBE_KERNEL_LOG( logging_context, "failed to delegate %s node #%d. unexpected number of " - "quantizations scales (expected %d or 1, got %d)", + "quantizations scales (expected a divisor of %d, got %d)", EnumNameBuiltinOperator(BuiltinOperator_BATCH_MATMUL), - node_index, batch_size_b * n, quant_params_b->scale->size); + node_index, batch_size_b * n, num_quant_params); return kTfLiteError; } + TfLiteFloatArray* new_scale_b = + TfLiteFloatArrayCreate(num_quant_params + batch_size_b * n); + if (num_quant_params == 1) { + std::fill_n(new_scale_b->data, new_scale_b->size, + input_b.params.scale); + } else { + std::copy_n(quant_params_b->scale->data, num_quant_params, + new_scale_b->data); + for (int k = 0; k < batch_size_b * n; k++) { + new_scale_b->data[num_quant_params + k] = + quant_params_b->scale->data[k % num_quant_params]; + } + } TfLiteFloatArrayFree(quant_params_b->scale); - quant_params_b->scale = TfLiteFloatArrayCreate(batch_size_b * n); - std::fill_n(quant_params_b->scale->data, batch_size_b * n, - input_b.params.scale); - TfLiteIntArrayFree(quant_params_b->zero_point); - quant_params_b->zero_point = TfLiteIntArrayCreate(batch_size_b * n); - std::fill_n(quant_params_b->zero_point->data, batch_size_b * n, - input_b.params.zero_point); - quant_params_b->quantized_dimension = - params->adj_y ? num_dims_b - 2 : num_dims_b - 1; + new_scale_b->size = num_quant_params; + quant_params_b->scale = new_scale_b; + scale_b = new_scale_b->data + num_quant_params; + quantized_dimension = params->adj_y ? num_dims_b - 2 : num_dims_b - 1; } // Create the quantized input_b. @@ -3374,12 +3358,11 @@ class Subgraph { for (int i = 0; i < num_dims_b; ++i) { dims_b[i] = SizeOfDimension(&input_b, i); } - const int32_t zero_point_value = quant_params_b->zero_point->data[0]; uint32_t cq_input_b_id = XNN_INVALID_VALUE_ID; if (xnn_status status = xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint8, zero_point_value, - quant_params_b->scale->data, dims_b.size(), + subgraph, xnn_datatype_qcint8, zero_point_b, scale_b, + dims_b.size(), /*channel_dim=*/ (params->adj_y ? num_dims_b - 2 : num_dims_b - 1), dims_b.data(), GetTensorData(&input_b), @@ -3452,39 +3435,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitCeilNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_CEIL, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_ceiling( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_CEIL), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitConcatenationNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -3960,37 +3910,198 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitDequantizeNode( + static TfLiteStatus VisitBinaryNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, + tflite::BuiltinOperator op_type, const TfLiteTensor* tensors, + const std::unordered_set& quasi_static_tensors, const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_DEQUANTIZE, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorQInt8OrQUInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); + // Get the input and output tensors. + TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs(logging_context, node, 2, 1, + op_type, node_index)); + const int input1_id = node->inputs->data[0]; + const int input2_id = node->inputs->data[1]; + const int output_id = node->outputs->data[0]; + const TfLiteTensor& input1_tensor = tensors[input1_id]; + const TfLiteTensor& input2_tensor = tensors[input2_id]; + const TfLiteTensor& output_tensor = tensors[output_id]; + + // Check the input shapes. TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_DEQUANTIZE, node_index)); + logging_context, input1_tensor, /*min_num_dims=*/0, + /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, input1_id, op_type, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorShape( + logging_context, input2_tensor, /*min_num_dims=*/0, + /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, input2_id, op_type, node_index)); + + // Check the input/output tensor types. + switch (op_type) { + case BuiltinOperator_ADD: + case BuiltinOperator_MUL: + case BuiltinOperator_SUB: + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, input1_tensor, input1_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, input2_tensor, input2_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, output_tensor, output_id, node_index)); + if (input1_tensor.type != input2_tensor.type || + input1_tensor.type != output_tensor.type) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, "unsupported mixed types in %s operator #%d", + EnumNameBuiltinOperator(op_type), node_index); + return kTfLiteError; + } + break; + case BuiltinOperator_DIV: + case BuiltinOperator_MAXIMUM: + case BuiltinOperator_MINIMUM: + case BuiltinOperator_PRELU: + case BuiltinOperator_SQUARED_DIFFERENCE: + TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( + logging_context, input1_tensor, input1_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( + logging_context, input2_tensor, input2_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( + logging_context, output_tensor, output_id, node_index)); + break; + default: + TF_LITE_KERNEL_LOG( + logging_context, + "failed to delegate %s node #%d as a binary operator", + EnumNameBuiltinOperator(op_type), node_index); + return kTfLiteError; + } - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); + // Extract any op-specific params. + float output_min = -std::numeric_limits::infinity(); + float output_max = +std::numeric_limits::infinity(); + switch (op_type) { + case BuiltinOperator_ADD: { + const float scale_min = 1.0f / 1024.0f; + const float scale_max = 256.0f; + TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( + logging_context, input1_tensor, output_tensor, scale_min, scale_max, + op_type, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( + logging_context, input2_tensor, output_tensor, scale_min, scale_max, + op_type, node_index)); + const TfLiteAddParams* add_params = + static_cast(node->builtin_data); + if (add_params != nullptr) { + TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( + logging_context, node_index, add_params->activation, &output_min, + &output_max)); + } + break; + } + case BuiltinOperator_DIV: { + const TfLiteDivParams* div_params = + static_cast(node->builtin_data); + if (div_params != nullptr) { + TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( + logging_context, node_index, div_params->activation, &output_min, + &output_max)); + } + break; + } + case BuiltinOperator_MUL: { + const float scale_min = 1.0f / 65536.0f; + const float scale_max = 256.0f; + TF_LITE_ENSURE_STATUS(CheckTensorsInputProductOutputScale( + logging_context, input1_tensor, input2_tensor, output_tensor, + scale_min, scale_max, op_type, node_index)); + const TfLiteMulParams* mul_params = + static_cast(node->builtin_data); + if (mul_params != nullptr) { + TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( + logging_context, node_index, mul_params->activation, &output_min, + &output_max)); + } + break; + } + case BuiltinOperator_PRELU: + if (quasi_static_tensors.count(input2_id) == 0) { + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, input2_tensor, input2_id, op_type, node_index)); + } + break; + case BuiltinOperator_SUB: { + const float scale_min = 1.0f / 1024.0f; + const float scale_max = 256.0f; + TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( + logging_context, input1_tensor, output_tensor, scale_min, scale_max, + op_type, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( + logging_context, input2_tensor, output_tensor, scale_min, scale_max, + op_type, node_index)); + const TfLiteSubParams* sub_params = + static_cast(node->builtin_data); + if (sub_params != nullptr) { + TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( + logging_context, node_index, sub_params->activation, &output_min, + &output_max)); + } + break; + } + default: + break; + } if (subgraph != nullptr) { - const xnn_status status = xnn_define_convert( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); + // Setup the binary op params. + struct xnn_binary_params params; + params.output_min = output_min; + params.output_max = output_max; + + // Set the binary op type and any special params associated with it. + enum xnn_binary_operator binary_op_type = xnn_binary_invalid; + switch (op_type) { + case BuiltinOperator_ADD: + binary_op_type = xnn_binary_add; + break; + case BuiltinOperator_DIV: + binary_op_type = xnn_binary_divide; + break; + case BuiltinOperator_MAXIMUM: + binary_op_type = xnn_binary_maximum; + break; + case BuiltinOperator_MINIMUM: + binary_op_type = xnn_binary_minimum; + break; + case BuiltinOperator_MUL: + binary_op_type = xnn_binary_multiply; + break; + case BuiltinOperator_PRELU: + binary_op_type = xnn_binary_prelu; + break; + case BuiltinOperator_SQUARED_DIFFERENCE: + binary_op_type = xnn_binary_squared_difference; + break; + case BuiltinOperator_SUB: + binary_op_type = xnn_binary_subtract; + break; + default: + TF_LITE_KERNEL_LOG( + logging_context, + "failed to delegate %s node #%d as a binary operator", + EnumNameBuiltinOperator(op_type), node_index); + return kTfLiteError; + } + + // Create the subgraph node. + const xnn_status status = + xnn_define_binary(subgraph, binary_op_type, ¶ms, + /*input1_id=*/input_output_tensors.at(input1_id), + /*input2_id=*/input_output_tensors.at(input2_id), + /*output_id=*/input_output_tensors.at(output_id), + /*flags=*/0); if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_DEQUANTIZE), - node_index); + TF_LITE_KERNEL_LOG( + logging_context, + "failed to delegate %s node #%d (binary_op_type=%i, status=%i)", + EnumNameBuiltinOperator(BuiltinOperator_DIV), node_index, + binary_op_type, status); return kTfLiteError; } } @@ -3998,49 +4109,263 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitDivNode( + static TfLiteStatus VisitUnaryNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, const TfLiteDivParams* div_params, + tflite::BuiltinOperator op_type, const TfLiteTensor* tensors, const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_DIV, node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input1_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input1_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_DIV, node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input2_tensor, node->inputs->data[1], node_index)); + // Get the input and output tensors. + TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs(logging_context, node, 1, 1, + op_type, node_index)); + const int input_id = node->inputs->data[0]; + const int output_id = node->outputs->data[0]; + const TfLiteTensor& input_tensor = tensors[input_id]; + const TfLiteTensor& output_tensor = tensors[output_id]; + + // Check the input tensor shape. TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input2_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], - BuiltinOperator_DIV, node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - float output_min = -std::numeric_limits::infinity(); - float output_max = +std::numeric_limits::infinity(); - if (div_params != nullptr) { - TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( - logging_context, node_index, div_params->activation, &output_min, - &output_max)); + logging_context, input_tensor, /*min_num_dims=*/0, + /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, input_id, op_type, node_index)); + + // Check the input/output tensor types. + switch (op_type) { + case BuiltinOperator_ABS: + case BuiltinOperator_CEIL: + case BuiltinOperator_COS: + case BuiltinOperator_EXP: + case BuiltinOperator_FLOOR: + case BuiltinOperator_GELU: + case BuiltinOperator_HARD_SWISH: + case BuiltinOperator_NEG: + case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU6: + case BuiltinOperator_ROUND: + case BuiltinOperator_RSQRT: + case BuiltinOperator_SIN: + case BuiltinOperator_SQRT: + case BuiltinOperator_SQUARE: + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, input_tensor, input_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, output_tensor, output_id, node_index)); + break; + case BuiltinOperator_DEQUANTIZE: + TF_LITE_ENSURE_STATUS(CheckTensorQInt8OrQUInt8Type( + delegate, logging_context, input_tensor, input_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( + logging_context, output_tensor, output_id, node_index)); + break; + case BuiltinOperator_ELU: + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQInt8Type( + delegate, logging_context, input_tensor, input_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQInt8Type( + delegate, logging_context, output_tensor, output_id, node_index)); + break; + case BuiltinOperator_LOGISTIC: + case BuiltinOperator_TANH: + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, input_tensor, input_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, output_tensor, output_id, node_index)); + break; + case BuiltinOperator_LEAKY_RELU: { + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, input_tensor, input_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, output_tensor, output_id, node_index)); + const TfLiteLeakyReluParams* leaky_relu_params = + static_cast(node->builtin_data); + if (!std::isnormal(leaky_relu_params->alpha) || + leaky_relu_params->alpha == 0.0f) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, "unsupported alpha %g in LEAKY_RELU node #%d", + leaky_relu_params->alpha, node_index); + return kTfLiteError; + } + const float input_scale = + GetTensorScaleOrDefault(input_tensor, std::nanf("")); + const float output_scale = + GetTensorScaleOrDefault(output_tensor, std::nanf("")); + if (std::isnormal(input_scale) && std::isnormal(output_scale)) { + const float positive_scale = input_scale / output_scale; + if (positive_scale < 1.0f / 256.0f || positive_scale > 128.0f) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "unsupported positive input-to-output scale " + "%g in LEAKY_RELU node #%d", + positive_scale, node_index); + return kTfLiteError; + } + const float negative_scale = + positive_scale * leaky_relu_params->alpha; + if (negative_scale < -127.99609375f || negative_scale > 128.0f || + std::fabs(negative_scale) < 1.0f / 256.0f) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "unsupported negative input-to-output scale " + "%g in LEAKY_RELU node #%d", + negative_scale, node_index); + return kTfLiteError; + } + } + break; + } + case BuiltinOperator_QUANTIZE: { + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( + delegate, logging_context, input_tensor, input_id, node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorQInt8OrQUInt8Type( + delegate, logging_context, output_tensor, output_id, node_index)); + const xnn_datatype input_datatype = + GetXNNPackDatatype(logging_context, input_tensor, input_id); + const xnn_datatype output_datatype = + GetXNNPackDatatype(logging_context, output_tensor, output_id); + bool supported_combination = false; + switch (input_datatype) { + case xnn_datatype_fp32: + supported_combination = true; + break; + case xnn_datatype_qint8: + case xnn_datatype_quint8: + if (input_datatype == output_datatype) { + const float input_scale = + GetTensorScaleOrDefault(input_tensor, std::nanf("")); + const float output_scale = + GetTensorScaleOrDefault(output_tensor, std::nanf("")); + const float input_output_scale = input_scale / output_scale; + if (input_output_scale < 1.0f / 256.0f || + input_output_scale > 128.0f) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "unsupported input-to-output scale in QUANTIZE node #%d", + node_index); + return kTfLiteError; + } + supported_combination = true; + } + break; + default: + break; + } + if (!supported_combination) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "unsupported combination of input type (%s) and " + "output type (%s) in QUANTIZE node #%d", + TfLiteTypeGetName(input_tensor.type), + TfLiteTypeGetName(output_tensor.type), node_index); + return kTfLiteError; + } + break; + } + default: + TF_LITE_KERNEL_LOG( + logging_context, + "failed to delegate %s node #%d as a binary operator", + EnumNameBuiltinOperator(op_type), node_index); + return kTfLiteError; } if (subgraph != nullptr) { - const xnn_status status = xnn_define_divide( - subgraph, output_min, output_max, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); + // Setup the unary op params. + union xnn_unary_params params; + + // Set the binary op type and any special params associated with it. + enum xnn_unary_operator unary_op_type = xnn_unary_invalid; + switch (op_type) { + case BuiltinOperator_ABS: + unary_op_type = xnn_unary_abs; + break; + case BuiltinOperator_CEIL: + unary_op_type = xnn_unary_ceiling; + break; + case BuiltinOperator_COS: + unary_op_type = xnn_unary_cosine; + break; + case BuiltinOperator_DEQUANTIZE: + case BuiltinOperator_QUANTIZE: + unary_op_type = xnn_unary_convert; + break; + case BuiltinOperator_ELU: + unary_op_type = xnn_unary_elu; + params.elu.alpha = 1.0f; + break; + case BuiltinOperator_EXP: + unary_op_type = xnn_unary_exp; + break; + case BuiltinOperator_FLOOR: + unary_op_type = xnn_unary_floor; + break; + case BuiltinOperator_GELU: { + const TfLiteGeluParams* gelu_params = + static_cast(node->builtin_data); + unary_op_type = + gelu_params->approximate ? xnn_unary_approxgelu : xnn_unary_gelu; + break; + } + case BuiltinOperator_HARD_SWISH: + unary_op_type = xnn_unary_hardswish; + break; + case BuiltinOperator_LEAKY_RELU: { + const TfLiteLeakyReluParams* leaky_relu_params = + static_cast(node->builtin_data); + params.leaky_relu.negative_slope = leaky_relu_params->alpha; + unary_op_type = xnn_unary_leaky_relu; + break; + } + case BuiltinOperator_LOGISTIC: + unary_op_type = xnn_unary_sigmoid; + break; + case BuiltinOperator_NEG: + unary_op_type = xnn_unary_negate; + break; + case BuiltinOperator_RELU: + params.clamp.min = 0.0f; + params.clamp.max = std::numeric_limits::infinity(); + unary_op_type = xnn_unary_clamp; + break; + case BuiltinOperator_RELU_N1_TO_1: + params.clamp.min = -1.0f; + params.clamp.max = 1.0f; + unary_op_type = xnn_unary_clamp; + break; + case BuiltinOperator_RELU6: + params.clamp.min = 0.0f; + params.clamp.max = 6.0f; + unary_op_type = xnn_unary_clamp; + break; + case BuiltinOperator_ROUND: + unary_op_type = xnn_unary_bankers_rounding; + break; + case BuiltinOperator_RSQRT: + unary_op_type = xnn_unary_reciprocal_square_root; + break; + case BuiltinOperator_SIN: + unary_op_type = xnn_unary_sine; + break; + case BuiltinOperator_SQRT: + unary_op_type = xnn_unary_square_root; + break; + case BuiltinOperator_SQUARE: + unary_op_type = xnn_unary_square; + break; + case BuiltinOperator_TANH: + unary_op_type = xnn_unary_tanh; + break; + default: + TF_LITE_KERNEL_LOG( + logging_context, + "failed to delegate %s node #%d as a binary operator", + EnumNameBuiltinOperator(op_type), node_index); + return kTfLiteError; + } + + // Create the subgraph node. + const xnn_status status = + xnn_define_unary(subgraph, unary_op_type, ¶ms, + /*input_id=*/input_output_tensors.at(input_id), + /*output_id=*/input_output_tensors.at(output_id), + /*flags=*/0); if (status != xnn_status_success) { TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", EnumNameBuiltinOperator(BuiltinOperator_DIV), @@ -4052,63 +4377,28 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitEluNode( + static TfLiteStatus VisitExpandDimsNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, const TfLiteTensor* tensors, const std::unordered_map& input_output_tensors) { + return kTfLiteError; TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_ELU, node_index)); - + logging_context, node, 2, 1, BuiltinOperator_EXPAND_DIMS, node_index)); const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); + CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, + node->inputs->data[0], node_index)); + const TfLiteTensor& axis_tensor = tensors[node->inputs->data[1]]; + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, axis_tensor, node->inputs->data[1], + BuiltinOperator_EXPAND_DIMS, node_index)); - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_elu( - subgraph, /*alpha=*/1.0f, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_ELU), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitExpandDimsNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - return kTfLiteError; - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_EXPAND_DIMS, node_index)); - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); - const TfLiteTensor& axis_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, axis_tensor, node->inputs->data[1], - BuiltinOperator_EXPAND_DIMS, node_index)); - - const size_t num_new_axes = NumElements(&axis_tensor); + const int64_t num_new_axes = NumElements(&axis_tensor); if (num_new_axes != 1) { TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "unexpected number of axes (%d) in node #%d: " - "TFLite only supports 1 new axes", + "unexpected number of axes (%" PRId64 + ") in node #%d: TFLite only supports 1 new axes", num_new_axes, node_index); return kTfLiteError; } @@ -4226,8 +4516,12 @@ class Subgraph { (input_tensor.type == kTfLiteFloat32 && (filter_tensor.type == kTfLiteInt4 || filter_tensor.type == kTfLiteInt8))); + bool supported_srq = (input_tensor.type == kTfLiteInt8 && + (filter_tensor.type == kTfLiteInt4 || + filter_tensor.type == kTfLiteInt8)); if (input_tensor.type != output_tensor.type || - ((input_tensor.type != filter_tensor.type) && !dynamically_quantized)) { + ((input_tensor.type != filter_tensor.type) && + !(dynamically_quantized || supported_srq))) { TF_LITE_MAYBE_KERNEL_LOG( logging_context, "unsupported mixed types in FULLY_CONNECTED operator #%d", @@ -4298,14 +4592,37 @@ class Subgraph { std::vector filter_dims( &filter_tensor.dims->data[0], &filter_tensor.dims->data[NumDimensions(&filter_tensor)]); - int32_t zero_point_value = filter_params->zero_point->data[0]; uint32_t kernel_id = XNN_INVALID_VALUE_ID; - status = xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, filter_datatype, zero_point_value, - filter_params->scale->data, filter_dims.size(), /*channel_dim=*/0, - filter_dims.data(), GetTensorData(&filter_tensor), - XNN_INVALID_VALUE_ID, - /*flags=*/0, &kernel_id); + switch (filter_datatype) { + case xnn_datatype_qcint4: + case xnn_datatype_qcint8: { + int32_t zero_point_value = filter_params->zero_point->data[0]; + status = xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, filter_datatype, zero_point_value, + filter_params->scale->data, filter_dims.size(), + /*channel_dim=*/0, filter_dims.data(), + GetTensorData(&filter_tensor), XNN_INVALID_VALUE_ID, + /*flags=*/0, &kernel_id); + break; + } + case xnn_datatype_qbint4: { + const auto* quantization_params = + reinterpret_cast( + tensors[node->inputs->data[1]].quantization.params); + const TfLiteTensor& scale_tensor = + tensors[quantization_params->scale]; + status = xnn_define_blockwise_quantized_tensor_value_v2( + subgraph, filter_datatype, 0, + reinterpret_cast(scale_tensor.data.data), + filter_dims.size(), quantization_params->quantized_dimension, + quantization_params->blocksize, filter_dims.data(), + GetTensorData(&filter_tensor), XNN_INVALID_VALUE_ID, + /*flags=*/0, xnn_datatype_fp16, &kernel_id); + break; + } + default: + return kTfLiteError; + } if (status != xnn_status_success) { TF_LITE_KERNEL_LOG( logging_context, "failed to update filter tensor %s node #%d", @@ -4354,215 +4671,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitFloorNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_FLOOR, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_floor( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_FLOOR), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitGeluNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_GELU, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - const TfLiteGeluParams* gelu_params = - static_cast(node->builtin_data); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_unary( - subgraph, - /*type=*/gelu_params->approximate ? xnn_unary_approxgelu - : xnn_unary_gelu, - /*params=*/nullptr, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_GELU), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitHardSwishNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_HARD_SWISH, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_hardswish( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_HARD_SWISH), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitLeakyReluNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const TfLiteLeakyReluParams* leaky_relu_params, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_LEAKY_RELU, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - - if (!std::isnormal(leaky_relu_params->alpha) || - leaky_relu_params->alpha == 0.0f) { - TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "unsupported alpha %g in LEAKY_RELU node #%d", - leaky_relu_params->alpha, node_index); - return kTfLiteError; - } - - const float input_scale = - GetTensorScaleOrDefault(input_tensor, std::nanf("")); - const float output_scale = - GetTensorScaleOrDefault(output_tensor, std::nanf("")); - if (std::isnormal(input_scale) && std::isnormal(output_scale)) { - const float positive_scale = input_scale / output_scale; - if (positive_scale < 1.0f / 256.0f || positive_scale > 128.0f) { - TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "unsupported positive input-to-output scale " - "%g in LEAKY_RELU node #%d", - positive_scale, node_index); - return kTfLiteError; - } - - const float negative_scale = positive_scale * leaky_relu_params->alpha; - if (negative_scale < -127.99609375f || negative_scale > 128.0f || - std::fabs(negative_scale) < 1.0f / 256.0f) { - TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "unsupported negative input-to-output scale " - "%g in LEAKY_RELU node #%d", - negative_scale, node_index); - return kTfLiteError; - } - } - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_leaky_relu( - subgraph, leaky_relu_params->alpha, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_LEAKY_RELU), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitLogisticNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_LOGISTIC, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_sigmoid( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_LOGISTIC), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitMaxPool2DNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -4688,44 +4796,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitMaximumNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_MAXIMUM, node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input1_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input2_tensor, node->inputs->data[1], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_maximum2( - subgraph, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_MAXIMUM), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitMediaPipeDeconvolutionNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -4962,140 +5032,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitMinimumNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_MINIMUM, node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input1_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input2_tensor, node->inputs->data[1], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_minimum2( - subgraph, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_MINIMUM), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitMulNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, const TfLiteMulParams* mul_params, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_MUL, node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input1_tensor, - node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input1_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_MUL, node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input2_tensor, - node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input2_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], - BuiltinOperator_MUL, node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - - const float scale_min = 1.0f / 65536.0f; - const float scale_max = 256.0f; - TF_LITE_ENSURE_STATUS(CheckTensorsInputProductOutputScale( - logging_context, input1_tensor, input2_tensor, output_tensor, scale_min, - scale_max, BuiltinOperator_MUL, node_index)); - - float output_min = -std::numeric_limits::infinity(); - float output_max = +std::numeric_limits::infinity(); - if (mul_params != nullptr) { - TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( - logging_context, node_index, mul_params->activation, &output_min, - &output_max)); - } - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_multiply2( - subgraph, output_min, output_max, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_MUL), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitNegNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_NEG, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_negate( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_NEG), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitPadNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -5158,151 +5094,18 @@ class Subgraph { std::array post_paddings{}; for (int i = 0; i < SizeOfDimension(&paddings_tensor, 0); i++) { pre_paddings[i] = static_cast(paddings_data[i * 2 + 0]); - post_paddings[i] = static_cast(paddings_data[i * 2 + 1]); - } - - const xnn_status status = xnn_define_static_constant_pad( - subgraph, pre_paddings.data(), post_paddings.data(), - /*padding_value=*/0.0f, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_PAD), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitPreluNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_set& quasi_static_tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_PRELU, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input_tensor, 1, XNN_MAX_TENSOR_DIMS, - node->inputs->data[0], BuiltinOperator_PRELU, node_index)); - - const TfLiteTensor& slope_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, slope_tensor, node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckSlopeTensorShape( - logging_context, slope_tensor, node->inputs->data[1], - BuiltinOperator_PRELU, node_index)); - if (quasi_static_tensors.count(node->inputs->data[1]) == 0) { - TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, slope_tensor, node->inputs->data[1], - BuiltinOperator_PRELU, node_index)); - } - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, output_tensor, 1, XNN_MAX_TENSOR_DIMS, - node->outputs->data[0], BuiltinOperator_PRELU, node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_prelu( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*slope_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_PRELU), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitQuantizeNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_QUANTIZE, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorQInt8OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_QUANTIZE, node_index)); - - const xnn_datatype input_datatype = GetXNNPackDatatype( - logging_context, input_tensor, node->inputs->data[0]); - const xnn_datatype output_datatype = GetXNNPackDatatype( - logging_context, output_tensor, node->outputs->data[0]); - bool supported_combination = false; - switch (input_datatype) { - case xnn_datatype_fp32: - supported_combination = true; - break; - case xnn_datatype_qint8: - case xnn_datatype_quint8: - if (input_datatype == output_datatype) { - const float input_scale = - GetTensorScaleOrDefault(input_tensor, std::nanf("")); - const float output_scale = - GetTensorScaleOrDefault(output_tensor, std::nanf("")); - const float input_output_scale = input_scale / output_scale; - if (input_output_scale < 1.0f / 256.0f || - input_output_scale > 128.0f) { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "unsupported input-to-output scale in QUANTIZE node #%d", - node_index); - return kTfLiteError; - } - supported_combination = true; - } - break; - default: - break; - } - if (!supported_combination) { - TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "unsupported combination of input type (%s) and " - "output type (%s) in QUANTIZE node #%d", - TfLiteTypeGetName(input_tensor.type), - TfLiteTypeGetName(output_tensor.type), - node_index); - return kTfLiteError; - } + post_paddings[i] = static_cast(paddings_data[i * 2 + 1]); + } - if (subgraph != nullptr) { - const xnn_status status = xnn_define_convert( - subgraph, + const xnn_status status = xnn_define_static_constant_pad( + subgraph, pre_paddings.data(), post_paddings.data(), + /*padding_value=*/0.0f, /*input_id=*/input_output_tensors.at(node->inputs->data[0]), /*output_id=*/input_output_tensors.at(node->outputs->data[0]), /*flags=*/0); if (status != xnn_status_success) { TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_QUANTIZE), + EnumNameBuiltinOperator(BuiltinOperator_PAD), node_index); return kTfLiteError; } @@ -5350,39 +5153,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitReluNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, float output_min, float output_max, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_RELU, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_clamp( - subgraph, output_min, output_max, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_RELU), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitReshapeNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -5576,39 +5346,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitRoundNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_ROUND, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_bankers_rounding( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_ROUND), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitSliceNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -5884,78 +5621,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitSquareNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_SQUARE, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_SQUARE, node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_square( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_SQUARE), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitTanhNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_TANH, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_tanh( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_TANH), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitTransposeNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -6003,119 +5668,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitSqrtNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_SQRT, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_square_root( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_SQRT), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitRsqrtNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 1, 1, BuiltinOperator_RSQRT, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_reciprocal_square_root( - subgraph, /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_RSQRT), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - - static TfLiteStatus VisitSquaredDifferenceNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_SQUARED_DIFFERENCE, - node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input1_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input1_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_SQUARED_DIFFERENCE, node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input2_tensor, node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input2_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], - BuiltinOperator_SQUARED_DIFFERENCE, node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, output_tensor, node->outputs->data[0], node_index)); - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_squared_difference( - subgraph, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG( - logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_SQUARED_DIFFERENCE), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitStridedSliceNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -6246,13 +5798,14 @@ class Subgraph { // inside our kernels; check here and punt those to the default // delegate implementation for it to decide how to handle them. const int64_t extent = input_tensor.dims->data[i]; - const size_t offset = begins[i] < 0 ? begins[i] + extent : begins[i]; - const size_t size = + const int64_t offset = begins[i] < 0 ? begins[i] + extent : begins[i]; + const int64_t size = ends[i] <= 0 ? ends[i] + extent - offset : ends[i] - offset; if (offset + size > extent) { TF_LITE_MAYBE_KERNEL_LOG(logging_context, - "offset %zu + size %zu exceeds extent %zu in " - "STRIDED_SLICE node #%d for dimension %zu", + "offset %" PRId64 " + size %" PRId64 + " exceeds extent %" PRId64 + " in STRIDED_SLICE node #%d for dimension %zu", offset, size, extent, node_index, i); return kTfLiteError; } @@ -6367,7 +5920,7 @@ class Subgraph { TF_LITE_ENSURE_EQ( logging_context, xnn_status_success, xnn_define_tensor_value(subgraph, xnn_datatype_fp32, /*num_dims=*/0, - /*dims=*/nullptr, &query_proj.dims->data[3], + /*dims=*/nullptr, &kConstantClampData, XNN_INVALID_VALUE_ID, 0, &scale_orig_id)); TF_LITE_ENSURE_EQ( logging_context, xnn_status_success, @@ -6719,72 +6272,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitSubNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, const TfLiteSubParams* sub_params, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_SUB, node_index)); - - const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input1_tensor, - node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input1_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], - BuiltinOperator_SUB, node_index)); - - const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input2_tensor, - node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, input2_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], - BuiltinOperator_SUB, node_index)); - - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - - const float scale_min = 1.0f / 1024.0f; - const float scale_max = 256.0f; - TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( - logging_context, input1_tensor, output_tensor, scale_min, scale_max, - BuiltinOperator_SUB, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale( - logging_context, input2_tensor, output_tensor, scale_min, scale_max, - BuiltinOperator_SUB, node_index)); - - float output_min = -std::numeric_limits::infinity(); - float output_max = +std::numeric_limits::infinity(); - if (sub_params != nullptr) { - TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange( - logging_context, node_index, sub_params->activation, &output_min, - &output_max)); - } - - if (subgraph != nullptr) { - const xnn_status status = xnn_define_subtract( - subgraph, output_min, output_max, - /*input1_id=*/input_output_tensors.at(node->inputs->data[0]), - /*input2_id=*/input_output_tensors.at(node->inputs->data[1]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - /*flags=*/0); - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_SUB), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitTransposeConvNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h index e7f1713072776a..ccd2840a8a13ed 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h @@ -58,6 +58,8 @@ typedef struct { // Number of threads to use in the thread pool. // 0 or negative value means no thread pool used. int32_t num_threads; + // Flags to pass to `xnn_create_runtime` + uint32_t runtime_flags; // Bitfield with any combination of the following binary options: // - TFLITE_XNNPACK_DELEGATE_FLAG_QS8 // - TFLITE_XNNPACK_DELEGATE_FLAG_QU8 diff --git a/tensorflow/lite/examples/label_image/bitmap_helpers.cc b/tensorflow/lite/examples/label_image/bitmap_helpers.cc index d3698f3b22218b..32d7f443fc49d8 100644 --- a/tensorflow/lite/examples/label_image/bitmap_helpers.cc +++ b/tensorflow/lite/examples/label_image/bitmap_helpers.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/lite/examples/label_image/label_image.h" #include "tensorflow/lite/examples/label_image/log.h" diff --git a/tensorflow/lite/examples/label_image/label_image.cc b/tensorflow/lite/examples/label_image/label_image.cc index 9368f00d0fc5fe..1d441b46b3cc34 100644 --- a/tensorflow/lite/examples/label_image/label_image.cc +++ b/tensorflow/lite/examples/label_image/label_image.cc @@ -22,6 +22,7 @@ limitations under the License. #include // NOLINT(build/include_order) #include // NOLINT(build/include_order) +#include #include #include #include diff --git a/tensorflow/lite/examples/label_image/label_image_test.cc b/tensorflow/lite/examples/label_image/label_image_test.cc index d4e2e87270484b..02410987e62894 100644 --- a/tensorflow/lite/examples/label_image/label_image_test.cc +++ b/tensorflow/lite/examples/label_image/label_image_test.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/lite/examples/label_image/label_image.h" +#include #include +#include +#include #include #include "tensorflow/lite/c/c_api_types.h" diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD index 4360e6a615f64e..6d14959b1a9b74 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD @@ -294,8 +294,8 @@ cc_library( hdrs = ["decode_jpeg_register.h"], copts = tflite_copts(), deps = [ + ":decode_jpeg_status", ":libjpeg_decoder", - "//tensorflow/lite:string", "//tensorflow/lite:string_util", "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/core/c:common", diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg.cc index b1e2d619904ca5..ea6e7ff5ad5574 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg.cc @@ -14,17 +14,17 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include -#include #include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg_register.h" +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg_status.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/string_type.h" #include "tensorflow/lite/string_util.h" namespace tflite { @@ -124,7 +124,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { std::unique_ptr decoder = LibjpegDecoder::Create(decoder_status); if (decoder_status.code != kTfLiteOk) { - TF_LITE_KERNEL_LOG(context, decoder_status.error_message.c_str()); + TF_LITE_KERNEL_LOG(context, "%s", decoder_status.error_message.c_str()); return kTfLiteError; } @@ -166,7 +166,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output_array_offset += kOutputImageSize; if (decode_status.code != kTfLiteOk) { - TF_LITE_KERNEL_LOG(context, decode_status.error_message.c_str()); + TF_LITE_KERNEL_LOG(context, "%s", decode_status.error_message.c_str()); return kTfLiteError; } } diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.h b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.h index de00c24e9cb999..f99e300af83bba 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.h +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.h @@ -115,11 +115,6 @@ class Validator { std::unique_ptr model_loader_; const ComputeSettings* compute_settings_; - // Optional. Interpreter that runs on CPU. - std::unique_ptr golden_interpreter_; - // Interpreter that runs with delegate enabled, using the compute settings - // passed to the Validator constructor. - std::unique_ptr interpreter_; // Op resolver used to create the interpreters. Depending on the // compute_settings_, it may or may not include the default delegate. std::unique_ptr<::tflite::MutableOpResolver> resolver_; @@ -129,6 +124,11 @@ class Validator { TfLiteOpaqueDelegatePtr opaque_delegate_ = TfLiteOpaqueDelegatePtr(nullptr, [](TfLiteOpaqueDelegate*) {}); std::unique_ptr delegate_plugin_; + // Optional. Interpreter that runs on CPU. + std::unique_ptr golden_interpreter_; + // Interpreter that runs with delegate enabled, using the compute settings + // passed to the Validator constructor. + std::unique_ptr interpreter_; int validation_entrypoint_index_ = -1; Subgraph* validation_entrypoint_ = nullptr; Subgraph* main_model_ = nullptr; diff --git a/tensorflow/lite/experimental/genai/BUILD b/tensorflow/lite/experimental/genai/BUILD index 144734c0ba3660..35de2af73a28fd 100644 --- a/tensorflow/lite/experimental/genai/BUILD +++ b/tensorflow/lite/experimental/genai/BUILD @@ -26,12 +26,9 @@ cc_library( "//tensorflow/lite/experimental/resource", "//tensorflow/lite/experimental/resource:cache_buffer", "//tensorflow/lite/kernels:kernel_util", - "//tensorflow/lite/kernels:reference_ops", "//tensorflow/lite/kernels/internal:common", - "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:reference_base", "//tensorflow/lite/kernels/internal:tensor", - "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/kernels/internal:types", "@flatbuffers", ], diff --git a/tensorflow/lite/experimental/genai/genai_ops_wrapper.cc b/tensorflow/lite/experimental/genai/genai_ops_wrapper.cc index 8fa8451909e57e..cd57ea2cc55d80 100644 --- a/tensorflow/lite/experimental/genai/genai_ops_wrapper.cc +++ b/tensorflow/lite/experimental/genai/genai_ops_wrapper.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 #include "tensorflow/lite/experimental/genai/genai_ops.h" diff --git a/tensorflow/lite/experimental/genai/kvcache.cc b/tensorflow/lite/experimental/genai/kvcache.cc index 59fa3abd7ed510..f4f8bacf43eb0e 100644 --- a/tensorflow/lite/experimental/genai/kvcache.cc +++ b/tensorflow/lite/experimental/genai/kvcache.cc @@ -267,10 +267,10 @@ TfLiteStatus KVCacheEval(TfLiteContext* context, TfLiteNode* node) { v_ptr = v_ptr + sizeof(float) * op_data->layer_index * elements_in_one_block; // 0. Ensure output ptr is pointing to the cache data - TF_LITE_ENSURE_EQ(context, k_ptr, op_data->key_cache_ptr); - TF_LITE_ENSURE_EQ(context, v_ptr, op_data->value_cache_ptr); - TF_LITE_ENSURE_EQ(context, k_ptr, kfull->data.data); - TF_LITE_ENSURE_EQ(context, v_ptr, vfull->data.data); + TF_LITE_ENSURE(context, k_ptr == op_data->key_cache_ptr); + TF_LITE_ENSURE(context, v_ptr == op_data->value_cache_ptr); + TF_LITE_ENSURE(context, k_ptr == kfull->data.data); + TF_LITE_ENSURE(context, v_ptr == vfull->data.data); // 1. Determine which slots the inputs take up, and which slots are in the // existing span of the cache. diff --git a/tensorflow/lite/experimental/litert/BUILD b/tensorflow/lite/experimental/litert/BUILD deleted file mode 100644 index 23b07d5602d7c8..00000000000000 --- a/tensorflow/lite/experimental/litert/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) diff --git a/tensorflow/lite/experimental/litert/build_common/BUILD b/tensorflow/lite/experimental/litert/build_common/BUILD deleted file mode 100644 index 735f1cbed03c2c..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/BUILD +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@bazel_skylib//:bzl_library.bzl", "bzl_library") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -exports_files(srcs = [ - "export_litert_only_darwin.lds", - "export_litert_only_linux.lds", - "export_litert_runtime_only_darwin.lds", - "export_litert_runtime_only_linux.lds", -]) - -bzl_library( - name = "special_rule_bzl", - srcs = ["special_rule.bzl"], - visibility = ["//visibility:private"], -) - -bzl_library( - name = "litert_build_defs_bzl", - srcs = ["litert_build_defs.bzl"], - visibility = ["//visibility:private"], -) - -bzl_library( - name = "tfl_model_gen_bzl", - srcs = ["tfl_model_gen.bzl"], - visibility = ["//visibility:private"], -) diff --git a/tensorflow/lite/experimental/litert/build_common/export_litert_only_darwin.lds b/tensorflow/lite/experimental/litert/build_common/export_litert_only_darwin.lds deleted file mode 100644 index a51afcee0a21f0..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/export_litert_only_darwin.lds +++ /dev/null @@ -1,8 +0,0 @@ -# Compiler Plugin -*LiteRt*CompilerPlugin* - -# Compiled Result -*LiteRt*CompiledResult* - -# Dispatch -*LiteRtDispatch* diff --git a/tensorflow/lite/experimental/litert/build_common/export_litert_only_linux.lds b/tensorflow/lite/experimental/litert/build_common/export_litert_only_linux.lds deleted file mode 100644 index 97b05c1d655a71..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/export_litert_only_linux.lds +++ /dev/null @@ -1,29 +0,0 @@ -VERS_1.0 { - - /* - Export abi-stable "vendor" implemented symbols. - - TODO: Add all vendor symbols. Also export qnn libc++ symbols - (statically linked) as "protected" as needed. - */ - - global: - - /* Compiler Plugin */ - - LiteRt*CompilerPlugin*; - - /* Compiled Result */ - - LiteRt*CompiledResult*; - - /* Dispatch */ - - LiteRtDispatch*; - - local: - - /* Hide everything else */ - - *; -}; diff --git a/tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_darwin.lds b/tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_darwin.lds deleted file mode 100644 index 9638faa6b23e98..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_darwin.lds +++ /dev/null @@ -1,2 +0,0 @@ -# All LiteRt C APIs -LiteRt* diff --git a/tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_linux.lds b/tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_linux.lds deleted file mode 100644 index 6948af4950cfd6..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/export_litert_runtime_only_linux.lds +++ /dev/null @@ -1,20 +0,0 @@ -VERS_1.0 { - - /* - Export abi-stable "vendor" implemented symbols. - - TODO: Add all vendor symbols. Also export qnn libc++ symbols - (statically linked) as "protected" as needed. - */ - - global: - - /* All LiteRt C APIs */ - LiteRt*; - - local: - - /* Hide everything else */ - - *; -}; diff --git a/tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl b/tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl deleted file mode 100644 index a6b13cb2d18767..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Common LiteRT Build Utilities.""" - -#################################################################################################### -# Util - -_LRT_SO_PREFIX = "libLiteRt" -_SO_EXT = ".so" -_SHARED_LIB_SUFFIX = "_so" - -# Public - -def make_linkopt(opt): - return "-Wl,{}".format(opt) - -def make_rpaths(rpaths): - return make_linkopt("-rpath={}".format(":".join(rpaths))) - -def append_rule_kwargs(rule_kwargs, **append): - for k, v in append.items(): - append_to = rule_kwargs.pop(k, []) - append_to += v - rule_kwargs[k] = append_to - -def absolute_label(label, package_name = None): - """Get the absolute label for a given label. - - Args: - label: The label to convert to absolute. - package_name: The package name to use if the label is relative. - - Returns: - The absolute label. - """ - if label.startswith("//"): - if ":" in label: - return label - return "%s:%s" % (label, label.rsplit("/", 1)[-1]) - if not package_name: - package_name = native.package_name() - if label.startswith(":"): - return "//%s%s" % (package_name, label) - if ":" in label: - return "//%s/%s" % (package_name, label) - return "//%s:%s" % (package_name, label) - -# Private - -def _valid_shared_lib_name(name): - return name.endswith(_SHARED_LIB_SUFFIX) - -def _valid_so_name(name): - return name.startswith(_LRT_SO_PREFIX) and name.endswith(_SO_EXT) - -def _make_target_ref(name): - return ":{}".format(name) - -#################################################################################################### -# Explicitly Link System Libraries ("ungrte") - -_SYS_RPATHS_X86_64 = [ - "/usr/lib/x86_64-linux-gnu", - "/lib/x86_64-linux-gnu", -] -_SYS_RPATHS_LINKOPT_X86_64 = make_rpaths(_SYS_RPATHS_X86_64) - -_SYS_ELF_INTERPRETER_X86_64 = "/lib64/ld-linux-x86-64.so.2" -_SYS_ELF_INTERPRETER_LINKOPT_X86_64 = make_linkopt("--dynamic-linker={}".format(_SYS_ELF_INTERPRETER_X86_64)) - -#################################################################################################### -# Symbol Hiding - -_EXPORT_LRT_ONLY_SCRIPT_LINUX = "//tensorflow/lite/experimental/litert/build_common:export_litert_only_linux.lds" -_EXPORT_LRT_ONLY_SCRIPT_DARWIN = "//tensorflow/lite/experimental/litert/build_common:export_litert_only_darwin.lds" -_EXPORT_LRT_ONLY_LINKOPT_LINUX = make_linkopt("--version-script=$(location {})".format(_EXPORT_LRT_ONLY_SCRIPT_LINUX)) -_EXPORT_LRT_ONLY_LINKOPT_DARWIN = make_linkopt("-exported_symbols_list,$(location {})".format(_EXPORT_LRT_ONLY_SCRIPT_DARWIN)) - -def symbol_opts(): - """Defines linker flags whether to include symbols or not.""" - return select({ - "//tensorflow:debug": [], - "//conditions:default": [ - # Omit symbol table, for all non debug builds - "-Wl,-s", - ], - }) - -def export_lrt_only_script(): - return select({ - "//tensorflow:linux_x86_64": [_EXPORT_LRT_ONLY_SCRIPT_LINUX], - "//tensorflow:android": [_EXPORT_LRT_ONLY_SCRIPT_LINUX], - "//tensorflow:macos": [_EXPORT_LRT_ONLY_SCRIPT_DARWIN], - "//tensorflow:ios": [_EXPORT_LRT_ONLY_SCRIPT_DARWIN], - "//conditions:default": [], - }) - -def export_lrt_only_linkopt(): - return select({ - "//tensorflow:linux_x86_64": [_EXPORT_LRT_ONLY_LINKOPT_LINUX], - "//tensorflow:android": [_EXPORT_LRT_ONLY_LINKOPT_LINUX], - "//tensorflow:macos": [_EXPORT_LRT_ONLY_LINKOPT_DARWIN], - "//tensorflow:ios": [_EXPORT_LRT_ONLY_LINKOPT_DARWIN], - "//conditions:default": [], - }) + symbol_opts() - -_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_LINUX = "//tensorflow/lite/experimental/litert/build_common:export_litert_runtime_only_linux.lds" -_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_DARWIN = "//tensorflow/lite/experimental/litert/build_common:export_litert_runtime_only_darwin.lds" -_EXPORT_LRT_RUNTIME_ONLY_LINKOPT_LINUX = make_linkopt("--version-script=$(location {})".format(_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_LINUX)) -_EXPORT_LRT_RUNTIME_ONLY_LINKOPT_DARWIN = make_linkopt("-exported_symbols_list,$(location {})".format(_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_DARWIN)) - -# TODO b/391390553: Add "-Wl,--no-undefined" to make sure all symbols are defined. -_EXPORT_LRT_COMMON_LINKOPTS_LINUX = [ - "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj. - "-Wl,--gc-sections", # Eliminate unused code and data. - "-Wl,--as-needed", # Don't link unused libs.a -] - -def export_lrt_runtime_only_script(): - return select({ - "//tensorflow:linux_x86_64": [_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_LINUX], - "//tensorflow:android": [_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_LINUX], - "//tensorflow:macos": [_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_DARWIN], - "//tensorflow:ios": [_EXPORT_LRT_RUNTIME_ONLY_SCRIPT_DARWIN], - "//conditions:default": [], - }) - -def export_lrt_runtime_only_linkopt(): - return select({ - "//tensorflow:linux_x86_64": _EXPORT_LRT_COMMON_LINKOPTS_LINUX + [_EXPORT_LRT_RUNTIME_ONLY_LINKOPT_LINUX], - "//tensorflow:android": _EXPORT_LRT_COMMON_LINKOPTS_LINUX + [ - "-Wl,-z,max-page-size=16384", - _EXPORT_LRT_RUNTIME_ONLY_LINKOPT_LINUX, - ], - "//tensorflow:macos": [_EXPORT_LRT_RUNTIME_ONLY_LINKOPT_DARWIN], - "//tensorflow:ios": [_EXPORT_LRT_RUNTIME_ONLY_LINKOPT_DARWIN], - "//conditions:default": [], - }) + symbol_opts() - -#################################################################################################### -# Macros - -# Private - -def _litert_base( - rule, - ungrte = False, - **cc_rule_kwargs): - """ - Base rule for LiteRT targets. - - Args: - rule: The underlying rule to use (e.g., cc_test, cc_library). - ungrte: Whether to link against system libraries ("ungrte"). - **cc_rule_kwargs: Keyword arguments to pass to the underlying rule. - """ - if ungrte: - append_rule_kwargs( - cc_rule_kwargs, - linkopts = select({ - "//tensorflow:linux_x86_64": [_SYS_ELF_INTERPRETER_LINKOPT_X86_64, _SYS_RPATHS_LINKOPT_X86_64], - "//conditions:default": [], - }), - ) - rule(**cc_rule_kwargs) - -# Public - -def litert_test( - ungrte = False, - use_sys_malloc = False, - **cc_test_kwargs): - """ - LiteRT test rule. - - Args: - ungrte: Whether to link against system libraries ("ungrte"). - use_sys_malloc: Whether to use the system malloc. - **cc_test_kwargs: Keyword arguments to pass to the underlying rule. - """ - if use_sys_malloc: - # copybara:uncomment cc_test_kwargs["malloc"] = "//base:system_malloc" - pass - - append_rule_kwargs( - cc_test_kwargs, - deps = ["@com_google_googletest//:gtest_main"], - ) - - _litert_base( - native.cc_test, - ungrte, - **cc_test_kwargs - ) - -def litert_lib( - ungrte = False, - **cc_lib_kwargs): - """ - LiteRT library rule. - - Args: - ungrte: Whether to link against system libraries ("ungrte"). - **cc_lib_kwargs: Keyword arguments to pass to the underlying rule. - """ - _litert_base( - native.cc_library, - ungrte, - **cc_lib_kwargs - ) - -def litert_bin( - ungrte = False, - export_litert_only = False, - **cc_bin_kwargs): - """ - LiteRT binary rule. - - Args: - ungrte: Whether to link against system libraries ("ungrte"). - export_litert_only: Whether to export only LiteRT symbols. - **cc_bin_kwargs: Keyword arguments to pass to the underlying rule. - """ - if export_litert_only: - append_rule_kwargs( - cc_bin_kwargs, - linkopts = export_lrt_only_linkopt(), - deps = export_lrt_only_script(), - ) - - _litert_base( - native.cc_binary, - ungrte, - **cc_bin_kwargs - ) - -def litert_dynamic_lib( - name, - shared_lib_name, - so_name, - export_litert_only = False, - ungrte = False, - **cc_lib_kwargs): - """ - LiteRT dynamic library rule. - - Args: - name: The name of the library. - shared_lib_name: The name of the shared library. - so_name: The name of the shared object file. - export_litert_only: Whether to export only LiteRT symbols. - ungrte: Whether to link against system libraries ("ungrte"). - **cc_lib_kwargs: Keyword arguments to pass to the underlying rule. - """ - if not _valid_shared_lib_name(shared_lib_name): - fail("\"shared_lib_name\" must end with \"_so\"") - if not _valid_so_name(so_name): - fail("\"so_name\" must be \"libLiteRt*.so\"") - - lib_name = name - cc_lib_kwargs["name"] = lib_name - - lib_target_ref = _make_target_ref(lib_name) - - vis = cc_lib_kwargs.get("visibility", None) - - # Share tags for all targets. - tags = cc_lib_kwargs.get("tags", []) - - litert_lib( - ungrte = ungrte, - **cc_lib_kwargs - ) - - user_link_flags = [] - additional_linker_inputs = [] - if export_litert_only: - user_link_flags = export_lrt_only_linkopt() - additional_linker_inputs = export_lrt_only_script() - - native.cc_shared_library( - name = shared_lib_name, - shared_lib_name = so_name, - user_link_flags = user_link_flags, - additional_linker_inputs = additional_linker_inputs, - tags = tags, - visibility = vis, - deps = [lib_target_ref], - ) - -def copy_file(name, src, target, visibility = None): - input_path = "$(location %s)" % src - output_path = "$(@D)/" + target - - native.genrule( - name = name, - srcs = [src], - outs = [target], - visibility = visibility, - cmd = "cp %s %s" % (input_path, output_path), - ) - -def gtest_main_no_heapcheck_deps(): - # copybara:uncomment_begin(google-only) - # return ["//testing/base/public:gunit_main_no_heapcheck"] - # copybara:uncomment_end - # copybara:comment_begin(oss-only) - return ["@com_google_googletest//:gtest_main"] - # copybara:comment_end diff --git a/tensorflow/lite/experimental/litert/build_common/special_rule.bzl b/tensorflow/lite/experimental/litert/build_common/special_rule.bzl deleted file mode 100644 index e6a3c1c47fcf1e..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/special_rule.bzl +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""External versions of LiteRT build rules that differ outside of Google.""" - -def lite_rt_friends(): - """Internal visibility for packages outside of LiteRT code location. - - Return the package group declaration for internal code locations that need - visibility to LiteRT APIs""" - - return [] - -def gles_deps(): - """This is a no-op outside of Google.""" - return [] - -def gles_headers(): - """This is a no-op outside of Google.""" - return [] - -def gles_linkopts(): - """This is a no-op outside of Google.""" - return [] diff --git a/tensorflow/lite/experimental/litert/build_common/tfl_model_gen.bzl b/tensorflow/lite/experimental/litert/build_common/tfl_model_gen.bzl deleted file mode 100644 index 654c8684cf5754..00000000000000 --- a/tensorflow/lite/experimental/litert/build_common/tfl_model_gen.bzl +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility to generate tflite models from MLIR files.""" - -def tfl_model_gen(name, srcs, subdir = "testdata"): - """ - Generates tflite models from MLIR files. - - Args: - name: name of the rule. - srcs: list of MLIR files. - subdir: subdirectory to place the generated tflite files. - """ - OUT_DIR = "$(RULEDIR)" - CONVERTER = "//tensorflow/compiler/mlir/lite:tf_tfl_translate" - CMD = """ - for mlir_file in $(SRCS); do - $(location {converter}) --input-mlir $$mlir_file --o={out_dir}/{subdir}/$$(basename $$mlir_file .mlir).tflite - done - """.format( - converter = CONVERTER, - out_dir = OUT_DIR, - subdir = subdir, - ) - - native.genrule( - name = name, - srcs = srcs, - outs = [s.removesuffix(".mlir") + ".tflite" for s in srcs], - cmd = CMD, - tools = [CONVERTER], - ) - - native.filegroup( - name = name + "_files", - srcs = [name], - ) diff --git a/tensorflow/lite/experimental/litert/c/BUILD b/tensorflow/lite/experimental/litert/c/BUILD deleted file mode 100644 index 6ea15aa4478430..00000000000000 --- a/tensorflow/lite/experimental/litert/c/BUILD +++ /dev/null @@ -1,589 +0,0 @@ -# copybara:uncomment_begin(google-only) -# load("//devtools/deps/check:deps_check.bzl", "check_dependencies") -# -# copybara:uncomment_end -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "copy_file", "export_lrt_runtime_only_linkopt", "export_lrt_runtime_only_script") -load("//tensorflow/lite/experimental/litert/build_common:special_rule.bzl", "gles_deps", "gles_headers", "gles_linkopts") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], -) - -cc_library( - name = "litert_common", - srcs = ["litert_common.cc"], - hdrs = ["litert_common.h"], -) - -cc_test( - name = "litert_common_test", - srcs = ["litert_common_test.cc"], - deps = [ - ":litert_common", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_any", - hdrs = ["litert_any.h"], -) - -cc_library( - name = "litert_environment", - srcs = ["litert_environment.cc"], - hdrs = ["litert_environment.h"], - deps = [ - ":litert_common", - ":litert_environment_options", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime:gpu_environment", - "//tensorflow/lite/experimental/litert/runtime/accelerators:auto_registration", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "litert_environment_options", - srcs = ["litert_environment_options.cc"], - hdrs = ["litert_environment_options.h"], - deps = [ - ":litert_any", - ":litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:environment_options", - ], -) - -cc_library( - name = "litert_environment_options_header", - hdrs = ["litert_environment_options.h"], - tags = ["avoid_dep"], - deps = [ - ":litert_any", - ":litert_common", - ], -) - -cc_test( - name = "litert_environment_options_test", - srcs = ["litert_environment_options_test.cc"], - deps = [ - ":litert_any", - ":litert_common", - ":litert_environment_options", - "//tensorflow/lite/experimental/litert/core:environment_options", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_logging", - srcs = [ - "litert_logging.cc", - ], - hdrs = [ - "litert_logging.h", - ], - deps = [ - ":litert_common", - "//tensorflow/lite:minimal_logging", - ], -) - -cc_test( - name = "litert_logging_test", - srcs = [ - "litert_logging_test.cc", - ], - deps = [ - ":litert_common", - ":litert_logging", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_layout", - hdrs = ["litert_layout.h"], - deps = [ - ":litert_common", - ":litert_op_code", - "//tensorflow/lite/core/c:c_api_types", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "litert_model", - srcs = ["litert_model.cc"], - hdrs = ["litert_model.h"], - deps = [ - ":litert_common", - ":litert_layout", - ":litert_op_code", - "//tensorflow/lite/core/c:c_api_types", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:model_load", - "//tensorflow/lite/experimental/litert/core/model:model_serialize", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "litert_model_test", - srcs = ["litert_model_test.cc"], - deps = [ - ":litert_common", - ":litert_model", - ":litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_op_code", - hdrs = ["litert_op_code.h"], - deps = ["//tensorflow/lite:builtin_ops"], -) - -cc_library( - name = "litert_options", - srcs = ["litert_options.cc"], - hdrs = [ - "litert_options.h", - ], - deps = [ - ":litert_common", - ":litert_op_code", - "//tensorflow/compiler/mlir/lite/core:model_builder_base", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/experimental/litert/core/model", - ], -) - -cc_test( - name = "litert_options_test", - srcs = ["litert_options_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - ], - tags = ["no_oss"], - deps = [ - ":litert_common", - ":litert_options", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_event_type", - hdrs = ["litert_event_type.h"], -) - -cc_library( - name = "litert_event", - srcs = ["litert_event.cc"], - hdrs = ["litert_event.h"], - deps = [ - ":litert_common", - ":litert_event_type", - ":litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/runtime:event", - ], -) - -cc_library( - name = "litert_tensor_buffer_types", - srcs = [], - hdrs = ["litert_tensor_buffer_types.h"], -) - -cc_library( - name = "litert_gl_types", - srcs = [], - hdrs = ["litert_gl_types.h"], -) - -cc_library( - name = "litert_tensor_buffer", - srcs = [ - "litert_tensor_buffer.cc", - "litert_tensor_buffer_requirements.cc", - ], - hdrs = [ - "litert_tensor_buffer.h", - "litert_tensor_buffer_requirements.h", - ], - linkopts = select({ - "//tensorflow:android": [ - "-landroid", - ], - "//conditions:default": [], - }) + gles_linkopts(), - deps = [ - ":litert_common", - ":litert_event", - ":litert_gl_types", - ":litert_logging", - ":litert_model", - ":litert_tensor_buffer_types", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/runtime:tensor_buffer", - "@com_google_absl//absl/types:span", - "@opencl_headers", - ] + gles_deps(), -) - -cc_test( - name = "litert_tensor_buffer_test", - srcs = [ - "litert_tensor_buffer_test.cc", - ], - # require GPU to run OpenCL tests. - tags = [ - "requires-gpu-nvidia", - ], - deps = [ - ":litert_common", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - "//tensorflow/lite/experimental/litert/runtime:tensor_buffer", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "litert_tensor_buffer_requirements_test", - srcs = [ - "litert_tensor_buffer_requirements_test.cc", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - deps = [ - ":litert_common", - ":litert_tensor_buffer", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_dispatch_delegate", - hdrs = [ - "litert_dispatch_delegate.h", - ], - deps = [ - ":litert_environment_options", - "//tensorflow/lite/c:c_api", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/delegates/utils:simple_opaque_delegate", - "//tensorflow/lite/experimental/litert/runtime/dispatch:dispatch_delegate", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - ], -) - -cc_library( - name = "litert_compilation_options", - srcs = ["litert_compilation_options.cc"], - hdrs = [ - "litert_compilation_options.h", - ], - deps = [ - ":litert_accelerator_compilation_options", - ":litert_common", - ":litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/runtime:compilation_options", - ], -) - -cc_test( - name = "litert_compilation_options_test", - srcs = ["litert_compilation_options_test.cc"], - deps = [ - ":litert_accelerator_compilation_options", - ":litert_common", - ":litert_compilation_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_compiled_model", - srcs = ["litert_compiled_model.cc"], - hdrs = [ - "litert_compiled_model.h", - ], - deps = [ - ":litert_common", - ":litert_compilation_options", - ":litert_environment", - ":litert_logging", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/experimental/litert/runtime:compiled_model", - ], -) - -cc_test( - name = "litert_compiled_model_test", - srcs = [ - "litert_compiled_model_test.cc", - ], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", - ], - deps = [ - ":litert_common", - ":litert_compilation_options", - ":litert_compiled_model", - ":litert_environment", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -# The same test as `litert_compiled_model_test` but using the shared library `libLiteRtRuntimeCApi.so`. -cc_test( - name = "litert_compiled_model_shared_lib_test", - srcs = [ - "litert_compiled_model_test.cc", - ], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", - ], - deps = [ - ":litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_accelerator", - srcs = ["litert_accelerator.cc"], - hdrs = ["litert_accelerator.h"], - deps = [ - ":litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime:accelerator", - ], -) - -cc_test( - name = "litert_accelerator_test", - srcs = ["litert_accelerator_test.cc"], - deps = [ - ":litert_accelerator", - ":litert_accelerator_registration", - ":litert_common", - ":litert_environment", - "//tensorflow/lite/experimental/litert/runtime:accelerator", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_accelerator_compilation_options", - srcs = ["litert_accelerator_compilation_options.cc"], - hdrs = ["litert_accelerator_compilation_options.h"], - deps = [ - ":litert_common", - ], -) - -cc_test( - name = "litert_accelerator_compilation_options_test", - srcs = ["litert_accelerator_compilation_options_test.cc"], - deps = [ - ":litert_accelerator_compilation_options", - ":litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:version", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_accelerator_registration", - srcs = ["litert_accelerator_registration.cc"], - hdrs = ["litert_accelerator_registration.h"], - deps = [ - ":litert_accelerator", - ":litert_accelerator_compilation_options", - ":litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime:accelerator", - ], -) - -cc_test( - name = "litert_accelerator_registration_test", - srcs = ["litert_accelerator_registration_test.cc"], - deps = [ - ":litert_accelerator", - ":litert_accelerator_compilation_options", - ":litert_accelerator_registration", - ":litert_common", - ":litert_environment", - "//tensorflow/lite/experimental/litert/runtime:accelerator", - "@com_google_googletest//:gtest_main", - ], -) - -filegroup( - name = "litert_model_srcs", - srcs = ["litert_model.cc"], - visibility = ["//tensorflow/lite/experimental/litert/core/model:__pkg__"], -) - -filegroup( - name = "litert_model_hdrs", - srcs = ["litert_model.h"], - visibility = ["//tensorflow/lite/experimental/litert/core/model:__pkg__"], -) - -# Collection of all C API targets. -LITERT_C_API_COMMON_DEPS = [ - ":litert_accelerator", - ":litert_accelerator_registration", - ":litert_any", - ":litert_common", - ":litert_compiled_model", - ":litert_compilation_options", - ":litert_dispatch_delegate", - ":litert_event", - ":litert_environment", - ":litert_layout", - ":litert_logging", - ":litert_model", - ":litert_op_code", - ":litert_options", - ":litert_tensor_buffer", -] - -# This test verifies that the C API header files can build via C compiler. -cc_test( - name = "litert_c_api_common_test", - srcs = ["litert_c_api_common_test.c"], - copts = ["--std=c11"], - linkopts = ["-ldl"], - deps = LITERT_C_API_COMMON_DEPS, -) - -# Build `litert/c:litert_runtime_c_api_so` for `libLiteRtRuntimeCApi.so`. -cc_shared_library( - name = "litert_runtime_c_api_so", - additional_linker_inputs = export_lrt_runtime_only_script(), - shared_lib_name = "libLiteRtRuntimeCApi.so", - user_link_flags = export_lrt_runtime_only_linkopt() + select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }) + ["-Wl,-soname=libLiteRtRuntimeCApi.so"], - deps = LITERT_C_API_COMMON_DEPS, -) - -cc_library( - name = "litert_dispatch_headers", - hdrs = [ - ":litert_environment.h", - ":litert_environment_options.h", - ":litert_accelerator.h", - ":litert_accelerator_compilation_options.h", - ":litert_any.h", - ":litert_common.h", - ":litert_compiled_model.h", - ":litert_compilation_options.h", - ":litert_event.h", - ":litert_event_type.h", - ":litert_layout.h", - ":litert_logging.h", - ":litert_model.h", - ":litert_op_code.h", - ":litert_options.h", - ":litert_tensor_buffer.h", - ":litert_tensor_buffer_requirements.h", - ":litert_tensor_buffer_types.h", - ":litert_gl_types.h", - # Needed for litert/c/litert_op_code.h - "//tensorflow/lite:builtin_ops.h", - # Neeeded for litert/c/litert_model.h - "//tensorflow/lite/c:tensorflowlite_c_api_hdrs_filegroup", - "//tensorflow/lite/core/c:headers_filegroup", - ], # Export all header files (.h) in this directory - deps = [ - "@opencl_headers", - ], -) - -copy_file( - name = "copy_litert_runtime_c_api_so", - src = "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_so", - target = "libLiteRtRuntimeCApi.so", -) - -# This is cc_library target based on `libLiteRtRuntimeCApi.so`. -cc_library( - name = "litert_runtime_c_api_shared_lib", - srcs = [":litert_runtime_c_api_so"], - hdrs = glob(["litert_*.h"]) + [ - # Needed for litert/c/litert_op_code.h - "//tensorflow/lite:builtin_ops.h", - # Neeeded for litert/c/litert_model.h - "//tensorflow/lite/c:tensorflowlite_c_api_hdrs_filegroup", - "//tensorflow/lite/core/c:headers_filegroup", - ], - linkstatic = 1, - deps = [ - # only depend on headers - "@opencl_headers", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - ] + gles_headers(), -) - -# copybara:uncomment_begin(google-only) -# # Check that litert runtime doesn't depend on MLIR. -# check_dependencies( -# of = [":litert_runtime_c_api_shared_lib"], -# dont_match_regexp = "^//third_party/llvm/llvm-project/mlir", -# ) -# copybara:uncomment_end - -exports_files(srcs = glob(["litert_*.h"])) diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator.cc b/tensorflow/lite/experimental/litert/c/litert_accelerator.cc deleted file mode 100644 index 3e90de58967d07..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator.cc +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// Gets the number of accelerators registered to LiteRT. -LiteRtStatus LiteRtGetNumAccelerators(LiteRtEnvironment environment, - LiteRtParamIndex* num_accelerators) { - if (!environment || !num_accelerators) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_accelerators = environment->GetAcceleratorRegistry().size(); - return kLiteRtStatusOk; -} - -// Gets the accelerator at given index that is registered to LiteRT. -LiteRtStatus LiteRtGetAccelerator(LiteRtEnvironment environment, - LiteRtParamIndex index, - - LiteRtAccelerator* accelerator) { - if (!environment || !accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - litert::Expected registered_accelerator = - environment->GetAcceleratorRegistry().Get(index); - if (!registered_accelerator.HasValue()) { - return registered_accelerator.Error().Status(); - } - *accelerator = registered_accelerator.Value(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAcceleratorName(LiteRtAccelerator accelerator, - char const** name) { - if (!accelerator || !accelerator->GetName || !name) { - return kLiteRtStatusErrorInvalidArgument; - } - return accelerator->GetName(accelerator, name); -} - -LiteRtStatus LiteRtGetAcceleratorId(LiteRtAccelerator accelerator, - LiteRtAcceleratorId* id) { - if (!accelerator || !accelerator->env || !id) { - return kLiteRtStatusErrorInvalidArgument; - } - litert::Expected index = - accelerator->env->GetAcceleratorRegistry().FindAcceleratorIndex( - accelerator); - if (!index.HasValue()) { - return index.Error().Status(); - } - *id = index.Value(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAcceleratorVersion(LiteRtAccelerator accelerator, - LiteRtApiVersion* version) { - if (!accelerator || !accelerator->GetVersion || !version) { - return kLiteRtStatusErrorInvalidArgument; - } - return accelerator->GetVersion(accelerator, version); -} - -LiteRtStatus LiteRtGetAcceleratorHardwareSupport( - LiteRtAccelerator accelerator, LiteRtHwAcceleratorSet* supported_hardware) { - if (!accelerator || !accelerator->GetHardwareSupport || !supported_hardware) { - return kLiteRtStatusErrorInvalidArgument; - } - return accelerator->GetHardwareSupport(accelerator, supported_hardware); -} - -LiteRtStatus LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - LiteRtAccelerator accelerator, bool* does_jit_compilation) { - if (!accelerator || !does_jit_compilation) { - return kLiteRtStatusErrorInvalidArgument; - } - if (!accelerator->IsTfLiteDelegateResponsibleForJitCompilation) { - *does_jit_compilation = false; - return kLiteRtStatusOk; - } - return accelerator->IsTfLiteDelegateResponsibleForJitCompilation( - accelerator, does_jit_compilation); -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator.h b/tensorflow/lite/experimental/litert/c/litert_accelerator.h deleted file mode 100644 index ff3ec4bf14f9a6..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" - -#ifdef __cplusplus -extern "C" { -#endif - -LITERT_DEFINE_HANDLE(LiteRtAccelerator); - -typedef size_t LiteRtAcceleratorId; - -// Gets the number of accelerators registered to LiteRT. -LiteRtStatus LiteRtGetNumAccelerators(LiteRtEnvironment environment, - LiteRtParamIndex* num_accelerators); - -// Gets the accelerator at given index that is registered to LiteRT. -LiteRtStatus LiteRtGetAccelerator(LiteRtEnvironment environment, - LiteRtParamIndex index, - LiteRtAccelerator* accelerator); - -// Fetches the name of the accelerator. -// -// Note: client code does not need to manage the `name` lifetime. -LiteRtStatus LiteRtGetAcceleratorName(LiteRtAccelerator accelerator, - char const** name); - -// Fetches the accelerator identifier. -// -// The identifier is a runtime unique number, provided by the registrar to the -// accelerator upon registration. -LiteRtStatus LiteRtGetAcceleratorId(LiteRtAccelerator accelerator, - LiteRtAcceleratorId* id); - -// Fetches the version of the accelerator implementation. -// -// Note: This is NOT the LiteRT version. It's the accelerator specific software -// implementation version. -LiteRtStatus LiteRtGetAcceleratorVersion(LiteRtAccelerator accelerator, - LiteRtApiVersion* version); - -// Fetches the accelerator hardware. -// -// `supported_hardware` is a bitfield of `LiteRtHwAccelerators` values. -LiteRtStatus LiteRtGetAcceleratorHardwareSupport( - LiteRtAccelerator accelerator, LiteRtHwAcceleratorSet* supported_hardware); - -// Returns whether the accelerator TFLite delegate does some JIT compilation. -LiteRtStatus LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - LiteRtAccelerator accelerator, bool* does_jit_compilation); - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.cc b/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.cc deleted file mode 100644 index fb6e0016e69885..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.cc +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -struct LiteRtAcceleratorCompilationOptionsT { - LiteRtApiVersion payload_version; - std::string payload_identifier; - std::unique_ptr payload_data; - LiteRtAcceleratorCompilationOptionsT* next = nullptr; - - LiteRtAcceleratorCompilationOptionsT(const LiteRtApiVersion& payload_version_, - std::string payload_identifier_, - void* payload_data_, - void (*payload_destructor_)(void*)) - : payload_version(payload_version_), - payload_identifier(std::move(payload_identifier_)), - payload_data(payload_data_, payload_destructor_) {} -}; - -LiteRtStatus LiteRtCreateAcceleratorCompilationOptions( - const LiteRtApiVersion* payload_version, const char* payload_identifier, - void* payload_data, void (*payload_destructor)(void*), - LiteRtAcceleratorCompilationOptions* options) { - if (!payload_version || !payload_identifier || !payload_data || - !payload_destructor || !options) { - return kLiteRtStatusErrorInvalidArgument; - } - *options = new LiteRtAcceleratorCompilationOptionsT( - *payload_version, std::string(payload_identifier), payload_data, - payload_destructor); - return kLiteRtStatusOk; -} - -void LiteRtDestroyAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions options) { - while (options) { - LiteRtAcceleratorCompilationOptions next = options->next; - delete options; - options = next; - } -} - -LiteRtStatus LiteRtGetAcceleratorCompilationOptionsVersion( - LiteRtAcceleratorCompilationOptions options, - LiteRtApiVersion* payload_version) { - if (!options || !payload_version) { - return kLiteRtStatusErrorInvalidArgument; - } - *payload_version = options->payload_version; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAcceleratorCompilationOptionsIdentifier( - LiteRtAcceleratorCompilationOptions options, - const char** payload_identifier) { - if (!options || !payload_identifier) { - return kLiteRtStatusErrorInvalidArgument; - } - *payload_identifier = options->payload_identifier.c_str(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAcceleratorCompilationOptionsData( - LiteRtAcceleratorCompilationOptions options, void** payload_data) { - if (!options || !payload_data) { - return kLiteRtStatusErrorInvalidArgument; - } - *payload_data = options->payload_data.get(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtFindAcceleratorCompilationOptionsData( - LiteRtAcceleratorCompilationOptions options, const char* payload_identifier, - LiteRtApiVersion* payload_version, void** payload_data) { - if (!options || !payload_identifier || !payload_version || !payload_data) { - return kLiteRtStatusErrorInvalidArgument; - } - while (options) { - if (!strcmp(options->payload_identifier.c_str(), payload_identifier)) { - *payload_version = options->payload_version; - *payload_data = options->payload_data.get(); - return kLiteRtStatusOk; - } else { - options = options->next; - } - } - return kLiteRtStatusErrorNotFound; -} - -LiteRtStatus LiteRtGetNextAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions* options) { - if (!options || !*options) { - return kLiteRtStatusErrorInvalidArgument; - } - *options = (*options)->next; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtAppendAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions* options, - LiteRtAcceleratorCompilationOptions appended_options) { - if (!options || !appended_options) { - return kLiteRtStatusErrorInvalidArgument; - } - while (*options) { - options = &((*options)->next); - } - *options = appended_options; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtPopAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions* options) { - if (!options) { - return kLiteRtStatusErrorInvalidArgument; - } - LiteRtAcceleratorCompilationOptions* last = options; - while ((*last)->next) { - last = &(*last)->next; - } - if (*last) { - LiteRtDestroyAcceleratorCompilationOptions(*last); - *last = nullptr; - } - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h b/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h deleted file mode 100644 index cfeff3df8d6cf7..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_COMPILATION_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_COMPILATION_OPTIONS_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// A linked list of versioned accelerator compilation options. List items -// include: -// -// - a unique payload identifier field (string), used to distinguish payloads of -// different types; -// -// - a payload field and associated payload destructor callback; -// -// - a payload version field, used by the consumer code to know the structure of -// the payload. -LITERT_DEFINE_HANDLE(LiteRtAcceleratorCompilationOptions); - -LiteRtStatus LiteRtCreateAcceleratorCompilationOptions( - const LiteRtApiVersion* payload_version, const char* payload_identifier, - void* payload_data, void (*payload_destructor)(void* payload_data), - LiteRtAcceleratorCompilationOptions* options); - -// Releases an entire options list starting from `options`. -// -// Warning: Once an `options` item has been appended to another `options` item, -// the user will no longer need to destoy the former `options` item manually -// with this function. -void LiteRtDestroyAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions options); - -// Gets the payload version field of the first item in the given `options` list. -LiteRtStatus LiteRtGetAcceleratorCompilationOptionsVersion( - LiteRtAcceleratorCompilationOptions options, - LiteRtApiVersion* payload_version); - -// Gets the patload identifier field of the first item in the given `options` -// list. -LiteRtStatus LiteRtGetAcceleratorCompilationOptionsIdentifier( - LiteRtAcceleratorCompilationOptions options, - const char** payload_identifier); - -// Gets the payload data field of the first item in the given `options` list. -LiteRtStatus LiteRtGetAcceleratorCompilationOptionsData( - LiteRtAcceleratorCompilationOptions options, void** payload_data); - -// Gets the payload version and data for the `options` list item with a given -// payload identifier. Return kLiteRtStatusErrorNotFound if not such item is -// found. -LiteRtStatus LiteRtFindAcceleratorCompilationOptionsData( - LiteRtAcceleratorCompilationOptions options, const char* payload_identifier, - LiteRtApiVersion* payload_version, void** payload_data); - -// Iterate through the next item in the option list pointed by `options` and -// sets parameter `options` to null if there is no next item. -LiteRtStatus LiteRtGetNextAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions* options); - -// Appends `next_options` to the list ponted by `options` and takes ownership of -// the appended object. While parameter `options` must be non-null, `*options` -// may however be null, in which case this call is equivalent to `*options = -// appended_options`. -LiteRtStatus LiteRtAppendAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions* options, - LiteRtAcceleratorCompilationOptions appended_options); - -// Removes and deallocates the last option in the linked list pointed by -// parameter `options`. -LiteRtStatus LiteRtPopAcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions* options); - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_COMPILATION_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options_test.cc b/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options_test.cc deleted file mode 100644 index f13195942893e1..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options_test.cc +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/version.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { - -using testing::StrEq; -using testing::litert::IsError; - -struct DummyAccleratorCompilationOptions { - static constexpr const LiteRtApiVersion kVersion = {0, 1, 0}; - static constexpr const char* const kIdentifier = "dummy-accelerator"; - - int dummy_option = 3; - - // Allocates and sets the basic structure for the accelerator options. - static litert::Expected CreateOptions() { - auto* payload = new DummyAccleratorCompilationOptions; - auto payload_destructor = [](void* payload) { - delete reinterpret_cast(payload); - }; - return CreateOptions(kVersion, kIdentifier, payload, payload_destructor); - } - - static litert::Expected CreateOptions( - LiteRtApiVersion version, const char* identifier, void* payload, - void (*payload_destructor)(void*)) { - LiteRtAcceleratorCompilationOptions options; - LITERT_RETURN_IF_ERROR(LiteRtCreateAcceleratorCompilationOptions( - &version, identifier, payload, payload_destructor, &options)); - return options; - } -}; - -class LiteRtAcceleratorOptionsTest : public testing::Test { - public: - void SetUp() override { - auto options = DummyAccleratorCompilationOptions::CreateOptions(); - EXPECT_TRUE(options); - options_ = *options; - } - - void TearDown() override { - LiteRtDestroyAcceleratorCompilationOptions(options_); - options_ = nullptr; - } - - LiteRtAcceleratorCompilationOptions options_ = nullptr; -}; - -TEST_F(LiteRtAcceleratorOptionsTest, CreateAndDestroyDoesntLeak) {} - -TEST_F(LiteRtAcceleratorOptionsTest, GetIdentifier) { - const char* identifier = nullptr; - LITERT_EXPECT_OK( - LiteRtGetAcceleratorCompilationOptionsIdentifier(options_, &identifier)); - EXPECT_THAT(identifier, - StrEq(DummyAccleratorCompilationOptions::kIdentifier)); - EXPECT_THAT( - LiteRtGetAcceleratorCompilationOptionsIdentifier(nullptr, &identifier), - IsError(kLiteRtStatusErrorInvalidArgument)); - EXPECT_EQ(LiteRtGetAcceleratorCompilationOptionsIdentifier(options_, nullptr), - kLiteRtStatusErrorInvalidArgument); -} - -TEST_F(LiteRtAcceleratorOptionsTest, GetVersion) { - LiteRtApiVersion version; - EXPECT_EQ(LiteRtGetAcceleratorCompilationOptionsVersion(options_, &version), - kLiteRtStatusOk); - EXPECT_TRUE(litert::internal::IsSameVersion( - version, DummyAccleratorCompilationOptions::kVersion)); - EXPECT_EQ(LiteRtGetAcceleratorCompilationOptionsVersion(nullptr, &version), - kLiteRtStatusErrorInvalidArgument); - EXPECT_EQ(LiteRtGetAcceleratorCompilationOptionsVersion(options_, nullptr), - kLiteRtStatusErrorInvalidArgument); -} - -TEST_F(LiteRtAcceleratorOptionsTest, CreatingAndDestroyingAListWorks) { - auto appended_options1 = DummyAccleratorCompilationOptions::CreateOptions(); - ASSERT_TRUE(appended_options1); - auto appended_options2 = DummyAccleratorCompilationOptions::CreateOptions(); - ASSERT_TRUE(appended_options2); - - EXPECT_EQ( - LiteRtAppendAcceleratorCompilationOptions(&options_, *appended_options1), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtAppendAcceleratorCompilationOptions(&options_, *appended_options2), - kLiteRtStatusOk); - - // Iterate through the list to check that the links have been correctly added. - - LiteRtAcceleratorCompilationOptions options_it = options_; - ASSERT_EQ(LiteRtGetNextAcceleratorCompilationOptions(&options_it), - kLiteRtStatusOk); - EXPECT_EQ(options_it, *appended_options1); - - ASSERT_EQ(LiteRtGetNextAcceleratorCompilationOptions(&options_it), - kLiteRtStatusOk); - EXPECT_EQ(options_it, *appended_options2); - - // The list is destroyed in the `TearDown()` function. -} - -TEST_F(LiteRtAcceleratorOptionsTest, FindData) { - constexpr LiteRtApiVersion appended_options_version = {1, 2, 3}; - constexpr auto* appended_options_id = "appended_options_id"; - void* appended_options_data = reinterpret_cast(12345); - constexpr auto appended_options_destructor = [](void*) {}; - - auto appended_options = DummyAccleratorCompilationOptions::CreateOptions( - appended_options_version, appended_options_id, appended_options_data, - appended_options_destructor); - - EXPECT_EQ( - LiteRtAppendAcceleratorCompilationOptions(&options_, *appended_options), - kLiteRtStatusOk); - - LiteRtApiVersion payload_version; - void* payload_data; - EXPECT_EQ(LiteRtFindAcceleratorCompilationOptionsData( - options_, appended_options_id, &payload_version, &payload_data), - kLiteRtStatusOk); - - EXPECT_EQ(payload_version.major, appended_options_version.major); - EXPECT_EQ(payload_version.minor, appended_options_version.minor); - EXPECT_EQ(payload_version.patch, appended_options_version.patch); - EXPECT_EQ(payload_data, appended_options_data); - - // The list is destroyed in the `TearDown()` function. -} - -TEST_F(LiteRtAcceleratorOptionsTest, Pop) { - constexpr LiteRtApiVersion appended_options_version = {1, 2, 3}; - constexpr auto* appended_options_id = "appended_options_id"; - void* appended_options_data = reinterpret_cast(12345); - constexpr auto appended_options_destructor = [](void*) {}; - - auto appended_options = DummyAccleratorCompilationOptions::CreateOptions( - appended_options_version, appended_options_id, appended_options_data, - appended_options_destructor); - - EXPECT_EQ( - LiteRtAppendAcceleratorCompilationOptions(&options_, *appended_options), - kLiteRtStatusOk); - - LiteRtApiVersion payload_version; - void* payload_data; - EXPECT_EQ(LiteRtFindAcceleratorCompilationOptionsData( - options_, appended_options_id, &payload_version, &payload_data), - kLiteRtStatusOk); - - // After poping the last item, we shouldn't be able to find it any longer. - EXPECT_EQ(LiteRtPopAcceleratorCompilationOptions(&options_), kLiteRtStatusOk); - EXPECT_NE(LiteRtFindAcceleratorCompilationOptionsData( - options_, appended_options_id, &payload_version, &payload_data), - kLiteRtStatusOk); - - // The list is destroyed in the `TearDown()` function. -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_registration.cc b/tensorflow/lite/experimental/litert/c/litert_accelerator_registration.cc deleted file mode 100644 index 8404f7275b7fb5..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_registration.cc +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" - -LiteRtStatus LiteRtCreateAccelerator(LiteRtAccelerator* accelerator) { - if (!accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - *accelerator = - litert::internal::AcceleratorRegistry::CreateEmptyAccelerator().release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtDestroyAccelerator(LiteRtAccelerator accelerator) { - litert::internal::AcceleratorRegistry::DestroyAccelerator(accelerator); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtRegisterAccelerator(LiteRtEnvironment environment, - LiteRtAccelerator accelerator, - void* data, void (*ReleaseData)(void*)) { - std::unique_ptr data_guard(data, ReleaseData); - litert::internal::AcceleratorRegistry::Ptr accelerator_guard(accelerator); - if (!accelerator_guard) { - return kLiteRtStatusErrorInvalidArgument; - } - accelerator_guard->env = environment; - litert::Expected registered_accelerator = - environment->GetAcceleratorRegistry().RegisterAccelerator( - std::move(accelerator_guard)); - if (!registered_accelerator.HasValue()) { - return registered_accelerator.Error().Status(); - } - registered_accelerator.Value()->data = data_guard.release(); - registered_accelerator.Value()->ReleaseData = ReleaseData; - return kLiteRtStatusOk; -} - -// Sets the function used to retrieve the accelerator name. -LiteRtStatus LiteRtSetAcceleratorGetName( - LiteRtAccelerator accelerator, - LiteRtStatus (*GetName)(LiteRtAccelerator accelerator, const char** name)) { - if (!accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - accelerator->GetName = GetName; - return kLiteRtStatusOk; -} - -// Sets the function used to retrieve the accelerator version. -LiteRtStatus LiteRtSetAcceleratorGetVersion( - LiteRtAccelerator accelerator, - LiteRtStatus (*GetVersion)(LiteRtAccelerator accelerator, - LiteRtApiVersion* version)) { - if (!accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - accelerator->GetVersion = GetVersion; - return kLiteRtStatusOk; -} - -// Sets the function used to retrieve the accelerator hardware support. -LiteRtStatus LiteRtSetAcceleratorGetHardwareSupport( - LiteRtAccelerator accelerator, - LiteRtStatus (*GetHardwareSupport)( - LiteRtAccelerator accelerator, - LiteRtHwAcceleratorSet* supported_hardware)) { - if (!accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - accelerator->GetHardwareSupport = GetHardwareSupport; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtSetDelegateFunction( - LiteRtAccelerator accelerator, - LiteRtStatus (*CreateDelegate)(LiteRtAccelerator accelerator, - LiteRtAcceleratorCompilationOptions options, - void** delegate), - void (*DestroyDelegate)(void* delegate)) { - if (!accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - accelerator->CreateDelegate = CreateDelegate; - accelerator->DestroyDelegate = DestroyDelegate; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtSetIsAcceleratorDelegateResponsibleForJitCompilation( - LiteRtAccelerator accelerator, - LiteRtStatus (*IsTfLiteDelegateResponsibleForJitCompilation)( - LiteRtAcceleratorT* accelerator, bool* does_jit_compilation)) { - if (!accelerator) { - return kLiteRtStatusErrorInvalidArgument; - } - accelerator->IsTfLiteDelegateResponsibleForJitCompilation = - IsTfLiteDelegateResponsibleForJitCompilation; - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h b/tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h deleted file mode 100644 index 19369d436daea0..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_REGISTRATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_REGISTRATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// Creates an empty accelerator handle. -LiteRtStatus LiteRtCreateAccelerator(LiteRtAccelerator* accelerator); - -// Destroys an accelerator handle. -// -// Warning: This SHOULD NOT BE CALLED after a call to -// `LiteRtRegisterAccelerator`. -LiteRtStatus LiteRtDestroyAccelerator(LiteRtAccelerator accelerator); - -// Sets the registration data AND clean-up function, then registers the -// accelerator with the LiteRT environment. -// -// - `data` and `ReleaseData` may be null. -// -// Note: After this function returns successfully, `data` is managed by the -// LiteRT environment. `ReleaseData` is called to release its memory. -// -// Warning: In case of failure, `accelerator` is released and `data` is released -// using `ReleaseData`. -LiteRtStatus LiteRtRegisterAccelerator(LiteRtEnvironment environment, - LiteRtAccelerator accelerator, - void* data, void (*ReleaseData)(void*)); - -// Sets the function used to retrieve the accelerator name. -LiteRtStatus LiteRtSetAcceleratorGetName( - LiteRtAccelerator accelerator, - LiteRtStatus (*GetName)(LiteRtAccelerator accelerator, const char** name)); - -// Sets the function used to retrieve the accelerator implementation version. -// -// Note: This is NOT the LiteRT version. It's the accelerator specific software -// implementation version. -LiteRtStatus LiteRtSetAcceleratorGetVersion( - LiteRtAccelerator accelerator, - LiteRtStatus (*GetVersion)(LiteRtAccelerator accelerator, - LiteRtApiVersion* version)); - -// Sets the function used to retrieve the accelerator hardware support. -LiteRtStatus LiteRtSetAcceleratorGetHardwareSupport( - LiteRtAccelerator accelerator, - LiteRtStatus (*GetHardwareSupport)( - LiteRtAccelerator accelerator, - LiteRtHwAcceleratorSet* supported_hardware)); - -// Sets the function used to return a Delegate to apply the accelerator by the -// compiled model and its destructor. The returned Delegate object is owned by -// the compiled model. Used void** for the Delegate instead of -// TfLiteOpaqueDelegate** to avoid TFLite dependency. -LiteRtStatus LiteRtSetDelegateFunction( - LiteRtAccelerator accelerator, - LiteRtStatus (*CreateDelegate)(LiteRtAccelerator accelerator, - LiteRtAcceleratorCompilationOptions options, - void** delegate), - void (*DestroyDelegate)(void* delegate)); - -// Sets the function used to surface whether the delegate created by the -// accelerator does JIT compilation or not. -// -// This affects whether the compiled model creation will apply the accelerator -// without an explicit request in the JIT compilation options. -// -// If this isn't set, the result will be treated as `false`. -LiteRtStatus LiteRtSetIsAcceleratorDelegateResponsibleForJitCompilation( - LiteRtAccelerator accelerator, - LiteRtStatus (*IsTfLiteDelegateResponsibleForJitCompilation)( - LiteRtAccelerator accelerator, bool* does_jit_compilation)); - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ACCELERATOR_REGISTRATION_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_registration_test.cc b/tensorflow/lite/experimental/litert/c/litert_accelerator_registration_test.cc deleted file mode 100644 index df7a051949447f..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_registration_test.cc +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" - -#include - -#include -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" - -namespace { - -class DummyAccelerator { - public: - static std::unique_ptr CpuAccelerator() { - auto accelerator = std::make_unique(); - accelerator->hardware_support_ = kLiteRtHwAcceleratorCpu; - return accelerator; - } - - static void Destroy(void* dummy_accelerator) { - DummyAccelerator* instance = - reinterpret_cast(dummy_accelerator); - delete instance; - } - - static LiteRtStatus GetName(LiteRtAccelerator accelerator, - const char** name) { - return kLiteRtStatusOk; - } - - static LiteRtStatus GetVersion(LiteRtAccelerator accelerator, - LiteRtApiVersion* version) { - return kLiteRtStatusOk; - } - - static LiteRtStatus GetHardwareSupport( - LiteRtAccelerator accelerator, - LiteRtHwAcceleratorSet* supported_hardware) { - return kLiteRtStatusOk; - } - - static LiteRtStatus CreateDelegate( - LiteRtAccelerator accelerator, - LiteRtAcceleratorCompilationOptions options, void** delegate) { - return kLiteRtStatusOk; - } - - static void DestroyDelegate(void* delegate) {} - - LiteRtHwAccelerators hardware_support_; -}; - -TEST(LiteRtAcceleratorRegistrationTest, SetAcceleratorGetNameWorks) { - LiteRtAcceleratorT accelerator; - EXPECT_EQ(LiteRtSetAcceleratorGetName(nullptr, DummyAccelerator::GetName), - kLiteRtStatusErrorInvalidArgument); - LiteRtSetAcceleratorGetName(&accelerator, DummyAccelerator::GetName); - EXPECT_EQ(accelerator.GetName, DummyAccelerator::GetName); -} - -TEST(LiteRtAcceleratorRegistrationTest, SetAcceleratorGetVersionWorks) { - LiteRtAcceleratorT accelerator; - EXPECT_EQ( - LiteRtSetAcceleratorGetVersion(nullptr, DummyAccelerator::GetVersion), - kLiteRtStatusErrorInvalidArgument); - LiteRtSetAcceleratorGetVersion(&accelerator, DummyAccelerator::GetVersion); - EXPECT_EQ(accelerator.GetVersion, DummyAccelerator::GetVersion); -} - -TEST(LiteRtAcceleratorRegistrationTest, SetAcceleratorGetHardwareSupportWorks) { - LiteRtAcceleratorT accelerator; - EXPECT_EQ(LiteRtSetAcceleratorGetHardwareSupport( - nullptr, DummyAccelerator::GetHardwareSupport), - kLiteRtStatusErrorInvalidArgument); - LiteRtSetAcceleratorGetHardwareSupport(&accelerator, - DummyAccelerator::GetHardwareSupport); - EXPECT_EQ(accelerator.GetHardwareSupport, - DummyAccelerator::GetHardwareSupport); -} - -TEST(LiteRtAcceleratorRegistrationTest, SetDelegateFunctionsWorks) { - LiteRtAcceleratorT accelerator; - EXPECT_EQ(LiteRtSetDelegateFunction(nullptr, DummyAccelerator::CreateDelegate, - DummyAccelerator::DestroyDelegate), - kLiteRtStatusErrorInvalidArgument); - LiteRtSetDelegateFunction(&accelerator, DummyAccelerator::CreateDelegate, - DummyAccelerator::DestroyDelegate); - EXPECT_EQ(accelerator.CreateDelegate, DummyAccelerator::CreateDelegate); - EXPECT_EQ(accelerator.DestroyDelegate, DummyAccelerator::DestroyDelegate); -} - -TEST(LiteRtAcceleratorRegistrationTest, CreateDestroyAcceleratorDoesntLeak) { - LiteRtAccelerator accelerator; - ASSERT_EQ(LiteRtCreateAccelerator(&accelerator), kLiteRtStatusOk); - ASSERT_EQ(LiteRtDestroyAccelerator(accelerator), kLiteRtStatusOk); -} - -TEST(LiteRtAcceleratorRegistrationTest, RegisterAcceleratorWorks) { - LiteRtEnvironment env_; - LiteRtEnvironmentCreate(/*num_options=*/0, /*options=*/nullptr, &env_); - auto dummy_accelerator = DummyAccelerator::CpuAccelerator(); - LiteRtAccelerator accelerator; - LiteRtCreateAccelerator(&accelerator); - LiteRtSetAcceleratorGetName(accelerator, DummyAccelerator::GetName); - LiteRtSetAcceleratorGetVersion(accelerator, DummyAccelerator::GetVersion); - LiteRtSetAcceleratorGetHardwareSupport(accelerator, - DummyAccelerator::GetHardwareSupport); - LiteRtRegisterAccelerator(env_, accelerator, dummy_accelerator.release(), - DummyAccelerator::Destroy); - LiteRtDestroyEnvironment(env_); -} - -TEST(LiteRtAcceleratorRegistrationTest, - RegisterAcceleratorFailsForNullAccelerator) { - LiteRtEnvironment env; - LiteRtEnvironmentCreate(/*num_options=*/0, /*options=*/nullptr, &env); - // We check that the memory is correctly deallocated if the registration - // fails. - auto dummy_accelerator = DummyAccelerator::CpuAccelerator(); - EXPECT_EQ(LiteRtRegisterAccelerator(env, nullptr, dummy_accelerator.release(), - DummyAccelerator::Destroy), - kLiteRtStatusErrorInvalidArgument); - LiteRtDestroyEnvironment(env); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_accelerator_test.cc b/tensorflow/lite/experimental/litert/c/litert_accelerator_test.cc deleted file mode 100644 index 0ef878b5cc3198..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_accelerator_test.cc +++ /dev/null @@ -1,287 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" - -#include -#include - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" - -#define LITERT_ENSURE_OK(expr) \ - do { \ - LiteRtStatus status = (expr); \ - if (status != kLiteRtStatusOk) { \ - return status; \ - } \ - } while (0) - -namespace { -using testing::Eq; -using testing::Ne; -using testing::NotNull; -using testing::StrEq; - -class DummyAccelerator { - public: - // `hardware_support` is a bitfield of `LiteRtHwAccelerators` values. - static LiteRtStatus RegisterAccelerator(int hardware_support, - LiteRtEnvironment env) { - auto dummy_accelerator = std::make_unique(); - dummy_accelerator->hardware_support_ = hardware_support; - LiteRtAccelerator accelerator; - LiteRtCreateAccelerator(&accelerator); - LITERT_ENSURE_OK( - LiteRtSetAcceleratorGetName(accelerator, DummyAccelerator::GetName)); - LITERT_ENSURE_OK(LiteRtSetAcceleratorGetVersion( - accelerator, DummyAccelerator::GetVersion)); - LITERT_ENSURE_OK(LiteRtSetAcceleratorGetHardwareSupport( - accelerator, DummyAccelerator::GetHardwareSupport)); - LITERT_ENSURE_OK(LiteRtRegisterAccelerator(env, accelerator, - dummy_accelerator.release(), - DummyAccelerator::Destroy)); - return kLiteRtStatusOk; - } - - static void Destroy(void* dummy_accelerator) { - DummyAccelerator* instance = - reinterpret_cast(dummy_accelerator); - delete instance; - } - - static LiteRtStatus GetName(LiteRtAccelerator accelerator, - const char** name) { - if (!accelerator || !accelerator->data || !name) { - return kLiteRtStatusErrorInvalidArgument; - } - DummyAccelerator& self = - *reinterpret_cast(accelerator->data); - if (self.name_.empty()) { - self.name_.append("Dummy"); - if (self.hardware_support_ & kLiteRtHwAcceleratorCpu) { - self.name_.append("Cpu"); - } - if (self.hardware_support_ & kLiteRtHwAcceleratorGpu) { - self.name_.append("Gpu"); - } - self.name_.append("Accelerator"); - } - *name = self.name_.c_str(); - return kLiteRtStatusOk; - } - - static LiteRtStatus GetVersion(LiteRtAccelerator accelerator, - LiteRtApiVersion* version) { - if (!version) { - return kLiteRtStatusErrorInvalidArgument; - } - version->major = 1; - version->minor = 2; - version->patch = 3; - return kLiteRtStatusOk; - } - - static LiteRtStatus GetHardwareSupport( - LiteRtAccelerator accelerator, - LiteRtHwAcceleratorSet* supported_hardware) { - if (!accelerator || !accelerator->data || !supported_hardware) { - return kLiteRtStatusErrorInvalidArgument; - } - - const DummyAccelerator& self = - *reinterpret_cast(accelerator->data); - *supported_hardware = self.hardware_support_; - return kLiteRtStatusOk; - } - - int hardware_support_; - std::string name_; -}; - -class LiteRtAcceleratorTest : public testing::Test { - public: - LiteRtEnvironment env_; - void SetUp() override { - LiteRtEnvironmentCreate(/*num_options=*/0, nullptr, &env_); - DummyAccelerator::RegisterAccelerator(kLiteRtHwAcceleratorCpu, env_); - } - - void TearDown() override { LiteRtDestroyEnvironment(env_); } -}; - -TEST_F(LiteRtAcceleratorTest, IteratingOverAcceleratorsWorks) { - // CPU accelerator is registered in the SetUp function. - DummyAccelerator::RegisterAccelerator(kLiteRtHwAcceleratorGpu, env_); - - LiteRtParamIndex num_accelerators = 0; - ASSERT_THAT(LiteRtGetNumAccelerators(env_, &num_accelerators), - kLiteRtStatusOk); - ASSERT_THAT(num_accelerators, 2); - - EXPECT_THAT(LiteRtGetAccelerator(env_, 0, nullptr), - kLiteRtStatusErrorInvalidArgument); - LiteRtAccelerator accelerator0; - ASSERT_THAT(LiteRtGetAccelerator(env_, 0, &accelerator0), kLiteRtStatusOk); - EXPECT_THAT(accelerator0, NotNull()); - - EXPECT_THAT(LiteRtGetAccelerator(env_, 1, nullptr), - kLiteRtStatusErrorInvalidArgument); - LiteRtAccelerator accelerator1; - ASSERT_THAT(LiteRtGetAccelerator(env_, 1, &accelerator1), kLiteRtStatusOk); - EXPECT_THAT(accelerator1, NotNull()); - - EXPECT_THAT(accelerator0, Ne(accelerator1)); - - LiteRtAccelerator accelerator2; - EXPECT_THAT(LiteRtGetAccelerator(env_, 2, &accelerator2), - kLiteRtStatusErrorNotFound); -} - -TEST_F(LiteRtAcceleratorTest, GetAcceleratorNameWorks) { - LiteRtParamIndex num_accelerators = 0; - ASSERT_THAT(LiteRtGetNumAccelerators(env_, &num_accelerators), - kLiteRtStatusOk); - ASSERT_THAT(num_accelerators, 1); - - LiteRtAccelerator accelerator; - ASSERT_THAT(LiteRtGetAccelerator(env_, 0, &accelerator), kLiteRtStatusOk); - const char* name = nullptr; - ASSERT_THAT(LiteRtGetAcceleratorName(accelerator, &name), kLiteRtStatusOk); - EXPECT_THAT(name, StrEq("DummyCpuAccelerator")); - - EXPECT_THAT(LiteRtGetAcceleratorName(nullptr, &name), - kLiteRtStatusErrorInvalidArgument); - EXPECT_THAT(LiteRtGetAcceleratorName(accelerator, nullptr), - kLiteRtStatusErrorInvalidArgument); - // Make the accelerator invalid. - accelerator->GetName = nullptr; - EXPECT_THAT(LiteRtGetAcceleratorName(accelerator, &name), - kLiteRtStatusErrorInvalidArgument); -} - -TEST_F(LiteRtAcceleratorTest, GetAcceleratorIdWorks) { - LiteRtParamIndex num_accelerators = 0; - ASSERT_THAT(LiteRtGetNumAccelerators(env_, &num_accelerators), - kLiteRtStatusOk); - ASSERT_THAT(num_accelerators, 1); - - LiteRtAccelerator accelerator; - ASSERT_THAT(LiteRtGetAccelerator(env_, 0, &accelerator), kLiteRtStatusOk); - LiteRtAcceleratorId accelerator_id; - ASSERT_THAT(LiteRtGetAcceleratorId(accelerator, &accelerator_id), - kLiteRtStatusOk); - EXPECT_THAT(accelerator_id, Eq(0)); - - EXPECT_THAT(LiteRtGetAcceleratorId(nullptr, &accelerator_id), - kLiteRtStatusErrorInvalidArgument); - EXPECT_THAT(LiteRtGetAcceleratorId(accelerator, nullptr), - kLiteRtStatusErrorInvalidArgument); - // Make the accelerator invalid. - accelerator->env = nullptr; - EXPECT_THAT(LiteRtGetAcceleratorId(accelerator, &accelerator_id), - kLiteRtStatusErrorInvalidArgument); -} - -TEST_F(LiteRtAcceleratorTest, GetAcceleratorVersionWorks) { - LiteRtParamIndex num_accelerators = 0; - ASSERT_THAT(LiteRtGetNumAccelerators(env_, &num_accelerators), - kLiteRtStatusOk); - ASSERT_THAT(num_accelerators, 1); - - LiteRtAccelerator accelerator; - ASSERT_THAT(LiteRtGetAccelerator(env_, 0, &accelerator), kLiteRtStatusOk); - LiteRtApiVersion version; - ASSERT_THAT(LiteRtGetAcceleratorVersion(accelerator, &version), - kLiteRtStatusOk); - EXPECT_THAT(version.major, Eq(1)); - EXPECT_THAT(version.minor, Eq(2)); - EXPECT_THAT(version.patch, Eq(3)); - - EXPECT_THAT(LiteRtGetAcceleratorVersion(nullptr, &version), - kLiteRtStatusErrorInvalidArgument); - EXPECT_THAT(LiteRtGetAcceleratorVersion(accelerator, nullptr), - kLiteRtStatusErrorInvalidArgument); - // Make the accelerator invalid. - accelerator->GetVersion = nullptr; - EXPECT_THAT(LiteRtGetAcceleratorVersion(accelerator, &version), - kLiteRtStatusErrorInvalidArgument); -} - -TEST_F(LiteRtAcceleratorTest, GetAcceleratorHardwareSupportWorks) { - LiteRtParamIndex num_accelerators = 0; - ASSERT_THAT(LiteRtGetNumAccelerators(env_, &num_accelerators), - kLiteRtStatusOk); - ASSERT_THAT(num_accelerators, 1); - - LiteRtAccelerator accelerator; - ASSERT_THAT(LiteRtGetAccelerator(env_, 0, &accelerator), kLiteRtStatusOk); - int hardware_support; - ASSERT_THAT( - LiteRtGetAcceleratorHardwareSupport(accelerator, &hardware_support), - kLiteRtStatusOk); - EXPECT_THAT(hardware_support & kLiteRtHwAcceleratorCpu, true); - EXPECT_THAT(hardware_support & kLiteRtHwAcceleratorGpu, false); - EXPECT_THAT(hardware_support & kLiteRtHwAcceleratorNpu, false); - - EXPECT_THAT(LiteRtGetAcceleratorHardwareSupport(nullptr, &hardware_support), - kLiteRtStatusErrorInvalidArgument); - EXPECT_THAT(LiteRtGetAcceleratorHardwareSupport(accelerator, nullptr), - kLiteRtStatusErrorInvalidArgument); - // Make the accelerator invalid. - accelerator->GetHardwareSupport = nullptr; - EXPECT_THAT( - LiteRtGetAcceleratorHardwareSupport(accelerator, &hardware_support), - kLiteRtStatusErrorInvalidArgument); -} - -TEST_F(LiteRtAcceleratorTest, - IsAcceleratorDelegateResponsibleForJitCompilationWorks) { - LiteRtParamIndex num_accelerators = 0; - ASSERT_THAT(LiteRtGetNumAccelerators(env_, &num_accelerators), - kLiteRtStatusOk); - ASSERT_THAT(num_accelerators, 1); - - LiteRtAccelerator accelerator; - ASSERT_THAT(LiteRtGetAccelerator(env_, 0, &accelerator), kLiteRtStatusOk); - bool does_jit_compilation; - ASSERT_THAT(LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - accelerator, &does_jit_compilation), - kLiteRtStatusOk); - EXPECT_THAT(does_jit_compilation, false); - - EXPECT_THAT(LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - nullptr, &does_jit_compilation), - kLiteRtStatusErrorInvalidArgument); - EXPECT_THAT(LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - accelerator, nullptr), - kLiteRtStatusErrorInvalidArgument); - - // Add an implementation to the function. - accelerator->IsTfLiteDelegateResponsibleForJitCompilation = - [](LiteRtAccelerator, bool* does_jit) { - *does_jit = true; - return kLiteRtStatusOk; - }; - EXPECT_THAT(LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - accelerator, &does_jit_compilation), - kLiteRtStatusOk); - EXPECT_THAT(does_jit_compilation, true); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_any.h b/tensorflow/lite/experimental/litert/c/litert_any.h deleted file mode 100644 index e8e67b0c80f239..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_any.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ANY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ANY_H_ - -#include // NOLINT: To use bool type in C -#include - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef enum { - kLiteRtAnyTypeNone = 0, - kLiteRtAnyTypeBool = 1, - kLiteRtAnyTypeInt = 2, - kLiteRtAnyTypeReal = 3, - kLiteRtAnyTypeString = 8, - kLiteRtAnyTypeVoidPtr = 9, -} LiteRtAnyType; - -inline const char* LiteRtAnyTypeToString(LiteRtAnyType type) { - switch (type) { - case kLiteRtAnyTypeNone: - return "kLiteRtAnyTypeNone"; - case kLiteRtAnyTypeBool: - return "kLiteRtAnyTypeBool"; - case kLiteRtAnyTypeInt: - return "kLiteRtAnyTypeInt"; - case kLiteRtAnyTypeReal: - return "kLiteRtAnyTypeReal"; - case kLiteRtAnyTypeString: - return "kLiteRtAnyTypeString"; - case kLiteRtAnyTypeVoidPtr: - return "kLiteRtAnyTypeVoidPtr"; - } - return "Unknown"; -} - -typedef struct { - LiteRtAnyType type; - union { - bool bool_value; - int64_t int_value; - double real_value; - const char* str_value; - const void* ptr_value; - }; -} LiteRtAny; - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ANY_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_c_api_common_test.c b/tensorflow/lite/experimental/litert/c/litert_c_api_common_test.c deleted file mode 100644 index f4aa75e231c297..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_c_api_common_test.c +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file exists to verify that the below header files can build, link, -// and run as C code. -#ifdef __cplusplus -#error "This file should be compiled as C code, not as C++." -#endif - -// Include all the header files in the litert/c directory. -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_any.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_common.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_event.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_layout.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_model.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_options.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" // NOLINT -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" // NOLINT - -int main(void) { return 0; } diff --git a/tensorflow/lite/experimental/litert/c/litert_common.cc b/tensorflow/lite/experimental/litert/c/litert_common.cc deleted file mode 100644 index adbecb6259f2f8..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_common.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -extern "C" { - -const char* LiteRtGetStatusString(LiteRtStatus status) { - switch (status) { - // NOLINTNEXTLINE(preprocessor-macros) -#define LITERT_STATUS_STR_CASE(STATUS) \ - case STATUS: \ - return #STATUS; - LITERT_STATUS_STR_CASE(kLiteRtStatusOk); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorInvalidArgument); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorMemoryAllocationFailure); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorRuntimeFailure); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorMissingInputTensor); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorUnsupported); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorNotFound); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorTimeoutExpired); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorFileIO); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorInvalidFlatbuffer); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorDynamicLoading); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorSerialization); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorCompilation); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorIndexOOB); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorInvalidIrType); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorInvalidGraphInvariant); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorGraphModification); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorInvalidToolConfig); - LITERT_STATUS_STR_CASE(kLiteRtStatusLegalizeNoMatch); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorInvalidLegalization); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorWrongVersion); - LITERT_STATUS_STR_CASE(kLiteRtStatusErrorUnknown); -#undef LITERT_STATUS_STR_CASE - } -} - -} // extern "C" diff --git a/tensorflow/lite/experimental/litert/c/litert_common.h b/tensorflow/lite/experimental/litert/c/litert_common.h deleted file mode 100644 index 4ed0ca5407c4ad..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_common.h +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ - -#include - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// Declares canonical opaque type. -#define LITERT_DEFINE_HANDLE(name) typedef struct name##T* name - -#if __ANDROID_API__ >= 26 -#define LITERT_HAS_AHWB_SUPPORT 1 -#else -#define LITERT_HAS_AHWB_SUPPORT 0 -#endif // __ANDROID_API__ >= 26 - -#if defined(__linux__) || defined(__ANDROID__) -#define LITERT_HAS_SYNC_FENCE_SUPPORT 1 -#else -#define LITERT_HAS_SYNC_FENCE_SUPPORT 0 -#endif - -#if defined(__ANDROID__) -#define LITERT_HAS_ION_SUPPORT 1 -#define LITERT_HAS_DMABUF_SUPPORT 1 -#define LITERT_HAS_FASTRPC_SUPPORT 1 -#define LITERT_HAS_OPENGL_SUPPORT 1 -#define LITERT_HAS_OPENCL_SUPPORT_DEFAULT 1 -// copybara:comment_begin(google-only) -#elif defined(GOOGLE_UNSUPPORTED_OS_LOONIX) -#define LITERT_HAS_ION_SUPPORT 0 -#define LITERT_HAS_DMABUF_SUPPORT 1 -#define LITERT_HAS_FASTRPC_SUPPORT 0 -#define LITERT_HAS_OPENCL_SUPPORT_DEFAULT 1 -// copybara:comment_end -#else -#define LITERT_HAS_ION_SUPPORT 0 -#define LITERT_HAS_DMABUF_SUPPORT 0 -#define LITERT_HAS_FASTRPC_SUPPORT 0 -#define LITERT_HAS_OPENCL_SUPPORT_DEFAULT 1 -#define LITERT_HAS_OPENGL_SUPPORT 0 -#endif - -#if defined(LITERT_DISABLE_OPENCL_SUPPORT) -#define LITERT_HAS_OPENCL_SUPPORT 0 -#else -#define LITERT_HAS_OPENCL_SUPPORT LITERT_HAS_OPENCL_SUPPORT_DEFAULT -#endif - -#define LITERT_API_VERSION_MAJOR 0 -#define LITERT_API_VERSION_MINOR 1 -#define LITERT_API_VERSION_PATCH 0 - -typedef struct LiteRtApiVersion { - int major; - int minor; - int patch; -} LiteRtApiVersion; - -typedef enum { - kLiteRtStatusOk = 0, - - // Generic errors. - kLiteRtStatusErrorInvalidArgument = 1, - kLiteRtStatusErrorMemoryAllocationFailure = 2, - kLiteRtStatusErrorRuntimeFailure = 3, - kLiteRtStatusErrorMissingInputTensor = 4, - kLiteRtStatusErrorUnsupported = 5, - kLiteRtStatusErrorNotFound = 6, - kLiteRtStatusErrorTimeoutExpired = 7, - kLiteRtStatusErrorWrongVersion = 8, - kLiteRtStatusErrorUnknown = 9, - - // File and loading related errors. - kLiteRtStatusErrorFileIO = 500, - kLiteRtStatusErrorInvalidFlatbuffer = 501, - kLiteRtStatusErrorDynamicLoading = 502, - kLiteRtStatusErrorSerialization = 503, - kLiteRtStatusErrorCompilation = 504, - - // IR related errors. - kLiteRtStatusErrorIndexOOB = 1000, - kLiteRtStatusErrorInvalidIrType = 1001, - kLiteRtStatusErrorInvalidGraphInvariant = 1002, - kLiteRtStatusErrorGraphModification = 1003, - - // Tool related errors. - kLiteRtStatusErrorInvalidToolConfig = 1500, - - // Legalization related errors. - kLiteRtStatusLegalizeNoMatch = 2000, - kLiteRtStatusErrorInvalidLegalization = 2001, -} LiteRtStatus; - -// Returns a string describing the status value. -const char* LiteRtGetStatusString(LiteRtStatus status); - -typedef enum : int { - kLiteRtHwAcceleratorNone = 0, - kLiteRtHwAcceleratorCpu = 1 << 0, - kLiteRtHwAcceleratorGpu = 1 << 1, - kLiteRtHwAcceleratorNpu = 1 << 2, -} LiteRtHwAccelerators; - -// A bit field of `LiteRtHwAccelerators` values. -typedef int LiteRtHwAcceleratorSet; - -// For indexing into LiteRT collections or counting LiteRT things. -typedef size_t LiteRtParamIndex; - -#if defined(_WIN32) -// Provides posix_memalign() missing in Windows. -#include - -#define posix_memalign(p, a, s) \ - (((*(p)) = _aligned_malloc((s), (a))), *(p) ? 0 : errno) -#endif // defined(_WIN32) - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_common_test.cc b/tensorflow/lite/experimental/litert/c/litert_common_test.cc deleted file mode 100644 index be0993c1ce4733..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_common_test.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#include -#include - -namespace { - -using testing::Eq; -using testing::Gt; -using testing::Lt; -using testing::StrEq; - -TEST(GetStatusString, Works) { - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusOk), StrEq("kLiteRtStatusOk")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorInvalidArgument), - StrEq("kLiteRtStatusErrorInvalidArgument")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorMemoryAllocationFailure), - StrEq("kLiteRtStatusErrorMemoryAllocationFailure")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorRuntimeFailure), - StrEq("kLiteRtStatusErrorRuntimeFailure")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorMissingInputTensor), - StrEq("kLiteRtStatusErrorMissingInputTensor")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorUnsupported), - StrEq("kLiteRtStatusErrorUnsupported")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorNotFound), - StrEq("kLiteRtStatusErrorNotFound")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorTimeoutExpired), - StrEq("kLiteRtStatusErrorTimeoutExpired")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorFileIO), - StrEq("kLiteRtStatusErrorFileIO")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorInvalidFlatbuffer), - StrEq("kLiteRtStatusErrorInvalidFlatbuffer")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorDynamicLoading), - StrEq("kLiteRtStatusErrorDynamicLoading")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorSerialization), - StrEq("kLiteRtStatusErrorSerialization")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorCompilation), - StrEq("kLiteRtStatusErrorCompilation")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorIndexOOB), - StrEq("kLiteRtStatusErrorIndexOOB")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorInvalidIrType), - StrEq("kLiteRtStatusErrorInvalidIrType")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorInvalidGraphInvariant), - StrEq("kLiteRtStatusErrorInvalidGraphInvariant")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorGraphModification), - StrEq("kLiteRtStatusErrorGraphModification")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorInvalidToolConfig), - StrEq("kLiteRtStatusErrorInvalidToolConfig")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusLegalizeNoMatch), - StrEq("kLiteRtStatusLegalizeNoMatch")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorInvalidLegalization), - StrEq("kLiteRtStatusErrorInvalidLegalization")); - EXPECT_THAT(LiteRtGetStatusString(kLiteRtStatusErrorWrongVersion), - StrEq("kLiteRtStatusErrorWrongVersion")); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_compilation_options.cc b/tensorflow/lite/experimental/litert/c/litert_compilation_options.cc deleted file mode 100644 index 0ff32733d13d1e..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_compilation_options.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/runtime/compilation_options.h" - -#define LRT_CHECK_NON_NULL(handle) \ - if (!(handle)) { \ - LITERT_LOG(LITERT_ERROR, #handle " must not be null."); \ - return kLiteRtStatusErrorInvalidArgument; \ - } - -extern "C" { - -LiteRtStatus LiteRtCreateCompilationOptions(LiteRtCompilationOptions* options) { - LRT_CHECK_NON_NULL(options); - *options = new LiteRtCompilationOptionsT; - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompilationOptions(LiteRtCompilationOptions options) { - delete options; -} - -LiteRtStatus LiteRtSetCompilationOptionsHardwareAccelerators( - LiteRtCompilationOptions options, - LiteRtHwAcceleratorSet hardware_accelerators) { - LRT_CHECK_NON_NULL(options); - if ((hardware_accelerators & - (kLiteRtHwAcceleratorCpu | kLiteRtHwAcceleratorGpu | - kLiteRtHwAcceleratorNpu)) != hardware_accelerators) { - LITERT_LOG(LITERT_ERROR, - "Invalid bitfield value for hardware accelerator set: %d.", - hardware_accelerators); - return kLiteRtStatusErrorInvalidArgument; - } - options->hardware_accelerators = hardware_accelerators; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilationOptionsHardwareAccelerators( - LiteRtCompilationOptions options, - LiteRtHwAcceleratorSet* hardware_accelerators) { - LRT_CHECK_NON_NULL(options); - LRT_CHECK_NON_NULL(hardware_accelerators); - *hardware_accelerators = options->hardware_accelerators; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtAddAcceleratorCompilationOptions( - LiteRtCompilationOptions options, - LiteRtAcceleratorCompilationOptions accelerator_compilation_options) { - LRT_CHECK_NON_NULL(options); - LRT_CHECK_NON_NULL(accelerator_compilation_options); - LITERT_RETURN_IF_ERROR(options->accelerator_compilation_options.Append( - litert::AcceleratorCompilationOptions(accelerator_compilation_options, - /*owned=*/false))); - return kLiteRtStatusOk; -} - -// Retrieves the head of the accelerator compilation option list. -// -// Note: The following elements may be retrieved with -// `LiteRtGetNextAcceleratorCompilationOptions`. -LiteRtStatus LiteRtGetAcceleratorCompilationOptions( - LiteRtCompilationOptions options, - LiteRtAcceleratorCompilationOptions* accelerator_compilation_options) { - LRT_CHECK_NON_NULL(options); - LRT_CHECK_NON_NULL(accelerator_compilation_options); - *accelerator_compilation_options = - options->accelerator_compilation_options.Get(); - return kLiteRtStatusOk; -} - -} // extern "C" diff --git a/tensorflow/lite/experimental/litert/c/litert_compilation_options.h b/tensorflow/lite/experimental/litert/c/litert_compilation_options.h deleted file mode 100644 index d27aa3919ff2a2..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_compilation_options.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILATION_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILATION_OPTIONS_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// The compilation options for the LiteRtCompiledModel. -LITERT_DEFINE_HANDLE(LiteRtCompilationOptions); - -// Creates a compilation option object. -LiteRtStatus LiteRtCreateCompilationOptions(LiteRtCompilationOptions* options); - -// Destroys a compilation option object. -void LiteRtDestroyCompilationOptions(LiteRtCompilationOptions options); - -// Sets the requested hardware accelerators to apply during model compilation. -LiteRtStatus LiteRtSetCompilationOptionsHardwareAccelerators( - LiteRtCompilationOptions options, - LiteRtHwAcceleratorSet hardware_accelerators); - -// Gets the hardware accelerators to apply during model compilation. -LiteRtStatus LiteRtGetCompilationOptionsHardwareAccelerators( - LiteRtCompilationOptions options, - LiteRtHwAcceleratorSet* hardware_accelerators); - -// Adds compilation options for a specific accelerator to the accelerator -// compilation option list. -// -// Note: Multiple accelerator options may be added to the options object. -// -// Note: `accelerator_compilation_options`'s ownership is transferred to -// `options`. -LiteRtStatus LiteRtAddAcceleratorCompilationOptions( - LiteRtCompilationOptions options, - LiteRtAcceleratorCompilationOptions accelerator_compilation_options); - -// Retrieves the head of the accelerator compilation option list. -// -// Note: The following elements may be retrieved with -// `LiteRtGetNextAcceleratorCompilationOptions`. -LiteRtStatus LiteRtGetAcceleratorCompilationOptions( - LiteRtCompilationOptions options, - LiteRtAcceleratorCompilationOptions* accelerator_compilation_options); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILATION_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_compilation_options_test.cc b/tensorflow/lite/experimental/litert/c/litert_compilation_options_test.cc deleted file mode 100644 index 3941105430e68e..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_compilation_options_test.cc +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" - -#include -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace { - -TEST(LiteRtCompiledModelOptionsTest, CreateAndDestroyDontLeak) { - LiteRtCompilationOptions options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&options), kLiteRtStatusOk); - LiteRtDestroyCompilationOptions(options); -} - -TEST(LiteRtCompiledModelOptionsTest, CreateWithANullPointerErrors) { - EXPECT_EQ(LiteRtCreateCompilationOptions(nullptr), - kLiteRtStatusErrorInvalidArgument); -} - -TEST(LiteRtCompiledModelOptionsTest, SetAndGetHardwareAcceleratorsWorks) { - LiteRtCompilationOptions options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&options), kLiteRtStatusOk); - - LiteRtHwAcceleratorSet hardware_accelerators; - - EXPECT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - options, kLiteRtHwAcceleratorNone), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetCompilationOptionsHardwareAccelerators( - options, &hardware_accelerators), - kLiteRtStatusOk); - EXPECT_EQ(hardware_accelerators, kLiteRtHwAcceleratorNone); - - EXPECT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - options, kLiteRtHwAcceleratorCpu), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetCompilationOptionsHardwareAccelerators( - options, &hardware_accelerators), - kLiteRtStatusOk); - EXPECT_EQ(hardware_accelerators, kLiteRtHwAcceleratorCpu); - - EXPECT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - options, kLiteRtHwAcceleratorGpu), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetCompilationOptionsHardwareAccelerators( - options, &hardware_accelerators), - kLiteRtStatusOk); - EXPECT_EQ(hardware_accelerators, kLiteRtHwAcceleratorGpu); - - EXPECT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - options, kLiteRtHwAcceleratorNpu), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetCompilationOptionsHardwareAccelerators( - options, &hardware_accelerators), - kLiteRtStatusOk); - EXPECT_EQ(hardware_accelerators, kLiteRtHwAcceleratorNpu); - - EXPECT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - options, (kLiteRtHwAcceleratorCpu | kLiteRtHwAcceleratorGpu | - kLiteRtHwAcceleratorNpu) + - 1), - kLiteRtStatusErrorInvalidArgument); - EXPECT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - nullptr, kLiteRtHwAcceleratorNone), - kLiteRtStatusErrorInvalidArgument); - - LiteRtDestroyCompilationOptions(options); -} - -struct DummyAcceleratorCompilationOptions { - static constexpr const LiteRtApiVersion kVersion = {1, 0, 0}; - static constexpr const char* const kIdentifier = "dummy-accelerator"; - - // Allocates and sets the basic structure for the accelerator options. - static litert::Expected CreateOptions() { - LiteRtAcceleratorCompilationOptions options; - auto* payload = new DummyAcceleratorCompilationOptions; - auto payload_destructor = [](void* payload) { - delete reinterpret_cast(payload); - }; - LITERT_RETURN_IF_ERROR(LiteRtCreateAcceleratorCompilationOptions( - &kVersion, kIdentifier, payload, payload_destructor, &options)); - return options; - } -}; - -TEST(LiteRtCompiledModelOptionsTest, AddAcceleratorCompilationOptionsWorks) { - LiteRtCompilationOptions options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&options), kLiteRtStatusOk); - - auto accelerator_compilation_options1 = - DummyAcceleratorCompilationOptions::CreateOptions(); - EXPECT_TRUE(accelerator_compilation_options1); - auto accelerator_compilation_options2 = - DummyAcceleratorCompilationOptions::CreateOptions(); - EXPECT_TRUE(accelerator_compilation_options2); - - EXPECT_EQ(LiteRtAddAcceleratorCompilationOptions( - nullptr, *accelerator_compilation_options1), - kLiteRtStatusErrorInvalidArgument); - EXPECT_EQ(LiteRtAddAcceleratorCompilationOptions(options, nullptr), - kLiteRtStatusErrorInvalidArgument); - - EXPECT_EQ(LiteRtAddAcceleratorCompilationOptions( - options, *accelerator_compilation_options1), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtAddAcceleratorCompilationOptions( - options, *accelerator_compilation_options2), - kLiteRtStatusOk); - - LiteRtAcceleratorCompilationOptions options_it = nullptr; - EXPECT_EQ(LiteRtGetAcceleratorCompilationOptions(options, &options_it), - kLiteRtStatusOk); - EXPECT_EQ(options_it, *accelerator_compilation_options1); - - EXPECT_EQ(LiteRtGetNextAcceleratorCompilationOptions(&options_it), - kLiteRtStatusOk); - EXPECT_EQ(options_it, *accelerator_compilation_options2); - - LiteRtDestroyCompilationOptions(options); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_compiled_model.cc b/tensorflow/lite/experimental/litert/c/litert_compiled_model.cc deleted file mode 100644 index 295fbf32c40596..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_compiled_model.cc +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" - -#include - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/runtime/compiled_model.h" - -#ifdef __cplusplus -extern "C" { -#endif - -LiteRtStatus LiteRtCreateCompiledModel( - LiteRtEnvironment environment, LiteRtModel model, - LiteRtCompilationOptions jit_compilation_options, - LiteRtCompiledModel* compiled_model) { - if (!environment || !model || !compiled_model) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_compiled_model = - LiteRtCompiledModelT::Create(environment, model, jit_compilation_options); - if (!created_compiled_model) { - LITERT_LOG(LITERT_ERROR, "%s", - created_compiled_model.Error().Message().c_str()); - return created_compiled_model.Error().Status(); - } - *compiled_model = created_compiled_model->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledModelInputBufferRequirements( - LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, - LiteRtParamIndex input_index, - LiteRtTensorBufferRequirements* buffer_requirements) { - if (!compiled_model || !buffer_requirements) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto res = compiled_model->GetInputBufferRequirementsCApi(signature_index, - input_index); - if (!res) { - LITERT_LOG(LITERT_ERROR, "%s", res.Error().Message().c_str()); - return res.Error().Status(); - } - *buffer_requirements = res.Value(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledModelOutputBufferRequirements( - LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, - LiteRtParamIndex output_index, - LiteRtTensorBufferRequirements* buffer_requirements) { - if (!compiled_model || !buffer_requirements) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto res = compiled_model->GetOutputBufferRequirementsCApi(signature_index, - output_index); - if (!res) { - LITERT_LOG(LITERT_ERROR, "%s", res.Error().Message().c_str()); - return res.Error().Status(); - } - *buffer_requirements = res.Value(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtRunCompiledModel(LiteRtCompiledModel compiled_model, - LiteRtParamIndex signature_index, - size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers) { - if (!compiled_model || (num_input_buffers > 0 && !input_buffers) || - (num_output_buffers > 0 && !output_buffers)) { - return kLiteRtStatusErrorInvalidArgument; - } - - bool async = false; - auto res = - compiled_model->RunCApi(signature_index, num_input_buffers, input_buffers, - num_output_buffers, output_buffers, &async); - if (!res) { - LITERT_LOG(LITERT_ERROR, "%s", res.Error().Message().c_str()); - return res.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtRunCompiledModelAsync(LiteRtCompiledModel compiled_model, - LiteRtParamIndex signature_index, - size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers, - bool* async) { - if (!compiled_model || (num_input_buffers > 0 && !input_buffers) || - (num_output_buffers > 0 && !output_buffers)) { - return kLiteRtStatusErrorInvalidArgument; - } - - if (async) { - *async = true; - } - bool async_ = true; - bool* async_ptr = async ? async : &async_; - - auto res = - compiled_model->RunCApi(signature_index, num_input_buffers, input_buffers, - num_output_buffers, output_buffers, async_ptr); - if (!res) { - LITERT_LOG(LITERT_ERROR, "%s", res.Error().Message().c_str()); - return res.Error().Status(); - } - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompiledModel(LiteRtCompiledModel compiled_model) { - delete compiled_model; -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_compiled_model.h b/tensorflow/lite/experimental/litert/c/litert_compiled_model.h deleted file mode 100644 index 76df573c5ea9e9..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_compiled_model.h +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// The LiteRtCompiledModel is a higher level inference API. It is created by -// provided model with compilation options. Internally, it instantiates runtime -// and applies Delegates mapped to the compilation options. -// It also supports getting LiteRtTensorBufferRequirements to create -// input/output TensorBuffers, and it allows to invoke the model with the -// input/output TensorBuffers. -// -// Example user flow: -// -// 1. Create LiteRtCompiledModel -// 2. Query the model input/output LiteRtTensorBufferRequirements -// 3. Create input/output LiteRtTensorBuffer -// 4. Fill the input LiteRtTensorBuffer with input data -// 5. Invoke the model with the input/output LiteRtTensorBuffer -// 6. Evaluate the output LiteRtTensorBuffer - -LITERT_DEFINE_HANDLE(LiteRtCompiledModel); - -// Creates a LiteRtCompiledModel from a LiteRtModel object. Parameter -// `jit_compilation_options` is optional and can be null, and is owned by the -// caller. The model is loaded into memory and the caller takes ownership of -// the returned object. -LiteRtStatus LiteRtCreateCompiledModel( - LiteRtEnvironment environment, LiteRtModel model, - LiteRtCompilationOptions compilation_options, - LiteRtCompiledModel* compiled_model); - -// Returns the buffer requirements for the given n-th input tensor. The returned -// LiteRtTensorBufferRequirements is used to create the input tensor -// buffer. -// -// Parameters: -// - compiled_model: the target `LiteRtCompiledModel` object. -// - signature_index: the index of the signature in `LiteRtModel`. -// - input_index: the index of the input tensor in the signature (subgraph). -// - buffer_requirements: the returned `LiteRtTensorBufferRequirements`. -LiteRtStatus LiteRtGetCompiledModelInputBufferRequirements( - LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, - LiteRtParamIndex input_index, - LiteRtTensorBufferRequirements* buffer_requirements); - -// Returns the buffer requirements for the given n-th output tensor. The -// returned LiteRtTensorBufferRequirements is used to create the output tensor -// buffer. -// -// Parameters: -// - compiled_model: the target `LiteRtCompiledModel` object. -// - signature_index: the index of the signature in `LiteRtModel`. -// - input_index: the index of the input tensor in the signature (subgraph). -// - buffer_requirements: the returned `LiteRtTensorBufferRequirements`. -LiteRtStatus LiteRtGetCompiledModelOutputBufferRequirements( - LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, - LiteRtParamIndex output_index, - LiteRtTensorBufferRequirements* buffer_requirements); - -// Runs the model of the given signature synchronously, with the provided -// input/output LiteRtTensorBuffer. -// -// Parameters: -// - compiled_model: the target `LiteRtCompiledModel` object. -// - signature_index: the index of the signature in `LiteRtModel`. -// - num_input_buffers: the number of input `LiteRtTensorBuffer`. -// - input_buffers: the array of input `LiteRtTensorBuffer`. -// - num_output_buffers: the number of output `LiteRtTensorBuffer`. -// - output_buffers: the array of output LiteRtTensorBuffer. -LiteRtStatus LiteRtRunCompiledModel(LiteRtCompiledModel compiled_model, - LiteRtParamIndex signature_index, - size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers); - -// Runs the model of the given signature asynchronously, if possible, with the -// provided input/output LiteRtTensorBuffers. If asynchronous execution is -// possible, then the function sets parameter `async` to true; if asynchronous -// execution is not possible, then the function runs the model synchronously and -// sets parameter `async` to false. Note that: -// -// - Asynchronous execution is possible only in certain cases, based on the ops -// included in the model, the selected HW accelerator(s), and the capability -// of the user device hardware. -// -// - If asynchronous execution is indeed possible, it may be that only some -// parts of the model are run asynchronously (e.g., ops mapped to the GPU) -// while other parts of the model are still run synchronously with the -// invocation of this call (e.g., ops mapped to the CPU). -// -// - In case of asynchronous execution some or all of the output tensor buffers -// will have a synchronization event attached to them and the caller is -// responsible for passing such events to a downstream processing step. -// -// Parameters: -// - async: optional boolean to let the caller know if the model is being run -// asynchronously. -LiteRtStatus LiteRtRunCompiledModelAsync( - LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, - size_t num_input_buffers, LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, LiteRtTensorBuffer* output_buffers, bool* async); - -void LiteRtDestroyCompiledModel(LiteRtCompiledModel compiled_model); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_compiled_model_test.cc b/tensorflow/lite/experimental/litert/c/litert_compiled_model_test.cc deleted file mode 100644 index 6aa617bed4f551..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_compiled_model_test.cc +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" - -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -using testing::FloatNear; -using testing::Pointwise; - -namespace litert { -namespace { - -TEST(CompiledModelTest, Basic) { - auto path = testing::GetTestFilePath(kModelFileName); - - LiteRtModel model; - ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk); - - LiteRtCompilationOptions jit_compilation_options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&jit_compilation_options), - kLiteRtStatusOk); - ASSERT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - jit_compilation_options, kLiteRtHwAcceleratorCpu), - kLiteRtStatusOk); - - LiteRtEnvironment environment; - LiteRtEnvOption options = {}; - ASSERT_EQ(LiteRtEnvironmentCreate(/*num_options=*/0, &options, &environment), - kLiteRtStatusOk); - - LiteRtCompiledModel compiled_model; - ASSERT_EQ(LiteRtCreateCompiledModel(environment, model, - jit_compilation_options, &compiled_model), - kLiteRtStatusOk); - - LiteRtDestroyCompilationOptions(jit_compilation_options); - - LiteRtSubgraph subgraph; - ASSERT_EQ(LiteRtGetModelSubgraph(model, 0, &subgraph), kLiteRtStatusOk); - - LiteRtParamIndex num_inputs; - ASSERT_EQ(LiteRtGetNumSubgraphInputs(subgraph, &num_inputs), kLiteRtStatusOk); - - std::vector input_tensor_buffers; - input_tensor_buffers.reserve(num_inputs); - for (auto i = 0; i < num_inputs; ++i) { - LiteRtTensorBufferRequirements tensor_buffer_requirements; - ASSERT_EQ(LiteRtGetCompiledModelInputBufferRequirements( - compiled_model, /*signature_index=*/0, i, - &tensor_buffer_requirements), - kLiteRtStatusOk); - LiteRtTensorBufferType tensor_buffer_type; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - tensor_buffer_requirements, /*type_index=*/0, &tensor_buffer_type), - kLiteRtStatusOk); - size_t tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - tensor_buffer_requirements, &tensor_buffer_size), - kLiteRtStatusOk); - LiteRtTensorBuffer tensor_buffer; - EXPECT_EQ( - LiteRtCreateManagedTensorBuffer(tensor_buffer_type, &kInput0TensorType, - tensor_buffer_size, &tensor_buffer), - kLiteRtStatusOk); - input_tensor_buffers.push_back(tensor_buffer); - } - - LiteRtParamIndex num_outputs; - ASSERT_EQ(LiteRtGetNumSubgraphOutputs(subgraph, &num_outputs), - kLiteRtStatusOk); - - std::vector output_tensor_buffers; - output_tensor_buffers.reserve(num_outputs); - for (auto i = 0; i < num_outputs; ++i) { - LiteRtTensorBufferRequirements tensor_buffer_requirements; - ASSERT_EQ(LiteRtGetCompiledModelOutputBufferRequirements( - compiled_model, /*signature_index=*/0, i, - &tensor_buffer_requirements), - kLiteRtStatusOk); - LiteRtTensorBufferType tensor_buffer_type; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - tensor_buffer_requirements, /*type_index=*/0, &tensor_buffer_type), - kLiteRtStatusOk); - size_t tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - tensor_buffer_requirements, &tensor_buffer_size), - kLiteRtStatusOk); - LiteRtTensorBuffer tensor_buffer; - EXPECT_EQ( - LiteRtCreateManagedTensorBuffer(tensor_buffer_type, &kInput0TensorType, - tensor_buffer_size, &tensor_buffer), - kLiteRtStatusOk); - output_tensor_buffers.push_back(tensor_buffer); - } - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_tensor_buffers[0], &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_tensor_buffers[0]), - kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_tensor_buffers[1], &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_tensor_buffers[1]), - kLiteRtStatusOk); - } - - ASSERT_EQ(LiteRtRunCompiledModel( - compiled_model, /*signature_index=*/0, - input_tensor_buffers.size(), input_tensor_buffers.data(), - output_tensor_buffers.size(), output_tensor_buffers.data()), - kLiteRtStatusOk); - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffers[0], &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffers[0]), - kLiteRtStatusOk); - } - - LiteRtDestroyCompiledModel(compiled_model); - LiteRtDestroyModel(model); - LiteRtDestroyEnvironment(environment); - - for (auto tensor_buffer : input_tensor_buffers) { - LiteRtDestroyTensorBuffer(tensor_buffer); - } - for (auto tensor_buffer : output_tensor_buffers) { - LiteRtDestroyTensorBuffer(tensor_buffer); - } -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h b/tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h deleted file mode 100644 index 7186bf794db7fe..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ - -#include - -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -typedef struct LiteRtDispatchDelegateOptions LiteRtDispatchDelegateOptions; -typedef struct LiteRtEnvironmentT* LiteRtEnvironment; - -// Returns DispatchDelegateOptions populated with default values. -LiteRtDispatchDelegateOptions* LiteRtCreateDefaultDispatchDelegateOptions( - LiteRtEnvironment environment); - -TfLiteStatus LiteRtAddDispatchDelegateOption( - LiteRtDispatchDelegateOptions* options, LiteRtDispatchOption option); - -void LiteRtDestroyDispatchDelegateOptions( - LiteRtDispatchDelegateOptions* options); - -// Create a delegate that uses the Dispatch API for execution. Takes ownership -// of the passed `options`. Must outlive the TFL interpreter. -TfLiteOpaqueDelegate* LiteRtCreateDispatchDelegate( - LiteRtEnvironmentOptions environment_options, - LiteRtDispatchDelegateOptions* options); - -// Do any needed cleanup and delete 'delegate'. -void LiteRtDestroyDispatchDelegate(TfLiteOpaqueDelegate* delegate); - -// -// Common option helpers -// - -// Alloc base is the address of the first byte of flatbuffer model in memory. It -// is used by ops to find the start of npu byte code appended to the file. -TfLiteStatus LiteRtDispatchDelegateAddAllocBaseOption( - LiteRtDispatchDelegateOptions* options, const void* alloc_base); - -// Alloc fd is the file descriptor for an mmapped flatbuffer. It is used by ops -// to find the start of npu byte code appended to the file. -TfLiteStatus LiteRtDispatchDelegateAddAllocFdOption( - LiteRtDispatchDelegateOptions* options, int alloc_fd); - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_environment.cc b/tensorflow/lite/experimental/litert/c/litert_environment.cc deleted file mode 100644 index 702cf90e113379..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_environment.cc +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.h" -#include "tensorflow/lite/experimental/litert/runtime/gpu_environment.h" - -#ifdef __cplusplus -extern "C" { -#endif - -LiteRtStatus LiteRtEnvironmentCreate(int num_options, - const LiteRtEnvOption* options, - LiteRtEnvironment* environment) { - LITERT_RETURN_IF_ERROR(environment != nullptr, - kLiteRtStatusErrorInvalidArgument); - LITERT_ASSIGN_OR_RETURN(auto env, LiteRtEnvironmentT::CreateWithOptions( - absl::MakeSpan(options, num_options))); - litert::TriggerAcceleratorAutomaticRegistration(*env); - *environment = env.release(); - return kLiteRtStatusOk; -} - -void LiteRtDestroyEnvironment(LiteRtEnvironment environment) { - if (environment != nullptr) { - delete environment; - } -} - -LiteRtStatus LiteRtGetEnvironmentOptions(LiteRtEnvironment environment, - LiteRtEnvironmentOptions* options) { - LITERT_RETURN_IF_ERROR( - environment, litert::ErrorStatusBuilder(kLiteRtStatusErrorInvalidArgument) - << "Environment pointer is null."); - LITERT_RETURN_IF_ERROR( - options, litert::ErrorStatusBuilder(kLiteRtStatusErrorInvalidArgument) - << "Options pointer is null."); - *options = &environment->GetOptions(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGpuGlobalEnvironmentCreate(int num_options, - const LiteRtEnvOption* options) { - LITERT_ASSIGN_OR_RETURN(auto env, LiteRtEnvironmentT::CreateWithOptions( - absl::MakeSpan(options, num_options))); - litert::internal::GpuEnvironmentSingleton::Create(env.get()); - return kLiteRtStatusOk; -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_environment.h b/tensorflow/lite/experimental/litert/c/litert_environment.h deleted file mode 100644 index 20b834df645513..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_environment.h +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITERT_DEFINE_HANDLE(LiteRtEnvironment); - -// Create a LiteRT environment with options. -// Used to set the path of the compiler plugin library and dispatch library. -// -// Note: options of kLiteRtEnvOptionTagOpenCl* shouldn't be set with this API. -LiteRtStatus LiteRtEnvironmentCreate(int num_options, - const LiteRtEnvOption* options, - LiteRtEnvironment* environment); - -// Destroy a created LiteRT environment. -void LiteRtDestroyEnvironment(LiteRtEnvironment environment); - -// Get the options that the environment was created with. -LiteRtStatus LiteRtGetEnvironmentOptions(LiteRtEnvironment environment, - LiteRtEnvironmentOptions* options); - -// Create a LiteRT GPU global environment with options. -// This API is usually called by the GPU accelerator implementation to set GPU -// environment options which affect the entire LiteRT runtime. -// -// Note: In most cases, users should not call this API directly. -LiteRtStatus LiteRtGpuGlobalEnvironmentCreate(int num_options, - const LiteRtEnvOption* options); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_environment_options.cc b/tensorflow/lite/experimental/litert/c/litert_environment_options.cc deleted file mode 100644 index daf14bb79a046f..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_environment_options.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" - -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/environment_options.h" - -extern "C" { - -LiteRtStatus LiteRtGetEnvironmentOptionsValue(LiteRtEnvironmentOptions options, - LiteRtEnvOptionTag tag, - LiteRtAny* value) { - LITERT_RETURN_IF_ERROR( - options, litert::ErrorStatusBuilder(kLiteRtStatusErrorInvalidArgument)) - << "`options` handle is null."; - LITERT_RETURN_IF_ERROR( - value, litert::ErrorStatusBuilder(kLiteRtStatusErrorInvalidArgument)) - << "`value` handle is null."; - LITERT_ASSIGN_OR_RETURN(*value, options->GetOption(tag)); - return kLiteRtStatusOk; -} - -} // extern "C" diff --git a/tensorflow/lite/experimental/litert/c/litert_environment_options.h b/tensorflow/lite/experimental/litert/c/litert_environment_options.h deleted file mode 100644 index ab778e230f51d3..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_environment_options.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_OPTIONS_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef enum { - kLiteRtEnvOptionTagCompilerPluginLibraryDir = 0, - kLiteRtEnvOptionTagDispatchLibraryDir = 1, - kLiteRtEnvOptionTagOpenClDeviceId = 2, - kLiteRtEnvOptionTagOpenClPlatformId = 3, - kLiteRtEnvOptionTagOpenClContext = 4, - kLiteRtEnvOptionTagOpenClCommandQueue = 5, -} LiteRtEnvOptionTag; - -typedef struct { - LiteRtEnvOptionTag tag; - LiteRtAny value; -} LiteRtEnvOption; - -LITERT_DEFINE_HANDLE(LiteRtEnvironmentOptions); - -// Retrieves the value corresponding to the given tag. -// -// Returns kLiteRtStatusErrorNotFound if the option tag is not found. -LiteRtStatus LiteRtGetEnvironmentOptionsValue(LiteRtEnvironmentOptions options, - LiteRtEnvOptionTag tag, - LiteRtAny* value); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_environment_options_test.cc b/tensorflow/lite/experimental/litert/c/litert_environment_options_test.cc deleted file mode 100644 index d827ee4005edbd..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_environment_options_test.cc +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/core/environment_options.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { - -using testing::AnyOf; -using testing::Eq; -using testing::Not; -using testing::StrEq; -using testing::litert::IsError; - -class LiteRtEnvironmentOptionsTest : public testing::Test { - public: - void SetUp() override { - constexpr const char* kStrValue = "string_value"; - dispatch_option_.tag = kLiteRtEnvOptionTagDispatchLibraryDir; - dispatch_option_.value.type = kLiteRtAnyTypeString; - dispatch_option_.value.str_value = kStrValue; - options_.SetOption(dispatch_option_); - - constexpr int kIntValue = 3; - cl_device_id_option_.tag = kLiteRtEnvOptionTagOpenClDeviceId; - cl_device_id_option_.value.type = kLiteRtAnyTypeInt; - cl_device_id_option_.value.int_value = kIntValue; - options_.SetOption(cl_device_id_option_); - - ASSERT_THAT(NotInsertedOptionTag(), - Not(AnyOf(dispatch_option_.tag, cl_device_id_option_.tag))); - } - - LiteRtEnvironmentOptions Options() { return &options_; } - const LiteRtEnvOption& DispatchOption() const { return dispatch_option_; } - const LiteRtEnvOption& ClDeviceIdOption() const { - return cl_device_id_option_; - } - - static constexpr LiteRtEnvOptionTag NotInsertedOptionTag() { - return kLiteRtEnvOptionTagOpenClPlatformId; - } - - private: - LiteRtEnvironmentOptionsT options_; - LiteRtEnvOption dispatch_option_; - LiteRtEnvOption cl_device_id_option_; -}; - -TEST_F(LiteRtEnvironmentOptionsTest, - LiteRtGetEnvironmentOptionsValueReturnsAnErrorForInvalidArguments) { - LiteRtAny option_value; - EXPECT_THAT( - LiteRtGetEnvironmentOptionsValue( - /*options=*/nullptr, kLiteRtEnvOptionTagOpenClContext, &option_value), - IsError(kLiteRtStatusErrorInvalidArgument)); - EXPECT_THAT( - LiteRtGetEnvironmentOptionsValue( - Options(), kLiteRtEnvOptionTagOpenClContext, /*value=*/nullptr), - IsError(kLiteRtStatusErrorInvalidArgument)); -} - -TEST_F(LiteRtEnvironmentOptionsTest, LiteRtGetEnvironmentOptionsValueWorks) { - LiteRtAny option_value; - LITERT_EXPECT_OK(LiteRtGetEnvironmentOptionsValue( - Options(), ClDeviceIdOption().tag, &option_value)); - EXPECT_THAT(option_value.type, Eq(ClDeviceIdOption().value.type)); - EXPECT_THAT(option_value.int_value, Eq(ClDeviceIdOption().value.int_value)); - - LITERT_EXPECT_OK(LiteRtGetEnvironmentOptionsValue( - Options(), DispatchOption().tag, &option_value)); - EXPECT_THAT(option_value.type, Eq(DispatchOption().value.type)); - EXPECT_THAT(option_value.str_value, StrEq(DispatchOption().value.str_value)); - - EXPECT_THAT(LiteRtGetEnvironmentOptionsValue( - Options(), NotInsertedOptionTag(), &option_value), - IsError(kLiteRtStatusErrorNotFound)); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_event.cc b/tensorflow/lite/experimental/litert/c/litert_event.cc deleted file mode 100644 index 09364b9cfdd1fa..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_event.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_event.h" - -#include - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event_type.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/runtime/event.h" - -#ifdef __cplusplus -extern "C" { -#endif - -LiteRtStatus LiteRtCreateEventFromSyncFenceFd(int sync_fence_fd, bool owns_fd, - LiteRtEvent* event) { -#if LITERT_HAS_SYNC_FENCE_SUPPORT - *event = new LiteRtEventT{.type = LiteRtEventTypeSyncFenceFd, - .fd = sync_fence_fd, - .owns_fd = owns_fd}; - return kLiteRtStatusOk; -#else - return kLiteRtStatusErrorUnsupported; -#endif -} - -LiteRtStatus LiteRtCreateEventFromOpenClEvent(cl_event cl_event, - LiteRtEvent* event) { -#if LITERT_HAS_OPENCL_SUPPORT - *event = new LiteRtEventT{ - .type = LiteRtEventTypeOpenCl, - .opencl_event = cl_event, - }; - return kLiteRtStatusOk; -#else - return kLiteRtStatusErrorUnsupported; -#endif -} - -LiteRtStatus LiteRtGetEventEventType(LiteRtEvent event, LiteRtEventType* type) { - *type = event->type; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetEventSyncFenceFd(LiteRtEvent event, int* sync_fence_fd) { -#if LITERT_HAS_SYNC_FENCE_SUPPORT - if (event->type == LiteRtEventTypeSyncFenceFd) { - *sync_fence_fd = event->fd; - return kLiteRtStatusOk; - } -#endif - return kLiteRtStatusErrorUnsupported; -} - -LiteRtStatus LiteRtGetEventOpenClEvent(LiteRtEvent event, cl_event* cl_event) { -#if LITERT_HAS_OPENCL_SUPPORT - if (event->type == LiteRtEventTypeOpenCl) { - *cl_event = event->opencl_event; - return kLiteRtStatusOk; - } -#endif - return kLiteRtStatusErrorUnsupported; -} - -LiteRtStatus LiteRtCreateManagedEvent(LiteRtEventType type, - LiteRtEvent* event) { - auto event_res = LiteRtEventT::CreateManaged(type); - if (!event_res) { - return kLiteRtStatusErrorUnsupported; - } - *event = *event_res; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtEventWait(LiteRtEvent event, int64_t timeout_in_ms) { - LITERT_RETURN_IF_ERROR(event->Wait(timeout_in_ms)); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtEventSignal(LiteRtEvent event) { - LITERT_RETURN_IF_ERROR(event->Signal()); - return kLiteRtStatusOk; -} - -void LiteRtDestroyEvent(LiteRtEvent event) { delete event; } - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_event.h b/tensorflow/lite/experimental/litert/c/litert_event.h deleted file mode 100644 index 16cc107168d0f0..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_event.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ - -#include // NOLINT: To use bool type in C -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event_type.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// Forward declaration of OpenCL event to avoid including OpenCL headers. -typedef struct _cl_event* cl_event; - -LITERT_DEFINE_HANDLE(LiteRtEvent); - -LiteRtStatus LiteRtCreateEventFromSyncFenceFd(int sync_fence_fd, bool owns_fd, - LiteRtEvent* event); - -LiteRtStatus LiteRtCreateEventFromOpenClEvent(cl_event cl_event, - LiteRtEvent* event); - -LiteRtStatus LiteRtCreateManagedEvent(LiteRtEventType type, LiteRtEvent* event); - -LiteRtStatus LiteRtGetEventEventType(LiteRtEvent event, LiteRtEventType* type); - -LiteRtStatus LiteRtGetEventSyncFenceFd(LiteRtEvent event, int* sync_fence_fd); - -LiteRtStatus LiteRtGetEventOpenClEvent(LiteRtEvent event, cl_event* cl_event); - -// Pass -1 for timeout_in_ms for indefinite wait. -LiteRtStatus LiteRtEventWait(LiteRtEvent event, int64_t timeout_in_ms); - -// Signal the event to notify the waiters. -LiteRtStatus LiteRtEventSignal(LiteRtEvent event); - -void LiteRtDestroyEvent(LiteRtEvent event); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_event_type.h b/tensorflow/lite/experimental/litert/c/litert_event_type.h deleted file mode 100644 index 24c7124702dcea..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_event_type.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_TYPE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_TYPE_H_ - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef enum { - LiteRtEventTypeUnknown = 0, - LiteRtEventTypeSyncFenceFd = 1, - LiteRtEventTypeOpenCl = 2, -} LiteRtEventType; - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_TYPE_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_gl_types.h b/tensorflow/lite/experimental/litert/c/litert_gl_types.h deleted file mode 100644 index 4f394e19ded7d1..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_gl_types.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_GL_TYPES_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_GL_TYPES_H_ - -#include -#if LITERT_HAS_OPENGL_SUPPORT -#include -#include -#endif // LITERT_HAS_OPENGL_SUPPORT - -#ifdef __cplusplus -extern "C" { -#endif - -#if LITERT_HAS_OPENGL_SUPPORT -typedef GLenum LiteRtGLenum; -typedef GLuint LiteRtGLuint; -typedef GLint LiteRtGLint; -#else -// Allows for compilation of GL types when OpenGl support is not available. -typedef uint32_t LiteRtGLenum; -typedef uint32_t LiteRtGLuint; -typedef int32_t LiteRtGLint; -#endif // LITERT_HAS_OPENGL_SUPPORT - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_GL_TYPES_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_layout.h b/tensorflow/lite/experimental/litert/c/litert_layout.h deleted file mode 100644 index b641985b9793af..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_layout.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LAYOUT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LAYOUT_H_ - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// Max number of dimensions in any ranked tensor type. -#define LITERT_TENSOR_MAX_RANK 8 - -// The shape information for tensor types of fixed rank. -typedef struct { - // The number of dimensions. - uint32_t rank; - - // Dimension sizes, array of length `rank`. Dynamic dimensions are anything - // less than 0. Everything from [rank, LITERT_MAX_RANK) is undefined. - int32_t dimensions[LITERT_TENSOR_MAX_RANK]; - - // Strides for a nomimal NWHC layout. NULL if unused. - const uint32_t* strides; -} LiteRtLayout; - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LAYOUT_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_logging.cc b/tensorflow/lite/experimental/litert/c/litert_logging.cc deleted file mode 100644 index 66f92cd9e79545..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_logging.cc +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/logger.h" -#include "tensorflow/lite/minimal_logging.h" - -class LiteRtLoggerT { - public: - LiteRtLogSeverity GetMinSeverity() { - return ConvertSeverity( - tflite::logging_internal::MinimalLogger::GetMinimumLogSeverity()); - } - - void SetMinSeverity(LiteRtLogSeverity severity) { - tflite::logging_internal::MinimalLogger::SetMinimumLogSeverity( - ConvertSeverity(severity)); - } - - void Log(LiteRtLogSeverity severity, const char* format, va_list args) { - tflite::logging_internal::MinimalLogger::LogFormatted( - ConvertSeverity(severity), format, args); - } - - private: - static tflite::LogSeverity ConvertSeverity(LiteRtLogSeverity severity) { - return static_cast(severity); - } - - static LiteRtLogSeverity ConvertSeverity(tflite::LogSeverity severity) { - return static_cast(severity); - } -}; - -#ifdef __cplusplus -extern "C" { -#endif - -LiteRtStatus LiteRtCreateLogger(LiteRtLogger* logger) { - if (!logger) { - return kLiteRtStatusErrorInvalidArgument; - } - *logger = new LiteRtLoggerT; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetMinLoggerSeverity(LiteRtLogger logger, - LiteRtLogSeverity* min_severity) { - if (!logger || !min_severity) { - return kLiteRtStatusErrorInvalidArgument; - } - *min_severity = logger->GetMinSeverity(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtSetMinLoggerSeverity(LiteRtLogger logger, - LiteRtLogSeverity min_severity) { - if (!logger) { - return kLiteRtStatusErrorInvalidArgument; - } - logger->SetMinSeverity(min_severity); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtLoggerLog(LiteRtLogger logger, LiteRtLogSeverity severity, - const char* format, ...) { - if (!logger || !format) { - return kLiteRtStatusErrorInvalidArgument; - } - va_list args; - va_start(args, format); - logger->Log(severity, format, args); - va_end(args); - return kLiteRtStatusOk; -} - -void LiteRtDestroyLogger(LiteRtLogger logger) { - if (logger != nullptr) { - delete logger; - } -} - -#ifdef __cplusplus -} // extern "C" -#endif - -namespace { -LiteRtLoggerT StaticLogger; -LiteRtLogger DefaultLogger = &StaticLogger; -} // namespace - -LiteRtStatus LiteRtSetDefaultLogger(LiteRtLogger logger) { - if (!logger) { - return kLiteRtStatusErrorInvalidArgument; - } - DefaultLogger = logger; - return kLiteRtStatusOk; -} - -LiteRtLogger LiteRtGetDefaultLogger() { return DefaultLogger; } diff --git a/tensorflow/lite/experimental/litert/c/litert_logging.h b/tensorflow/lite/experimental/litert/c/litert_logging.h deleted file mode 100644 index 4570e76327b7f9..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_logging.h +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LOGGING_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LOGGING_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITERT_DEFINE_HANDLE(LiteRtLogger); - -// WARNING: The values of the following enum are to be kept in sync with -// tflite::LogSeverity. -typedef enum { - kLiteRtLogSeverityVerbose = 0, - kLiteRtLogSeverityInfo = 1, - kLiteRtLogSeverityWarning = 2, - kLiteRtLogSeverityError = 3, - kLiteRtLogSeveritySilent = 4, -} LiteRtLogSeverity; - -#define LITERT_VERBOSE kLiteRtLogSeverityVerbose -#define LITERT_INFO kLiteRtLogSeverityInfo -#define LITERT_WARNING kLiteRtLogSeverityWarning -#define LITERT_ERROR kLiteRtLogSeverityError -#define LITERT_SILENT kLiteRtLogSeveritySilent - -LiteRtStatus LiteRtCreateLogger(LiteRtLogger* logger); -LiteRtStatus LiteRtGetMinLoggerSeverity(LiteRtLogger logger, - LiteRtLogSeverity* min_severity); -LiteRtStatus LiteRtSetMinLoggerSeverity(LiteRtLogger logger, - LiteRtLogSeverity min_severity); -LiteRtStatus LiteRtLoggerLog(LiteRtLogger logger, LiteRtLogSeverity severity, - const char* format, ...); -void LiteRtDestroyLogger(LiteRtLogger logger); - -LiteRtLogger LiteRtGetDefaultLogger(); -LiteRtStatus LiteRtSetDefaultLogger(LiteRtLogger logger); -LiteRtStatus LiteRtDefaultLoggerLog(LiteRtLogSeverity severity, - const char* format, ...); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#define LITERT_LOGGER_LOG_PROD(logger, severity, format, ...) \ - { \ - LiteRtLogSeverity __min_severity__; \ - if (LiteRtGetMinLoggerSeverity(logger, &__min_severity__) != \ - kLiteRtStatusOk) { \ - __min_severity__ = kLiteRtLogSeverityVerbose; \ - } \ - if (severity >= __min_severity__) { \ - LiteRtLoggerLog(logger, severity, "[%s:%d] " format, __FILE__, __LINE__, \ - ##__VA_ARGS__); \ - } \ - } - -#ifndef NDEBUG -#define LITERT_LOGGER_LOG LITERT_LOGGER_LOG_PROD -#else -#define LITERT_LOGGER_LOG(logger, severity, format, ...) \ - do { \ - LITERT_LOGGER_LOG_PROD(logger, severity, format, ##__VA_ARGS__); \ - } while (false) -#endif - -#define LITERT_LOG(severity, format, ...) \ - LITERT_LOGGER_LOG(LiteRtGetDefaultLogger(), severity, format, ##__VA_ARGS__); - -#define LITERT_ABORT abort() - -#define LITERT_FATAL(format, ...) \ - do { \ - LITERT_LOG(kLiteRtLogSeverityError, format, ##__VA_ARGS__) \ - LITERT_ABORT; \ - } while (0) - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LOGGING_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_logging_test.cc b/tensorflow/lite/experimental/litert/c/litert_logging_test.cc deleted file mode 100644 index 148fc778f18915..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_logging_test.cc +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" - -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -TEST(Layout, Creation) { - LiteRtLogger logger; - ASSERT_EQ(LiteRtCreateLogger(&logger), kLiteRtStatusOk); - LiteRtDestroyLogger(logger); -} - -TEST(Layout, MinLogging) { - LiteRtLogger logger; - ASSERT_EQ(LiteRtCreateLogger(&logger), kLiteRtStatusOk); - ASSERT_EQ(LiteRtSetMinLoggerSeverity(logger, LITERT_SILENT), kLiteRtStatusOk); - LiteRtLogSeverity min_severity; - ASSERT_EQ(LiteRtGetMinLoggerSeverity(logger, &min_severity), kLiteRtStatusOk); - ASSERT_EQ(min_severity, LITERT_SILENT); - LiteRtDestroyLogger(logger); -} diff --git a/tensorflow/lite/experimental/litert/c/litert_model.cc b/tensorflow/lite/experimental/litert/c/litert_model.cc deleted file mode 100644 index af83f970c94c2f..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_model.cc +++ /dev/null @@ -1,506 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" - -#include -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_load.h" -#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// -// Model -// - -LiteRtStatus LiteRtCreateModelFromFile(const char* filename, - LiteRtModel* model) { - if (!filename || !model) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto new_model = litert::internal::LoadModelFromFile(filename); - if (!new_model) { - return new_model.Error().Status(); - } - *model = new_model->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCreateModelFromBuffer(const void* buffer_addr, - size_t buffer_size, - LiteRtModel* model) { - if (!buffer_addr || !buffer_size || !model) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto new_model = litert::internal::LoadModelFromBuffer( - litert::BufferRef(buffer_addr, buffer_size)); - if (!new_model) { - return new_model.Error().Status(); - } - *model = new_model->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumModelSubgraphs(LiteRtModel model, - LiteRtParamIndex* num_subgraphs) { - if (model == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_subgraphs = model->Subgraphs().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetModelSubgraph(LiteRtModel model, - LiteRtParamIndex subgraph_index, - LiteRtSubgraph* subgraph) { - if (model == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - if (subgraph_index >= model->Subgraphs().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *subgraph = &model->Subgraph(subgraph_index); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetMainModelSubgraphIndex( - LiteRtModel model, LiteRtParamIndex* main_subgraph_index) { - if (!model || !main_subgraph_index) { - return kLiteRtStatusErrorInvalidArgument; - } - *main_subgraph_index = LiteRtModelT::kMainSubgraphIndex; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetModelMetadata(LiteRtModel model, const char* metadata_key, - const void** metadata_buffer, - size_t* metadata_buffer_size) { - if (!model || !metadata_key || !metadata_buffer || !metadata_buffer_size) { - return kLiteRtStatusErrorInvalidArgument; - } - auto m_buf = model->FindMetadata(metadata_key); - if (!m_buf) { - return m_buf.Error().Status(); - } - *metadata_buffer = m_buf->Data(); - *metadata_buffer_size = m_buf->Size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumModelSignatures(LiteRtModel model, - LiteRtParamIndex* num_signatures) { - if (!model || !num_signatures) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_signatures = model->Signatures().size(); - return kLiteRtStatusOk; -} - -// Get the signature at the given index in the model -LiteRtStatus LiteRtGetModelSignature(LiteRtModel model, - LiteRtParamIndex signature_index, - LiteRtSignature* signature) { - if (!model || !signature) { - return kLiteRtStatusErrorInvalidArgument; - } - if (signature_index >= model->Signatures().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *signature = model->Signatures().at(signature_index); - return kLiteRtStatusOk; -} - -void LiteRtDestroyModel(LiteRtModel model) { delete model; } - -LiteRtStatus LiteRtSerializeModel(LiteRtModel model, uint8_t** buf, - size_t* size, size_t* offset, - bool destroy_model, - LiteRtModelSerializationOptions options) { - auto serialized = litert::internal::SerializeModel( - std::move(*model), options.bytecode_alignment); - // Even if we fail to serialize, we still need to destroy the model if - // requested. This is because the model may have been partially serialized - // and we don't want to leak memory. Also if ownership of the model is - // transferred to the caller, we need to ensure that the model is destroyed - // when the caller is done with it. - if (destroy_model) { - delete model; - } - if (!serialized) { - return serialized.Error().Status(); - } - std::tie(*buf, *size, *offset) = serialized->Release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtPushOp(LiteRtOpList op_list, LiteRtOp op, - LiteRtParamIndex index) { - if (!op_list || !op) { - return kLiteRtStatusErrorInvalidArgument; - } - op_list->Push(op, index); - return kLiteRtStatusOk; -} - -// -// Signature -// - -LiteRtStatus LiteRtGetDefaultSignatureKey(const char** signature_key) { - if (!signature_key) { - return kLiteRtStatusErrorInvalidArgument; - } - *signature_key = LiteRtSignatureT::kDefaultSignatureKey.data(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSignatureKey(LiteRtSignature signature, - const char** signature_key) { - if (!signature || !signature_key) { - return kLiteRtStatusErrorInvalidArgument; - } - *signature_key = signature->Key().data(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSignatureSubgraph(LiteRtSignature signature, - LiteRtSubgraph* subgraph) { - if (signature == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *subgraph = &signature->GetSubgraph(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumSignatureInputs(LiteRtSignature signature, - LiteRtParamIndex* num_inputs) { - if (!signature || !num_inputs) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_inputs = signature->InputNames().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSignatureInputName(LiteRtSignature signature, - LiteRtParamIndex input_idx, - const char** input_name) { - if (!signature || !input_name) { - return kLiteRtStatusErrorInvalidArgument; - } - if (input_idx >= signature->InputNames().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *input_name = signature->InputNames().at(input_idx).data(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumSignatureOutputs(LiteRtSignature signature, - LiteRtParamIndex* num_outputs) { - if (!signature || !num_outputs) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_outputs = signature->OutputNames().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSignatureOutputName(LiteRtSignature signature, - LiteRtParamIndex output_idx, - const char** output_name) { - if (!signature || !output_name) { - return kLiteRtStatusErrorInvalidArgument; - } - if (output_idx >= signature->OutputNames().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *output_name = signature->OutputNames().at(output_idx).data(); - return kLiteRtStatusOk; -} - -// -// Subgraph -// - -LiteRtStatus LiteRtGetNumSubgraphInputs(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_inputs) { - if (!subgraph || !num_inputs) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_inputs = subgraph->Inputs().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSubgraphInput(LiteRtSubgraph subgraph, - LiteRtParamIndex input_index, - LiteRtTensor* input) { - if (!subgraph || !input) { - return kLiteRtStatusErrorInvalidArgument; - } else if (input_index < 0 || input_index >= subgraph->Inputs().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *input = subgraph->Inputs()[input_index]; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumSubgraphOutputs(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_outputs) { - if (!subgraph || !num_outputs) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_outputs = subgraph->Outputs().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSubgraphOutput(LiteRtSubgraph subgraph, - LiteRtParamIndex output_index, - LiteRtTensor* output) { - if (!subgraph || !output) { - return kLiteRtStatusErrorInvalidArgument; - } else if (output_index < 0 || output_index >= subgraph->Outputs().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *output = subgraph->Outputs()[output_index]; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumSubgraphOps(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_ops) { - if (!subgraph || !num_ops) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_ops = subgraph->Ops().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSubgraphOp(LiteRtSubgraph subgraph, - LiteRtParamIndex op_index, LiteRtOp* op) { - if (!subgraph || !op) { - return kLiteRtStatusErrorInvalidArgument; - } else if (op_index < 0 || op_index >= subgraph->Ops().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *op = subgraph->Ops()[op_index]; - return kLiteRtStatusOk; -} - -// -// Op -// - -LiteRtStatus LiteRtGetOpCode(LiteRtOp op, LiteRtOpCode* code) { - if (!op || !code) { - return kLiteRtStatusErrorInvalidArgument; - } - *code = op->OpCode(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumOpInputs(LiteRtOp op, LiteRtParamIndex* num_inputs) { - if (!op || !num_inputs) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_inputs = op->Inputs().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetOpInput(LiteRtOp op, LiteRtParamIndex input_index, - LiteRtTensor* input) { - if (!op || !input) { - return kLiteRtStatusErrorInvalidArgument; - } else if (input_index < 0 || input_index >= op->Inputs().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *input = op->Inputs()[input_index]; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumOpOutputs(LiteRtOp op, LiteRtParamIndex* num_outputs) { - if (!op || !num_outputs) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_outputs = op->Outputs().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetOpOutput(LiteRtOp op, LiteRtParamIndex output_index, - LiteRtTensor* output) { - if (!op || !output) { - return kLiteRtStatusErrorInvalidArgument; - } else if (output_index < 0 || output_index >= op->Outputs().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *output = op->Outputs()[output_index]; - return kLiteRtStatusOk; -} - -// -// Weights -// - -LiteRtStatus LiteRtGetWeightsBytes(LiteRtWeights weights, const void** addr, - size_t* size) { - if (!weights || !addr || !size) { - return kLiteRtStatusErrorInvalidArgument; - } - *addr = weights->Buffer().Data(); - *size = weights->Buffer().Size(); - return kLiteRtStatusOk; -} - -// -// Tensor -// - -LiteRtStatus LiteRtGetTensorWeights(LiteRtTensor tensor, - LiteRtWeights* weights) { - if (!tensor || !weights) { - return kLiteRtStatusErrorInvalidArgument; - } - *weights = &tensor->Weights(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumTensorUses(LiteRtTensor tensor, - LiteRtParamIndex* num_uses) { - if (!tensor || !num_uses) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_uses = tensor->Users().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorUse(LiteRtTensor tensor, LiteRtParamIndex use_index, - LiteRtOp* user, - LiteRtParamIndex* user_arg_index) { - if (!tensor || !user || !user_arg_index) { - return kLiteRtStatusErrorInvalidArgument; - } else if (use_index < 0 || use_index >= tensor->Users().size()) { - return kLiteRtStatusErrorIndexOOB; - } - *user = tensor->Users()[use_index]; - *user_arg_index = tensor->UserArgInds()[use_index]; - return kLiteRtStatusOk; -} - -// Null if subgraph input or constant. -LiteRtStatus LiteRtGetTensorDefiningOp(LiteRtTensor tensor, - bool* has_defining_op, - LiteRtTensorDefiningOp* defining_op) { - if (!tensor || !has_defining_op || !defining_op) { - return kLiteRtStatusErrorInvalidArgument; - } - if (tensor->DefiningOp() != nullptr) { - *has_defining_op = true; - defining_op->op = tensor->DefiningOp(); - defining_op->op_output_index = tensor->DefiningOpOutInd(); - } else { - *has_defining_op = false; - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorTypeId(LiteRtTensor tensor, - LiteRtTensorTypeId* type_id) { - if (!tensor || !type_id) { - return kLiteRtStatusErrorInvalidArgument; - } - *type_id = tensor->Type().first; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetUnrankedTensorType( - LiteRtTensor tensor, LiteRtUnrankedTensorType* unranked_tensor_type) { - if (!tensor || !unranked_tensor_type) { - return kLiteRtStatusErrorInvalidArgument; - } else if (tensor->Type().first != kLiteRtUnrankedTensorType) { - return kLiteRtStatusErrorInvalidIrType; - } - *unranked_tensor_type = tensor->Type().second.unranked_tensor_type; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetRankedTensorType( - LiteRtTensor tensor, LiteRtRankedTensorType* ranked_tensor_type) { - if (!tensor || !ranked_tensor_type) { - return kLiteRtStatusErrorInvalidArgument; - } else if (tensor->Type().first != kLiteRtRankedTensorType) { - return kLiteRtStatusErrorInvalidIrType; - } - *ranked_tensor_type = tensor->Type().second.ranked_tensor_type; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorName(LiteRtTensor tensor, const char** name) { - if (!tensor || !name) { - return kLiteRtStatusErrorInvalidArgument; - } - *name = tensor->Name().data(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetQuantizationTypeId(LiteRtTensor tensor, - LiteRtQuantizationTypeId* q_type_id) { - if (!tensor || !q_type_id) { - return kLiteRtStatusErrorInvalidArgument; - } - *q_type_id = tensor->Qparams().first; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetPerTensorQuantization( - LiteRtTensor tensor, LiteRtQuantizationPerTensor* per_tensor_quantization) { - if (!tensor || !per_tensor_quantization) { - return kLiteRtStatusErrorInvalidArgument; - } else if (tensor->Qparams().first != kLiteRtQuantizationPerTensor) { - return kLiteRtStatusErrorInvalidIrType; - } - auto& per_tensor = tensor->Qparams().second.per_tensor; - per_tensor_quantization->scale = per_tensor.scale; - per_tensor_quantization->zero_point = per_tensor.zero_point; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetPerChannelQuantization( - LiteRtTensor tensor, - LiteRtQuantizationPerChannel* per_channel_quantization) { - if (!tensor || !per_channel_quantization) { - return kLiteRtStatusErrorInvalidArgument; - } else if (tensor->Qparams().first != kLiteRtQuantizationPerChannel) { - return kLiteRtStatusErrorInvalidIrType; - } - auto& per_channel = tensor->Qparams().second.per_channel; - per_channel_quantization->scales = per_channel.scales; - per_channel_quantization->zero_points = per_channel.zero_points; - per_channel_quantization->num_channels = per_channel.num_channels; - per_channel_quantization->quantized_dimension = - per_channel.quantized_dimension; - return kLiteRtStatusOk; -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_model.h b/tensorflow/lite/experimental/litert/c/litert_model.h deleted file mode 100644 index ba55d759e23e77..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_model.h +++ /dev/null @@ -1,372 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ - -#include // NOLINT: To use bool type in C -#include -#include - -#include "tensorflow/lite/core/c/c_api_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_layout.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// -// Handles + Common -// - -// Constant data behind a tensor stored in the model. -LITERT_DEFINE_HANDLE(LiteRtWeights); - -// Values/edges of the models graph. -LITERT_DEFINE_HANDLE(LiteRtTensor); - -// Operations/nodes of the models graph. -LITERT_DEFINE_HANDLE(LiteRtOp); - -// Fundamental block of program, i.e. a function body. -LITERT_DEFINE_HANDLE(LiteRtSubgraph); - -// Signature of the model. -LITERT_DEFINE_HANDLE(LiteRtSignature); - -// A collection of subgraph + metadata + signature. -LITERT_DEFINE_HANDLE(LiteRtModel); - -// Append only list of ops. -LITERT_DEFINE_HANDLE(LiteRtOpList); - -// -// LiteRtTensor + Types -// - -// Get the string name associated with this tensor. This is an optional -// attribute and if not set will return a zero-length string. -LiteRtStatus LiteRtGetTensorName(LiteRtTensor tensor, const char** name); - -// TENSOR TYPES - -// Primitive types for elements in a tensor. -typedef enum { - kLiteRtElementTypeNone = kTfLiteNoType, - kLiteRtElementTypeBool = kTfLiteBool, - kLiteRtElementTypeInt4 = kTfLiteInt4, - kLiteRtElementTypeInt8 = kTfLiteInt8, - kLiteRtElementTypeInt16 = kTfLiteInt16, - kLiteRtElementTypeInt32 = kTfLiteInt32, - kLiteRtElementTypeInt64 = kTfLiteInt64, - kLiteRtElementTypeUInt8 = kTfLiteUInt8, - kLiteRtElementTypeUInt16 = kTfLiteUInt16, - kLiteRtElementTypeUInt32 = kTfLiteUInt32, - kLiteRtElementTypeUInt64 = kTfLiteUInt64, - kLiteRtElementTypeFloat16 = kTfLiteFloat16, - kLiteRtElementTypeBFloat16 = kTfLiteBFloat16, - kLiteRtElementTypeFloat32 = kTfLiteFloat32, - kLiteRtElementTypeFloat64 = kTfLiteFloat64, - kLiteRtElementTypeComplex64 = kTfLiteComplex64, - kLiteRtElementTypeComplex128 = kTfLiteComplex128, - kLiteRtElementTypeTfResource = kTfLiteResource, - kLiteRtElementTypeTfString = kTfLiteString, - kLiteRtElementTypeTfVariant = kTfLiteVariant, -} LiteRtElementType; - -// Tensor whose rank is dynamic. -typedef struct { - // The primitive element type of the constituent data. - LiteRtElementType element_type; -} LiteRtUnrankedTensorType; - -// Tensor whose rank is static but dimenions may be dynamic. -typedef struct { - // The primitive element type of the constituent data. - LiteRtElementType element_type; - - // Shape information. - LiteRtLayout layout; -} LiteRtRankedTensorType; - -// The identifier for tensor type union. -typedef enum { - // Type with fix ranked and possibly dynamic dimensions. - kLiteRtRankedTensorType = 0, - - // Type with dynamic rank. - kLiteRtUnrankedTensorType = 1, -} LiteRtTensorTypeId; - -// Get type identifier from tensor. -LiteRtStatus LiteRtGetTensorTypeId(LiteRtTensor tensor, - LiteRtTensorTypeId* type_id); - -// Get unranked tensor type info, return bad status if not unranked. -LiteRtStatus LiteRtGetUnrankedTensorType( - LiteRtTensor tensor, LiteRtUnrankedTensorType* unranked_tensor_type); - -// Get ranked tensor type info, return bad status if not ranked. -LiteRtStatus LiteRtGetRankedTensorType( - LiteRtTensor tensor, LiteRtRankedTensorType* ranked_tensor_type); - -// QUANTIZATION - -// Schema for tensors quantized with one set of q-params. -typedef struct { - // Scaling factor. - float scale; - - // The value that float:0 maps to in q-space. - int64_t zero_point; -} LiteRtQuantizationPerTensor; - -// Schema for tensors quantized with one set of q-params per channel. -typedef struct { - int32_t quantized_dimension; - uint64_t num_channels; - float* scales; - int64_t* zero_points; -} LiteRtQuantizationPerChannel; - -// The identifier for quantization scheme type union. -typedef enum { - // Tag for tensors without quantization. - kLiteRtQuantizationNone = 0, - - // Basic quantization, one set of q-params per tensor. - kLiteRtQuantizationPerTensor = 1, - - // [NOT IMPLEMENTED YET] Q-params for each element accross a single dimension. - kLiteRtQuantizationPerChannel = 2, - - // [NOT IMPLEMENTED YET] Q-params accross blocks of fixed size (e.g. 2048). - kLiteRtQuantizationBlockWise = 3, -} LiteRtQuantizationTypeId; - -// Get the identifier for the type of quantization for a given tensor. -LiteRtStatus LiteRtGetQuantizationTypeId(LiteRtTensor tensor, - LiteRtQuantizationTypeId* q_type_id); - -// Get the per-tensor quantization information for a given tensor if it has it. -LiteRtStatus LiteRtGetPerTensorQuantization( - LiteRtTensor tensor, LiteRtQuantizationPerTensor* per_tensor_quantization); - -// Get the per-channel quantization information for a given tensor if it has it. -LiteRtStatus LiteRtGetPerChannelQuantization( - LiteRtTensor tensor, - LiteRtQuantizationPerChannel* per_channel_quantization); - -// EDGES - -// Information about the about that defines a tensor. -typedef struct LiteRtTensorDefiningOp { - // The defining op itself. - LiteRtOp op; - - // The op output index that defines the specific tensor. - LiteRtParamIndex op_output_index; -} LiteRtTensorDefiningOp; - -// Information about a reference to a tensor in the graph. -typedef struct LiteRtTensorUserOp { - // The referring op itself. - LiteRtOp op; - - // Index of which operand the op refers to a specific tensor on. - LiteRtParamIndex op_input_index; -} LiteRtTensorUserOp; - -// Get all the ops that reference given tensor, and at what operand index. -LiteRtStatus LiteRtGetNumTensorUses(LiteRtTensor tensor, - LiteRtParamIndex* num_uses); -LiteRtStatus LiteRtGetTensorUse(LiteRtTensor tensor, LiteRtParamIndex use_index, - LiteRtOp* user, - LiteRtParamIndex* user_arg_index); - -// Get the op that defines this tensor and the corresponding output index. If -// tensor is a subgraph input, has_defining_op will be false. -LiteRtStatus LiteRtGetTensorDefiningOp(LiteRtTensor tensor, - bool* has_defining_op, - LiteRtTensorDefiningOp* defining_op); - -// WEIGHTS (constant data) - -// Get static weights associated with a given tensor. All tensors have weights, -// null weights have size = 0; -LiteRtStatus LiteRtGetTensorWeights(LiteRtTensor tensor, - LiteRtWeights* weights); - -// -// LiteRtWeights -// - -// Get opaque array from given tensor weights. -LiteRtStatus LiteRtGetWeightsBytes(LiteRtWeights weights, const void** addr, - size_t* size); - -// -// LiteRtOp -// - -// Get code corresponding to operation type for given op. -LiteRtStatus LiteRtGetOpCode(LiteRtOp op, LiteRtOpCode* code); - -// Get input tensors of given op. -LiteRtStatus LiteRtGetNumOpInputs(LiteRtOp op, LiteRtParamIndex* num_inputs); -LiteRtStatus LiteRtGetOpInput(LiteRtOp op, LiteRtParamIndex input_index, - LiteRtTensor* input); - -// Get output tensors of given op. -LiteRtStatus LiteRtGetNumOpOutputs(LiteRtOp op, LiteRtParamIndex* num_outputs); -LiteRtStatus LiteRtGetOpOutput(LiteRtOp op, LiteRtParamIndex output_index, - LiteRtTensor* output); - -// -// LiteRtSubgraph -// - -// Get input tensors for given subgraph. -LiteRtStatus LiteRtGetNumSubgraphInputs(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_inputs); -LiteRtStatus LiteRtGetSubgraphInput(LiteRtSubgraph subgraph, - LiteRtParamIndex input_index, - LiteRtTensor* input); - -// Get output tensors for given subgraph. -LiteRtStatus LiteRtGetNumSubgraphOutputs(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_outputs); -LiteRtStatus LiteRtGetSubgraphOutput(LiteRtSubgraph subgraph, - LiteRtParamIndex output_index, - LiteRtTensor* output); - -// Get all ops in given subgraph in a topological order. -LiteRtStatus LiteRtGetNumSubgraphOps(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_ops); -LiteRtStatus LiteRtGetSubgraphOp(LiteRtSubgraph subgraph, - LiteRtParamIndex op_index, LiteRtOp* op); - -// -// LiteRtSignature -// - -// Default signature key. This is the key that is used if the model does not -// define any signatures. -LiteRtStatus LiteRtGetDefaultSignatureKey(const char** signature_key); - -// Get the signature key string defined in the model. -LiteRtStatus LiteRtGetSignatureKey(LiteRtSignature signature, - const char** signature_key); - -// Get the associated subgraph for the given signature. -LiteRtStatus LiteRtGetSignatureSubgraph(LiteRtSignature signature, - LiteRtSubgraph* subgraph); - -// Get the number of inputs for the given signature. -LiteRtStatus LiteRtGetNumSignatureInputs(LiteRtSignature signature, - LiteRtParamIndex* num_inputs); - -// Get the name of the i-th of input tensor name for the given signature. -LiteRtStatus LiteRtGetSignatureInputName(LiteRtSignature signature, - LiteRtParamIndex input_idx, - const char** input_name); - -// Get the number of outputs for the given signature. -LiteRtStatus LiteRtGetNumSignatureOutputs(LiteRtSignature signature, - LiteRtParamIndex* num_outputs); - -// Get the name of the i-th of output tensor name for the given signature. -LiteRtStatus LiteRtGetSignatureOutputName(LiteRtSignature signature, - LiteRtParamIndex output_idx, - const char** output_name); - -// -// LiteRtModel -// - -LiteRtStatus LiteRtCreateModelFromFile(const char* filename, - LiteRtModel* model); - -LiteRtStatus LiteRtCreateModelFromBuffer(const void* buffer_addr, - size_t buffer_size, - LiteRtModel* model); - -// Get the metadata buffer associated with given key if it exists. -LiteRtStatus LiteRtGetModelMetadata(LiteRtModel model, const char* metadata_key, - const void** metadata_buffer, - size_t* metadata_buffer_size); - -// Get the index of the entry subgraph. -// TODO: b/365299994 - Figure out signatures. -LiteRtStatus LiteRtGetMainModelSubgraphIndex( - LiteRtModel model, LiteRtParamIndex* main_subgraph_index); - -// Get number of subgraphs in model. -LiteRtStatus LiteRtGetNumModelSubgraphs(LiteRtModel model, - LiteRtParamIndex* num_subgraphs); - -// Get subgraph at given index in model. -LiteRtStatus LiteRtGetModelSubgraph(LiteRtModel model, - LiteRtParamIndex subgraph_index, - LiteRtSubgraph* subgraph); - -// Get the number of signatures defined in the model. -LiteRtStatus LiteRtGetNumModelSignatures(LiteRtModel model, - LiteRtParamIndex* num_signatures); - -// Get the signature at the given index in the model -LiteRtStatus LiteRtGetModelSignature(LiteRtModel model, - LiteRtParamIndex signature_index, - LiteRtSignature* signature); - -// Destroy the given model, freeing any memory it owns. -void LiteRtDestroyModel(LiteRtModel model); - -// -// Utility Types -// - -// An append only list of ops. -LiteRtStatus LiteRtPushOp(LiteRtOpList op_list, LiteRtOp op, - LiteRtParamIndex partition_index); - -// -// Serialization related functions -// - -// Options for model serialization. -typedef struct LiteRtModelSerializationOptions { - // Alignment for bytecode assets that are appended to the model. - // Alignment is enforced relative to the first byte of the flatbuffer. - size_t bytecode_alignment; -} LiteRtModelSerializationOptions; - -// Serializes model to valid tflite flatbuffer bytes. -// -// This destroys the model before it returns unless destroy_model is false. -// Caller takes ownership of `buf`. Flatbuffers are packed into their arrays -// back to front, so the valid flatbuffer is buf[offset, size]. See the above -// options for more details. -LiteRtStatus LiteRtSerializeModel(LiteRtModel model, uint8_t** buf, - size_t* size, size_t* offset, - bool destroy_model, - LiteRtModelSerializationOptions options); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_model_test.cc b/tensorflow/lite/experimental/litert/c/litert_model_test.cc deleted file mode 100644 index 8f41902e8b70f1..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_model_test.cc +++ /dev/null @@ -1,390 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { - -using ::litert::BufferRef; -using ::litert::OwningBufferRef; -using ::testing::ElementsAreArray; -using ::testing::litert::IsError; - -TEST(LiteRtWeightsTest, GetNullWeights) { - LiteRtWeightsT weights = {}; - - const void* addr; - size_t size; - LITERT_ASSERT_OK(LiteRtGetWeightsBytes(&weights, &addr, &size)); - - EXPECT_EQ(addr, nullptr); - EXPECT_EQ(size, 0); -} - -TEST(LiteRtWeightsTest, GetWeights) { - static constexpr std::array kData = {1, 2, 3}; - const uint8_t* kDataPtr = reinterpret_cast(kData.data()); - const auto kDataSize = kData.size() * sizeof(int32_t); - - LiteRtWeightsT weights; - SetWeightsFromOwnedBuffer(weights, - OwningBufferRef(kDataPtr, kDataSize)); - - const void* addr; - size_t size; - LITERT_ASSERT_OK(LiteRtGetWeightsBytes(&weights, &addr, &size)); - - EXPECT_NE(addr, nullptr); - EXPECT_EQ(size, 3 * sizeof(int32_t)); - - EXPECT_THAT(absl::MakeConstSpan(reinterpret_cast(addr), 3), - ElementsAreArray(kData)); -} - -TEST(LiteRtTensorTest, GetUnrankedType) { - static constexpr auto kElementType = kLiteRtElementTypeFloat32; - static constexpr auto kId = kLiteRtUnrankedTensorType; - - TensorType type; - type.first = kId; - type.second.unranked_tensor_type.element_type = kElementType; - - LiteRtTensorT tensor; - tensor.SetType(std::move(type)); - - LiteRtTensorTypeId id; - LITERT_ASSERT_OK(LiteRtGetTensorTypeId(&tensor, &id)); - ASSERT_EQ(id, kId); - - LiteRtUnrankedTensorType unranked; - LITERT_ASSERT_OK(LiteRtGetUnrankedTensorType(&tensor, &unranked)); - EXPECT_EQ(unranked.element_type, kElementType); -} - -TEST(LiteRtTensorTest, GetRankedTensorType) { - static constexpr auto kElementType = kLiteRtElementTypeFloat32; - static constexpr auto kId = kLiteRtRankedTensorType; - - LiteRtTensorT tensor; - tensor.SetType(MakeRankedTensorType(kElementType, {3, 3})); - - LiteRtTensorTypeId id; - LITERT_ASSERT_OK(LiteRtGetTensorTypeId(&tensor, &id)); - ASSERT_EQ(id, kId); - - LiteRtRankedTensorType ranked; - LITERT_ASSERT_OK(LiteRtGetRankedTensorType(&tensor, &ranked)); - EXPECT_EQ(ranked.element_type, kElementType); - ASSERT_EQ(ranked.layout.rank, 2); - EXPECT_THAT(absl::MakeConstSpan(ranked.layout.dimensions, 2), - ElementsAreArray({3, 3})); -} - -TEST(LiteRtTensorTest, GetUses) { - LiteRtTensorT tensor; - - LiteRtOpT user; - tensor.Users().push_back(&user); - tensor.UserArgInds().push_back(0); - - LiteRtOpT other_user; - tensor.Users().push_back(&other_user); - tensor.UserArgInds().push_back(1); - - LiteRtParamIndex num_uses; - LITERT_ASSERT_OK(LiteRtGetNumTensorUses(&tensor, &num_uses)); - ASSERT_EQ(num_uses, 2); - - LiteRtOp actual_user; - LiteRtParamIndex actual_user_arg_index; - LITERT_ASSERT_OK(LiteRtGetTensorUse(&tensor, /*use_index=*/0, &actual_user, - &actual_user_arg_index)); - ASSERT_EQ(actual_user, &user); - ASSERT_EQ(actual_user_arg_index, 0); - - LITERT_ASSERT_OK(LiteRtGetTensorUse(&tensor, /*use_index=*/1, &actual_user, - &actual_user_arg_index)); - ASSERT_EQ(actual_user, &other_user); - ASSERT_EQ(actual_user_arg_index, 1); -} - -TEST(LiteRtTensorTest, GetDefiningOp) { - LiteRtTensorT tensor; - - LiteRtOpT def_op; - tensor.SetDefiningOp(def_op, 0); - - LiteRtTensorDefiningOp actual_def_op; - bool has_defining_op; - LITERT_ASSERT_OK( - LiteRtGetTensorDefiningOp(&tensor, &has_defining_op, &actual_def_op)); - ASSERT_TRUE(has_defining_op); - EXPECT_EQ(actual_def_op.op, &def_op); - EXPECT_EQ(actual_def_op.op_output_index, 0); -} - -TEST(LiteRtTensorTest, NoDefiningOp) { - LiteRtTensorT tensor; - - LiteRtTensorDefiningOp actual_def_op; - bool has_defining_op; - LITERT_ASSERT_OK( - LiteRtGetTensorDefiningOp(&tensor, &has_defining_op, &actual_def_op)); - ASSERT_FALSE(has_defining_op); -} - -TEST(LiteRtTensorTest, Name) { - static constexpr const char kName[] = "foo"; - - LiteRtTensorT tensor; - tensor.SetName(std::string(kName)); - - const char* name; - LITERT_ASSERT_OK(LiteRtGetTensorName(&tensor, &name)); - EXPECT_STREQ(name, kName); -} - -TEST(LiteRtTensorTest, QuantizationNone) { - LiteRtTensorT tensor; - - LiteRtQuantizationTypeId q_type_id; - LITERT_ASSERT_OK(LiteRtGetQuantizationTypeId(&tensor, &q_type_id)); - EXPECT_EQ(q_type_id, kLiteRtQuantizationNone); - - LiteRtQuantizationPerTensor per_tensor_quantization; - EXPECT_NE(LiteRtGetPerTensorQuantization(&tensor, &per_tensor_quantization), - kLiteRtStatusOk); -} - -TEST(LiteRtTensorTest, QuantizationPerTensor) { - static constexpr auto kScale = 1.0; - static constexpr auto kZeroPoint = 1; - - LiteRtTensorT tensor; - tensor.SetQarams(MakePerTensorQuantization(kScale, kZeroPoint)); - - LiteRtQuantizationTypeId q_type_id; - LITERT_ASSERT_OK(LiteRtGetQuantizationTypeId(&tensor, &q_type_id)); - ASSERT_EQ(q_type_id, kLiteRtQuantizationPerTensor); - - LiteRtQuantizationPerTensor per_tensor_quantization; - LITERT_ASSERT_OK( - LiteRtGetPerTensorQuantization(&tensor, &per_tensor_quantization)); - - EXPECT_EQ(per_tensor_quantization.scale, kScale); - EXPECT_EQ(per_tensor_quantization.zero_point, kZeroPoint); -} - -TEST(LiteRtTensorTest, QuantizationPerChannel) { - static constexpr size_t kNumChannels = 2; - static constexpr size_t kQuantizedDimension = 0; - static constexpr float kScales[kNumChannels] = {1.0, 2.0}; - static constexpr int64_t kZps[kNumChannels] = {2, 3}; - - LiteRtTensorT tensor; - - { - auto per_channel = - MakePerChannelQuantization(kScales, kZps, kQuantizedDimension, tensor); - tensor.SetQarams(per_channel); - } - - LiteRtQuantizationTypeId q_type_id; - LITERT_ASSERT_OK(LiteRtGetQuantizationTypeId(&tensor, &q_type_id)); - ASSERT_EQ(q_type_id, kLiteRtQuantizationPerChannel); - - LiteRtQuantizationPerChannel per_channel_quantization; - LITERT_ASSERT_OK( - LiteRtGetPerChannelQuantization(&tensor, &per_channel_quantization)); - - EXPECT_THAT( - absl::MakeConstSpan(per_channel_quantization.scales, kNumChannels), - testing::ElementsAreArray(kScales)); - EXPECT_THAT( - absl::MakeConstSpan(per_channel_quantization.zero_points, kNumChannels), - testing::ElementsAreArray(kZps)); - ASSERT_EQ(per_channel_quantization.num_channels, kNumChannels); - ASSERT_EQ(per_channel_quantization.quantized_dimension, kQuantizedDimension); -} - -TEST(LiteRtOpTest, GetOpCode) { - static constexpr auto kCode = kLiteRtOpCodeTflCustom; - - LiteRtOpT op; - op.SetOpCode(kCode); - - LiteRtOpCode code; - LITERT_ASSERT_OK(LiteRtGetOpCode(&op, &code)); - EXPECT_EQ(code, kCode); -} - -TEST(LiteRtOpTest, GetInputs) { - LiteRtTensorT input1; - LiteRtTensorT input2; - - LiteRtOpT op; - op.Inputs().push_back(&input1); - op.Inputs().push_back(&input2); - - LiteRtParamIndex num_inputs; - LITERT_ASSERT_OK(LiteRtGetNumOpInputs(&op, &num_inputs)); - ASSERT_EQ(num_inputs, 2); - - LiteRtTensor actual_input; - LITERT_ASSERT_OK(LiteRtGetOpInput(&op, /*input_index=*/0, &actual_input)); - EXPECT_EQ(actual_input, &input1); - - LITERT_ASSERT_OK(LiteRtGetOpInput(&op, /*input_index=*/1, &actual_input)); - EXPECT_EQ(actual_input, &input2); -} - -TEST(LiteRtOpTest, GetOutputs) { - LiteRtTensorT output1; - LiteRtTensorT output2; - - LiteRtOpT op; - op.Outputs().push_back(&output1); - op.Outputs().push_back(&output2); - - LiteRtParamIndex num_outputs; - LITERT_ASSERT_OK(LiteRtGetNumOpOutputs(&op, &num_outputs)); - ASSERT_EQ(num_outputs, 2); - - LiteRtTensor actual_output; - LITERT_ASSERT_OK(LiteRtGetOpOutput(&op, /*output_index=*/0, &actual_output)); - EXPECT_EQ(actual_output, &output1); - - LITERT_ASSERT_OK(LiteRtGetOpOutput(&op, /*output_index=*/1, &actual_output)); - EXPECT_EQ(actual_output, &output2); -} - -TEST(LiteRtSubgraphTest, GetInputs) { - LiteRtTensorT input1; - LiteRtTensorT input2; - - LiteRtSubgraphT subgraph; - subgraph.Inputs().push_back(&input1); - subgraph.Inputs().push_back(&input2); - - LiteRtParamIndex num_inputs; - LITERT_ASSERT_OK(LiteRtGetNumSubgraphInputs(&subgraph, &num_inputs)); - - LiteRtTensor actual_input; - LITERT_ASSERT_OK( - LiteRtGetSubgraphInput(&subgraph, /*input_index=*/0, &actual_input)); - EXPECT_EQ(actual_input, &input1); - - LITERT_ASSERT_OK( - LiteRtGetSubgraphInput(&subgraph, /*input_index=*/1, &actual_input)); - EXPECT_EQ(actual_input, &input2); -} - -TEST(LiteRtSubgraphTest, GetOutputs) { - LiteRtTensorT output1; - LiteRtTensorT output2; - - LiteRtSubgraphT subgraph; - subgraph.Outputs().push_back(&output1); - subgraph.Outputs().push_back(&output2); - - LiteRtParamIndex num_outputs; - LITERT_ASSERT_OK(LiteRtGetNumSubgraphOutputs(&subgraph, &num_outputs)); - - LiteRtTensor actual_output; - LITERT_ASSERT_OK( - LiteRtGetSubgraphOutput(&subgraph, /*output_index=*/0, &actual_output)); - EXPECT_EQ(actual_output, &output1); - - LITERT_ASSERT_OK( - LiteRtGetSubgraphOutput(&subgraph, /*output_index=*/1, &actual_output)); - EXPECT_EQ(actual_output, &output2); -} - -TEST(LiteRtSubgraphTest, GetOps) { - LiteRtSubgraphT subgraph; - auto& op1 = subgraph.EmplaceOp(); - auto& op2 = subgraph.EmplaceOp(); - - LiteRtParamIndex num_ops; - LITERT_ASSERT_OK(LiteRtGetNumSubgraphOps(&subgraph, &num_ops)); - ASSERT_EQ(num_ops, 2); - - LiteRtOp actual_op; - LITERT_ASSERT_OK(LiteRtGetSubgraphOp(&subgraph, /*op_index=*/0, &actual_op)); - ASSERT_EQ(actual_op, &op1); - - LITERT_ASSERT_OK(LiteRtGetSubgraphOp(&subgraph, /*op_index=*/1, &actual_op)); - ASSERT_EQ(actual_op, &op2); -} - -TEST(LiteRtModelTest, GetMetadata) { - static constexpr absl::string_view kKey = "KEY"; - static constexpr absl::string_view kData = "DATA"; - - LiteRtModelT model; - model.PushMetadata(kKey, kData); - - const void* metadata; - size_t metadata_size; - LITERT_ASSERT_OK( - LiteRtGetModelMetadata(&model, kKey.data(), &metadata, &metadata_size)); - EXPECT_EQ(BufferRef(metadata, metadata_size).StrView(), kData); -} - -TEST(LiteRtModelTest, GetSubgraph) { - LiteRtModelT model; - auto& subgraph = model.EmplaceSubgraph(); - - LiteRtSubgraph actual_subgraph; - LITERT_ASSERT_OK(LiteRtGetModelSubgraph(&model, 0, &actual_subgraph)); - EXPECT_EQ(actual_subgraph, &subgraph); -} - -TEST(LiteRtModelTest, GetSubgraphOOB) { - LiteRtModelT model; - - LiteRtSubgraph actual_subgraph; - EXPECT_THAT(LiteRtGetModelSubgraph(&model, 0, &actual_subgraph), - IsError(kLiteRtStatusErrorIndexOOB)); -} - -TEST(LiteRtOpListTest, PushOps) { - LiteRtOpListT op_list; - LiteRtOpT op; - - LITERT_ASSERT_OK(LiteRtPushOp(&op_list, &op, 0)); - auto vec = op_list.Values(); - ASSERT_EQ(vec.size(), 1); - EXPECT_EQ(vec.front().first, &op); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_op_code.h b/tensorflow/lite/experimental/litert/c/litert_op_code.h deleted file mode 100644 index 529360e87dc415..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_op_code.h +++ /dev/null @@ -1,245 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OP_CODE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OP_CODE_H_ - -#include "tensorflow/lite/builtin_ops.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef enum { - kLiteRtOpCodeTflAdd = kTfLiteBuiltinAdd, - kLiteRtOpCodeTflAveragePool2d = kTfLiteBuiltinAveragePool2d, - kLiteRtOpCodeTflConcatenation = kTfLiteBuiltinConcatenation, - kLiteRtOpCodeTflConv2d = kTfLiteBuiltinConv2d, - kLiteRtOpCodeTflDepthwiseConv2d = kTfLiteBuiltinDepthwiseConv2d, - kLiteRtOpCodeTflDepthToSpace = kTfLiteBuiltinDepthToSpace, - kLiteRtOpCodeTflDequantize = kTfLiteBuiltinDequantize, - kLiteRtOpCodeTflEmbeddingLookup = kTfLiteBuiltinEmbeddingLookup, - kLiteRtOpCodeTflFloor = kTfLiteBuiltinFloor, - kLiteRtOpCodeTflFullyConnected = kTfLiteBuiltinFullyConnected, - kLiteRtOpCodeTflHashtableLookup = kTfLiteBuiltinHashtableLookup, - kLiteRtOpCodeTflL2Normalization = kTfLiteBuiltinL2Normalization, - kLiteRtOpCodeTflL2Pool2d = kTfLiteBuiltinL2Pool2d, - kLiteRtOpCodeTflLocalResponseNormalization = - kTfLiteBuiltinLocalResponseNormalization, - kLiteRtOpCodeTflLogistic = kTfLiteBuiltinLogistic, - kLiteRtOpCodeTflLshProjection = kTfLiteBuiltinLshProjection, - kLiteRtOpCodeTflLstm = kTfLiteBuiltinLstm, - kLiteRtOpCodeTflMaxPool2d = kTfLiteBuiltinMaxPool2d, - kLiteRtOpCodeTflMul = kTfLiteBuiltinMul, - kLiteRtOpCodeTflRelu = kTfLiteBuiltinRelu, - kLiteRtOpCodeTflReluN1To1 = kTfLiteBuiltinReluN1To1, - kLiteRtOpCodeTflRelu6 = kTfLiteBuiltinRelu6, - kLiteRtOpCodeTflReshape = kTfLiteBuiltinReshape, - kLiteRtOpCodeTflResizeBilinear = kTfLiteBuiltinResizeBilinear, - kLiteRtOpCodeTflRnn = kTfLiteBuiltinRnn, - kLiteRtOpCodeTflSoftmax = kTfLiteBuiltinSoftmax, - kLiteRtOpCodeTflSpaceToDepth = kTfLiteBuiltinSpaceToDepth, - kLiteRtOpCodeTflSvdf = kTfLiteBuiltinSvdf, - kLiteRtOpCodeTflTanh = kTfLiteBuiltinTanh, - kLiteRtOpCodeTflConcatEmbeddings = kTfLiteBuiltinConcatEmbeddings, - kLiteRtOpCodeTflSkipGram = kTfLiteBuiltinSkipGram, - kLiteRtOpCodeTflCall = kTfLiteBuiltinCall, - kLiteRtOpCodeTflCustom = kTfLiteBuiltinCustom, - kLiteRtOpCodeTflEmbeddingLookupSparse = kTfLiteBuiltinEmbeddingLookupSparse, - kLiteRtOpCodeTflPad = kTfLiteBuiltinPad, - kLiteRtOpCodeTflUnidirectionalSequenceRnn = - kTfLiteBuiltinUnidirectionalSequenceRnn, - kLiteRtOpCodeTflGather = kTfLiteBuiltinGather, - kLiteRtOpCodeTflBatchToSpaceNd = kTfLiteBuiltinBatchToSpaceNd, - kLiteRtOpCodeTflSpaceToBatchNd = kTfLiteBuiltinSpaceToBatchNd, - kLiteRtOpCodeTflTranspose = kTfLiteBuiltinTranspose, - kLiteRtOpCodeTflMean = kTfLiteBuiltinMean, - kLiteRtOpCodeTflSub = kTfLiteBuiltinSub, - kLiteRtOpCodeTflDiv = kTfLiteBuiltinDiv, - kLiteRtOpCodeTflSqueeze = kTfLiteBuiltinSqueeze, - kLiteRtOpCodeTflUnidirectionalSequenceLstm = - kTfLiteBuiltinUnidirectionalSequenceLstm, - kLiteRtOpCodeTflStridedSlice = kTfLiteBuiltinStridedSlice, - kLiteRtOpCodeTflBidirectionalSequenceRnn = - kTfLiteBuiltinBidirectionalSequenceRnn, - kLiteRtOpCodeTflExp = kTfLiteBuiltinExp, - kLiteRtOpCodeTflTopkV2 = kTfLiteBuiltinTopkV2, - kLiteRtOpCodeTflSplit = kTfLiteBuiltinSplit, - kLiteRtOpCodeTflLogSoftmax = kTfLiteBuiltinLogSoftmax, - kLiteRtOpCodeTflDelegate = kTfLiteBuiltinDelegate, - kLiteRtOpCodeTflBidirectionalSequenceLstm = - kTfLiteBuiltinBidirectionalSequenceLstm, - kLiteRtOpCodeTflCast = kTfLiteBuiltinCast, - kLiteRtOpCodeTflPrelu = kTfLiteBuiltinPrelu, - kLiteRtOpCodeTflMaximum = kTfLiteBuiltinMaximum, - kLiteRtOpCodeTflArgMax = kTfLiteBuiltinArgMax, - kLiteRtOpCodeTflMinimum = kTfLiteBuiltinMinimum, - kLiteRtOpCodeTflLess = kTfLiteBuiltinLess, - kLiteRtOpCodeTflNeg = kTfLiteBuiltinNeg, - kLiteRtOpCodeTflPadv2 = kTfLiteBuiltinPadv2, - kLiteRtOpCodeTflGreater = kTfLiteBuiltinGreater, - kLiteRtOpCodeTflGreaterEqual = kTfLiteBuiltinGreaterEqual, - kLiteRtOpCodeTflLessEqual = kTfLiteBuiltinLessEqual, - kLiteRtOpCodeTflSelect = kTfLiteBuiltinSelect, - kLiteRtOpCodeTflSlice = kTfLiteBuiltinSlice, - kLiteRtOpCodeTflSin = kTfLiteBuiltinSin, - kLiteRtOpCodeTflTransposeConv = kTfLiteBuiltinTransposeConv, - kLiteRtOpCodeTflSparseToDense = kTfLiteBuiltinSparseToDense, - kLiteRtOpCodeTflTile = kTfLiteBuiltinTile, - kLiteRtOpCodeTflExpandDims = kTfLiteBuiltinExpandDims, - kLiteRtOpCodeTflEqual = kTfLiteBuiltinEqual, - kLiteRtOpCodeTflNotEqual = kTfLiteBuiltinNotEqual, - kLiteRtOpCodeTflLog = kTfLiteBuiltinLog, - kLiteRtOpCodeTflSum = kTfLiteBuiltinSum, - kLiteRtOpCodeTflSqrt = kTfLiteBuiltinSqrt, - kLiteRtOpCodeTflRsqrt = kTfLiteBuiltinRsqrt, - kLiteRtOpCodeTflShape = kTfLiteBuiltinShape, - kLiteRtOpCodeTflPow = kTfLiteBuiltinPow, - kLiteRtOpCodeTflArgMin = kTfLiteBuiltinArgMin, - kLiteRtOpCodeTflFakeQuant = kTfLiteBuiltinFakeQuant, - kLiteRtOpCodeTflReduceProd = kTfLiteBuiltinReduceProd, - kLiteRtOpCodeTflReduceMax = kTfLiteBuiltinReduceMax, - kLiteRtOpCodeTflPack = kTfLiteBuiltinPack, - kLiteRtOpCodeTflLogicalOr = kTfLiteBuiltinLogicalOr, - kLiteRtOpCodeTflOneHot = kTfLiteBuiltinOneHot, - kLiteRtOpCodeTflLogicalAnd = kTfLiteBuiltinLogicalAnd, - kLiteRtOpCodeTflLogicalNot = kTfLiteBuiltinLogicalNot, - kLiteRtOpCodeTflUnpack = kTfLiteBuiltinUnpack, - kLiteRtOpCodeTflReduceMin = kTfLiteBuiltinReduceMin, - kLiteRtOpCodeTflFloorDiv = kTfLiteBuiltinFloorDiv, - kLiteRtOpCodeTflReduceAny = kTfLiteBuiltinReduceAny, - kLiteRtOpCodeTflSquare = kTfLiteBuiltinSquare, - kLiteRtOpCodeTflZerosLike = kTfLiteBuiltinZerosLike, - kLiteRtOpCodeTflFill = kTfLiteBuiltinFill, - kLiteRtOpCodeTflFloorMod = kTfLiteBuiltinFloorMod, - kLiteRtOpCodeTflRange = kTfLiteBuiltinRange, - kLiteRtOpCodeTflResizeNearestNeighbor = kTfLiteBuiltinResizeNearestNeighbor, - kLiteRtOpCodeTflLeakyRelu = kTfLiteBuiltinLeakyRelu, - kLiteRtOpCodeTflSquaredDifference = kTfLiteBuiltinSquaredDifference, - kLiteRtOpCodeTflMirrorPad = kTfLiteBuiltinMirrorPad, - kLiteRtOpCodeTflAbs = kTfLiteBuiltinAbs, - kLiteRtOpCodeTflSplitV = kTfLiteBuiltinSplitV, - kLiteRtOpCodeTflUnique = kTfLiteBuiltinUnique, - kLiteRtOpCodeTflCeil = kTfLiteBuiltinCeil, - kLiteRtOpCodeTflReverseV2 = kTfLiteBuiltinReverseV2, - kLiteRtOpCodeTflAddN = kTfLiteBuiltinAddN, - kLiteRtOpCodeTflGatherNd = kTfLiteBuiltinGatherNd, - kLiteRtOpCodeTflCos = kTfLiteBuiltinCos, - kLiteRtOpCodeTflWhere = kTfLiteBuiltinWhere, - kLiteRtOpCodeTflRank = kTfLiteBuiltinRank, - kLiteRtOpCodeTflElu = kTfLiteBuiltinElu, - kLiteRtOpCodeTflReverseSequence = kTfLiteBuiltinReverseSequence, - kLiteRtOpCodeTflMatrixDiag = kTfLiteBuiltinMatrixDiag, - kLiteRtOpCodeTflQuantize = kTfLiteBuiltinQuantize, - kLiteRtOpCodeTflMatrixSetDiag = kTfLiteBuiltinMatrixSetDiag, - kLiteRtOpCodeTflRound = kTfLiteBuiltinRound, - kLiteRtOpCodeTflHardSwish = kTfLiteBuiltinHardSwish, - kLiteRtOpCodeTflIf = kTfLiteBuiltinIf, - kLiteRtOpCodeTflWhile = kTfLiteBuiltinWhile, - kLiteRtOpCodeTflNonMaxSuppressionV4 = kTfLiteBuiltinNonMaxSuppressionV4, - kLiteRtOpCodeTflNonMaxSuppressionV5 = kTfLiteBuiltinNonMaxSuppressionV5, - kLiteRtOpCodeTflScatterNd = kTfLiteBuiltinScatterNd, - kLiteRtOpCodeTflSelectV2 = kTfLiteBuiltinSelectV2, - kLiteRtOpCodeTflDensify = kTfLiteBuiltinDensify, - kLiteRtOpCodeTflSegmentSum = kTfLiteBuiltinSegmentSum, - kLiteRtOpCodeTflBatchMatmul = kTfLiteBuiltinBatchMatmul, - kLiteRtOpCodeTflPlaceholderForGreaterOpCodeTfls = - kTfLiteBuiltinPlaceholderForGreaterOpCodes, - kLiteRtOpCodeTflCumsum = kTfLiteBuiltinCumsum, - kLiteRtOpCodeTflCallOnce = kTfLiteBuiltinCallOnce, - kLiteRtOpCodeTflBroadcastTo = kTfLiteBuiltinBroadcastTo, - kLiteRtOpCodeTflRfft2d = kTfLiteBuiltinRfft2d, - kLiteRtOpCodeTflConv3d = kTfLiteBuiltinConv3d, - kLiteRtOpCodeTflImag = kTfLiteBuiltinImag, - kLiteRtOpCodeTflReal = kTfLiteBuiltinReal, - kLiteRtOpCodeTflComplexAbs = kTfLiteBuiltinComplexAbs, - kLiteRtOpCodeTflHashtable = kTfLiteBuiltinHashtable, - kLiteRtOpCodeTflHashtableFind = kTfLiteBuiltinHashtableFind, - kLiteRtOpCodeTflHashtableImport = kTfLiteBuiltinHashtableImport, - kLiteRtOpCodeTflHashtableSize = kTfLiteBuiltinHashtableSize, - kLiteRtOpCodeTflReduceAll = kTfLiteBuiltinReduceAll, - kLiteRtOpCodeTflConv3dTranspose = kTfLiteBuiltinConv3dTranspose, - kLiteRtOpCodeTflVarHandle = kTfLiteBuiltinVarHandle, - kLiteRtOpCodeTflReadVariable = kTfLiteBuiltinReadVariable, - kLiteRtOpCodeTflAssignVariable = kTfLiteBuiltinAssignVariable, - kLiteRtOpCodeTflBroadcastArgs = kTfLiteBuiltinBroadcastArgs, - kLiteRtOpCodeTflRandomStandardNormal = kTfLiteBuiltinRandomStandardNormal, - kLiteRtOpCodeTflBucketize = kTfLiteBuiltinBucketize, - kLiteRtOpCodeTflRandomUniform = kTfLiteBuiltinRandomUniform, - kLiteRtOpCodeTflMultinomial = kTfLiteBuiltinMultinomial, - kLiteRtOpCodeTflGelu = kTfLiteBuiltinGelu, - kLiteRtOpCodeTflDynamicUpdateSlice = kTfLiteBuiltinDynamicUpdateSlice, - kLiteRtOpCodeTflRelu0To1 = kTfLiteBuiltinRelu0To1, - kLiteRtOpCodeTflUnsortedSegmentProd = kTfLiteBuiltinUnsortedSegmentProd, - kLiteRtOpCodeTflUnsortedSegmentMax = kTfLiteBuiltinUnsortedSegmentMax, - kLiteRtOpCodeTflUnsortedSegmentSum = kTfLiteBuiltinUnsortedSegmentSum, - kLiteRtOpCodeTflAtan2 = kTfLiteBuiltinAtan2, - kLiteRtOpCodeTflUnsortedSegmentMin = kTfLiteBuiltinUnsortedSegmentMin, - kLiteRtOpCodeTflSign = kTfLiteBuiltinSign, - kLiteRtOpCodeTflBitcast = kTfLiteBuiltinBitcast, - kLiteRtOpCodeTflBitwiseXor = kTfLiteBuiltinBitwiseXor, - kLiteRtOpCodeTflRightShift = kTfLiteBuiltinRightShift, - kLiteRtOpCodeShloLogistic = kTfLiteBuiltinStablehloLogistic, - kLiteRtOpCodeShloAdd = kTfLiteBuiltinStablehloAdd, - kLiteRtOpCodeShloDivide = kTfLiteBuiltinStablehloDivide, - kLiteRtOpCodeShloMultiply = kTfLiteBuiltinStablehloMultiply, - kLiteRtOpCodeShloMaximum = kTfLiteBuiltinStablehloMaximum, - kLiteRtOpCodeShloReshape = kTfLiteBuiltinStablehloReshape, - kLiteRtOpCodeShloClamp = kTfLiteBuiltinStablehloClamp, - kLiteRtOpCodeShloConcatenate = kTfLiteBuiltinStablehloConcatenate, - kLiteRtOpCodeShloBroadcastInDim = kTfLiteBuiltinStablehloBroadcastInDim, - kLiteRtOpCodeShloConvolution = kTfLiteBuiltinStablehloConvolution, - kLiteRtOpCodeShloSlice = kTfLiteBuiltinStablehloSlice, - kLiteRtOpCodeShloCustomCall = kTfLiteBuiltinStablehloCustomCall, - kLiteRtOpCodeShloReduce = kTfLiteBuiltinStablehloReduce, - kLiteRtOpCodeShloAbs = kTfLiteBuiltinStablehloAbs, - kLiteRtOpCodeShloAnd = kTfLiteBuiltinStablehloAnd, - kLiteRtOpCodeShloCosine = kTfLiteBuiltinStablehloCosine, - kLiteRtOpCodeShloExponential = kTfLiteBuiltinStablehloExponential, - kLiteRtOpCodeShloFloor = kTfLiteBuiltinStablehloFloor, - kLiteRtOpCodeShloLog = kTfLiteBuiltinStablehloLog, - kLiteRtOpCodeShloMinimum = kTfLiteBuiltinStablehloMinimum, - kLiteRtOpCodeShloNegate = kTfLiteBuiltinStablehloNegate, - kLiteRtOpCodeShloOr = kTfLiteBuiltinStablehloOr, - kLiteRtOpCodeShloPower = kTfLiteBuiltinStablehloPower, - kLiteRtOpCodeShloRemainder = kTfLiteBuiltinStablehloRemainder, - kLiteRtOpCodeShloRsqrt = kTfLiteBuiltinStablehloRsqrt, - kLiteRtOpCodeShloSelect = kTfLiteBuiltinStablehloSelect, - kLiteRtOpCodeShloSubtract = kTfLiteBuiltinStablehloSubtract, - kLiteRtOpCodeShloTanh = kTfLiteBuiltinStablehloTanh, - kLiteRtOpCodeShloScatter = kTfLiteBuiltinStablehloScatter, - kLiteRtOpCodeShloCompare = kTfLiteBuiltinStablehloCompare, - kLiteRtOpCodeShloConvert = kTfLiteBuiltinStablehloConvert, - kLiteRtOpCodeShloDynamicSlice = kTfLiteBuiltinStablehloDynamicSlice, - kLiteRtOpCodeShloDynamicUpdateSlice = - kTfLiteBuiltinStablehloDynamicUpdateSlice, - kLiteRtOpCodeShloPad = kTfLiteBuiltinStablehloPad, - kLiteRtOpCodeShloIota = kTfLiteBuiltinStablehloIota, - kLiteRtOpCodeShloGeneral = kTfLiteBuiltinStablehloDotGeneral, - kLiteRtOpCodeShloWindow = kTfLiteBuiltinStablehloReduceWindow, - kLiteRtOpCodeShloSort = kTfLiteBuiltinStablehloSort, - kLiteRtOpCodeShloWhile = kTfLiteBuiltinStablehloWhile, - kLiteRtOpCodeShloGather = kTfLiteBuiltinStablehloGather, - kLiteRtOpCodeShloTranspose = kTfLiteBuiltinStablehloTranspose, - kLiteRtOpCodeTflDilate = kTfLiteBuiltinDilate, - kLiteRtOpCodeShloRngBitGenerator = kTfLiteBuiltinStablehloRngBitGenerator, - kLiteRtOpCodeTflReduceWindow = kTfLiteBuiltinReduceWindow, - kLiteRtOpCodeShloComposite = kTfLiteBuiltinStablehloComposite, -} LiteRtOpCode; - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OP_CODE_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_options.cc b/tensorflow/lite/experimental/litert/c/litert_options.cc deleted file mode 100644 index e14759ae641809..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_options.cc +++ /dev/null @@ -1,771 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_options.h" - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// -// Op Options -// - -LiteRtStatus LiteRtGetAddFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation) { - if (op->OpCode() != kLiteRtOpCodeTflAdd) { - return kLiteRtStatusErrorInvalidArgument; - } - const auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorNotFound; - } - *fused_activation = opts.AsAddOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetBatchMatmulAdjXOption(LiteRtOp op, bool* adj_x) { - if (op->OpCode() != kLiteRtOpCodeTflBatchMatmul) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *adj_x = opts.AsBatchMatMulOptions()->adj_x; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetBatchMatmulAdjYOption(LiteRtOp op, bool* adj_y) { - if (op->OpCode() != kLiteRtOpCodeTflBatchMatmul) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *adj_y = opts.AsBatchMatMulOptions()->adj_y; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetBatchMatmulAsymmetricQuantizeInputOption( - LiteRtOp op, bool* asymmetric_quantize_input) { - if (op->OpCode() != kLiteRtOpCodeTflBatchMatmul) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *asymmetric_quantize_input = - opts.AsBatchMatMulOptions()->asymmetric_quantize_inputs; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConcatenationFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation) { - if (op->OpCode() != kLiteRtOpCodeTflConcatenation) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation = opts.AsConcatenationOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConcatenationAxisOption(LiteRtOp op, int32_t* axis) { - if (op->OpCode() != kLiteRtOpCodeTflConcatenation) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *axis = opts.AsConcatenationOptions()->axis; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDivFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation) { - if (op->OpCode() != kLiteRtOpCodeTflDiv) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation = opts.AsDivOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetFullyConnectedFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation) { - if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation = opts.AsFullyConnectedOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetFullyConnectedKeepNumDimsOption(LiteRtOp op, - bool* keep_num_dims) { - if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *keep_num_dims = opts.AsFullyConnectedOptions()->keep_num_dims; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtFullyConnectedGetQuantizedBiasTypeOption( - LiteRtOp op, uint32_t* quantized_bias_type) { - if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *quantized_bias_type = opts.AsFullyConnectedOptions()->quantized_bias_type; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetFullyConnectedAsymmetricQuantizeInputOption( - LiteRtOp op, bool* asymmetric_quantize_input) { - if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *asymmetric_quantize_input = - opts.AsFullyConnectedOptions()->asymmetric_quantize_inputs; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetFullyConnectedWeightsFormatOption( - LiteRtOp op, uint32_t* weights_format) { - if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *weights_format = opts.AsFullyConnectedOptions()->weights_format; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetMulFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation) { - if (op->OpCode() != kLiteRtOpCodeTflMul) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation = opts.AsMulOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSoftmaxBetaOption(LiteRtOp op, float* beta) { - if (op->OpCode() != kLiteRtOpCodeTflSoftmax) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *beta = opts.AsSoftmaxOptions()->beta; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetStridedSliceBeginMaskOption(LiteRtOp op, - int32_t* begin_mask) { - if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *begin_mask = opts.AsStridedSliceOptions()->begin_mask; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetStridedSliceEndMaskOption(LiteRtOp op, - int32_t* end_mask) { - if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *end_mask = opts.AsStridedSliceOptions()->end_mask; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetStridedSliceEllipsisMaskOption(LiteRtOp op, - int32_t* ellipsis_mask) { - if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *ellipsis_mask = opts.AsStridedSliceOptions()->ellipsis_mask; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetStridedSliceNewAxisMaskOption(LiteRtOp op, - int32_t* new_axis_mask) { - if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *new_axis_mask = opts.AsStridedSliceOptions()->new_axis_mask; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetStridedSliceShrinkAxisMaskOption( - LiteRtOp op, int32_t* shrink_axis_mask) { - if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *shrink_axis_mask = opts.AsStridedSliceOptions()->shrink_axis_mask; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetStridedSliceOffsetOption(LiteRtOp op, bool* offset) { - if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *offset = opts.AsStridedSliceOptions()->offset; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSubFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation) { - if (op->OpCode() != kLiteRtOpCodeTflSub) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation = opts.AsSubOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetReshapeNewShapeOption(LiteRtOp op, - const int32_t** new_shape, - int32_t* new_shape_size) { - if (op->OpCode() != kLiteRtOpCodeTflReshape) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - *new_shape_size = -1; - return kLiteRtStatusErrorInvalidArgument; - } - if (opts.AsReshapeOptions() == nullptr) { - *new_shape_size = -1; - return kLiteRtStatusOk; - } else { - *new_shape = opts.AsReshapeOptions()->new_shape.data(); - *new_shape_size = opts.AsReshapeOptions()->new_shape.size(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSumKeepDimsOption(LiteRtOp op, bool* keepdims) { - if (op->OpCode() != kLiteRtOpCodeTflSum) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - // Sum OP options is stored as ReducerOptions. - *keepdims = opts.AsReducerOptions()->keep_dims; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetPackAxisOption(LiteRtOp op, int32_t* axis) { - if (op->OpCode() != kLiteRtOpCodeTflPack) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *axis = opts.AsPackOptions()->axis; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetGatherAxisOption(LiteRtOp op, int32_t* axis) { - if (op->OpCode() != kLiteRtOpCodeTflGather) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *axis = opts.AsGatherOptions()->axis; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetGatherBatchDimsOption(LiteRtOp op, int32_t* batch_dims) { - if (op->OpCode() != kLiteRtOpCodeTflGather) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *batch_dims = opts.AsGatherOptions()->batch_dims; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetMeanKeepDimsOption(LiteRtOp op, bool* keepdims) { - if (op->OpCode() != kLiteRtOpCodeTflMean) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - // Mean OP options is stored as ReducerOptions. - *keepdims = opts.AsReducerOptions()->keep_dims; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSplitNumSplitsOption(LiteRtOp op, int32_t* num_splits) { - if (op->OpCode() != kLiteRtOpCodeTflSplit) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_splits = opts.AsSplitOptions()->num_splits; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSHLOCompositeOpName(LiteRtOp op, const char** name) { - if (op->OpCode() != kLiteRtOpCodeShloComposite) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions2(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *name = opts.AsStableHLOCompositeOptions()->name.data(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSHLOCompositeOpDecompositionSubgraphIndex( - LiteRtOp op, int32_t* subgraph_index) { - if (op->OpCode() != kLiteRtOpCodeShloComposite) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions2(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *subgraph_index = - opts.AsStableHLOCompositeOptions()->decomposition_subgraph_index; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConv2dPaddingOption(LiteRtOp op, uint32_t* padding) { - if (op->OpCode() != kLiteRtOpCodeTflConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *padding = opts.AsConv2DOptions()->padding; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConv2dStrideWOption(LiteRtOp op, int32_t* stride_w) { - if (op->OpCode() != kLiteRtOpCodeTflConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *stride_w = opts.AsConv2DOptions()->stride_w; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConv2dStrideHOption(LiteRtOp op, int32_t* stride_h) { - if (op->OpCode() != kLiteRtOpCodeTflConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *stride_h = opts.AsConv2DOptions()->stride_h; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConv2dFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation_function) { - if (op->OpCode() != kLiteRtOpCodeTflConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation_function = - opts.AsConv2DOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConv2dDilationWOption(LiteRtOp op, - int32_t* dilation_w_factor) { - if (op->OpCode() != kLiteRtOpCodeTflConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *dilation_w_factor = opts.AsConv2DOptions()->dilation_w_factor; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetConv2dDilationHOption(LiteRtOp op, - int32_t* dilation_h_factor) { - if (op->OpCode() != kLiteRtOpCodeTflConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *dilation_h_factor = opts.AsConv2DOptions()->dilation_h_factor; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthwiseConv2dPaddingOption(LiteRtOp op, - uint32_t* padding) { - if (op->OpCode() != kLiteRtOpCodeTflDepthwiseConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *padding = opts.AsDepthwiseConv2DOptions()->padding; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthwiseConv2dStrideWOption(LiteRtOp op, - int32_t* stride_w) { - if (op->OpCode() != kLiteRtOpCodeTflDepthwiseConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *stride_w = opts.AsDepthwiseConv2DOptions()->stride_w; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthwiseConv2dStrideHOption(LiteRtOp op, - int32_t* stride_h) { - if (op->OpCode() != kLiteRtOpCodeTflDepthwiseConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *stride_h = opts.AsDepthwiseConv2DOptions()->stride_h; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthwiseConv2dFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation_function) { - if (op->OpCode() != kLiteRtOpCodeTflDepthwiseConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation_function = - opts.AsDepthwiseConv2DOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthwiseConv2dDilationWOption( - LiteRtOp op, int32_t* dilation_w_factor) { - if (op->OpCode() != kLiteRtOpCodeTflDepthwiseConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *dilation_w_factor = opts.AsDepthwiseConv2DOptions()->dilation_w_factor; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthwiseConv2dDilationHOptions( - LiteRtOp op, int32_t* dilation_h_factor) { - if (op->OpCode() != kLiteRtOpCodeTflDepthwiseConv2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *dilation_h_factor = opts.AsDepthwiseConv2DOptions()->dilation_h_factor; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dOptions(LiteRtOp op, int8_t* padding, - int32_t* stride_w, int32_t* stride_h, - int32_t* filter_width, - int32_t* filter_height, - int8_t* fused_activation_function) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - auto* options = opts.AsPool2DOptions(); - *padding = options->padding; - *stride_w = options->stride_w; - *stride_h = options->stride_h; - *filter_width = options->filter_width; - *filter_height = options->filter_height; - *fused_activation_function = options->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dPaddingOption(LiteRtOp op, - uint32_t* padding) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *padding = opts.AsPool2DOptions()->padding; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dStrideWOption(LiteRtOp op, - int32_t* stride_w) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *stride_w = opts.AsPool2DOptions()->stride_w; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dStrideHOption(LiteRtOp op, - int32_t* stride_h) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *stride_h = opts.AsPool2DOptions()->stride_h; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dFilterWidthOption(LiteRtOp op, - int32_t* filter_width) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *filter_width = opts.AsPool2DOptions()->filter_width; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dFilterHeightOption(LiteRtOp op, - int32_t* filter_height) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *filter_height = opts.AsPool2DOptions()->filter_height; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetAveragePool2dFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation_function) { - if (op->OpCode() != kLiteRtOpCodeTflAveragePool2d) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *fused_activation_function = - opts.AsPool2DOptions()->fused_activation_function; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetResizeBilinearAlignCornersOption(LiteRtOp op, - bool* align_corners) { - if (op->OpCode() != kLiteRtOpCodeTflResizeBilinear) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *align_corners = opts.AsResizeBilinearOptions()->align_corners; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetResizeBilinearHalfPixelCenterOption( - LiteRtOp op, bool* half_pixel_centers) { - if (op->OpCode() != kLiteRtOpCodeTflResizeBilinear) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *half_pixel_centers = opts.AsResizeBilinearOptions()->half_pixel_centers; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetLeakyReluAlphaOption(LiteRtOp op, float* alpha) { - if (op->OpCode() != kLiteRtOpCodeTflLeakyRelu) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *alpha = opts.AsLeakyReluOptions()->alpha; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetDepthToSpaceBlockSizeOption(LiteRtOp op, - int32_t* block_size) { - if (op->OpCode() != kLiteRtOpCodeTflDepthToSpace) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *block_size = opts.AsDepthToSpaceOptions()->block_size; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetSpaceToDepthBlockSizeOption(LiteRtOp op, - int32_t* block_size) { - if (op->OpCode() != kLiteRtOpCodeTflSpaceToDepth) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *block_size = opts.AsSpaceToDepthOptions()->block_size; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetResizeNearestNeighborAlignCornersOption( - LiteRtOp op, bool* align_corners) { - if (op->OpCode() != kLiteRtOpCodeTflResizeNearestNeighbor) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *align_corners = opts.AsResizeNearestNeighborOptions()->align_corners; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetResizeNearestNeighborHalfPixelCenterOption( - LiteRtOp op, bool* half_pixel_centers) { - if (op->OpCode() != kLiteRtOpCodeTflResizeNearestNeighbor) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& opts = litert::internal::GetTflOptions(*op); - if (opts.value == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - *half_pixel_centers = - opts.AsResizeNearestNeighborOptions()->half_pixel_centers; - return kLiteRtStatusOk; -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_options.h b/tensorflow/lite/experimental/litert/c/litert_options.h deleted file mode 100644 index ed746575c3d03e..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_options.h +++ /dev/null @@ -1,351 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ - -#include // NOLINT: To use bool type in C -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITERT_DEFINE_HANDLE(LiteRtOp); - -//============================================================================== -// -// Get option APIs for LiteRt ADD op. -// Options: -// - FusedActivationOption : uint32_t -// -//============================================================================== -LiteRtStatus LiteRtGetAddFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation); - -//============================================================================== -// -// Get option APIs for LiteRt BatchMatmul op. -// Options: -// - AdjXOption : bool -// - AdjYOption : bool -// - AsymmtericQuantizeInputOption : bool -// -//============================================================================== -LiteRtStatus LiteRtGetBatchMatmulAdjXOption(LiteRtOp op, bool* adj_x); -LiteRtStatus LiteRtGetBatchMatmulAdjYOption(LiteRtOp op, bool* adj_y); -LiteRtStatus LiteRtGetBatchMatmulAsymmetricQuantizeInputOption( - LiteRtOp op, bool* asymmetric_quantize_input); - -//============================================================================== -// -// Get option APIs for LiteRt Concatenation op. -// Options: -// - FusedActivationOption : uint32_t -// - AxisOption : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetConcatenationFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation); -LiteRtStatus LiteRtGetConcatenationAxisOption(LiteRtOp op, int32_t* axis); - -//============================================================================== -// -// Get option APIs for LiteRt Div op. -// Options: -// - FusedActivationOption : uint32_t -// -//============================================================================== -LiteRtStatus LiteRtGetDivFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation); - -//============================================================================== -// -// Get option APIs for LiteRt FullyConnected op. -// Options: -// - FusedActivationOption : uint32_t -// - WeightsFormatOption : uint32_t -// - KeepNumDimsOption : bool -// - QuantizedBiasTypeOption : uint32_t -// - AsymmtericQuantizeInputOption : bool -// -//============================================================================== -LiteRtStatus LiteRtGetFullyConnectedFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation); -LiteRtStatus LiteRtGetFullyConnectedWeightsFormatOption( - LiteRtOp op, uint32_t* weights_format); -LiteRtStatus LiteRtGetFullyConnectedKeepNumDimsOption(LiteRtOp op, - bool* keep_num_dims); -LiteRtStatus LiteRtFullyConnectedGetQuantizedBiasTypeOption( - LiteRtOp op, uint32_t* quantized_bias_type); -LiteRtStatus LiteRtGetFullyConnectedAsymmetricQuantizeInputOption( - LiteRtOp op, bool* asymmetric_quantize_input); - -//============================================================================== -// -// Get option APIs for LiteRt Mul op. -// Options: -// - FusedActivationOption : uint32_t -// -//============================================================================== -LiteRtStatus LiteRtGetMulFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation); - -//============================================================================== -// -// Get option APIs for LiteRt Softmax op. -// Options: -// - BetaOption : float -// -//============================================================================== -LiteRtStatus LiteRtGetSoftmaxBetaOption(LiteRtOp op, float* beta); - -//============================================================================== -// -// Get option APIs for LiteRt StridedSlice op. -// Options: -// - BeginMaskOption : int32_t -// - EndMaskOption : int32_t -// - EllipsisMaskOption : int32_t -// - NewAxisMaskOption : int32_t -// - ShrinkAxisMaskOption : int32_t -// - OffsetOption : bool - -//============================================================================== -LiteRtStatus LiteRtGetStridedSliceBeginMaskOption(LiteRtOp op, - int32_t* begin_mask); -LiteRtStatus LiteRtGetStridedSliceEndMaskOption(LiteRtOp op, int32_t* end_mask); -LiteRtStatus LiteRtGetStridedSliceEllipsisMaskOption(LiteRtOp op, - int32_t* ellipsis_mask); -LiteRtStatus LiteRtGetStridedSliceNewAxisMaskOption(LiteRtOp op, - int32_t* new_axis_mask); -LiteRtStatus LiteRtGetStridedSliceShrinkAxisMaskOption( - LiteRtOp op, int32_t* shrink_axis_mask); -LiteRtStatus LiteRtGetStridedSliceOffsetOption(LiteRtOp op, bool* offset); - -//============================================================================== -// -// Get option APIs for LiteRt Sub op. -// Options: -// - FusedActivationOption : uint32_t -// - (Not supported) PotScaleInt16Option : bool -// -//============================================================================== -LiteRtStatus LiteRtGetSubFusedActivationOption(LiteRtOp op, - uint32_t* fused_activation); - -//============================================================================== -// -// Get option APIs for LiteRt Reshape op. -// Options: -// - new_shape : int32_t[] -// -//============================================================================== -LiteRtStatus LiteRtGetReshapeNewShapeOption(LiteRtOp op, - const int32_t** new_shape, - int32_t* new_shape_size); - -//============================================================================== -// -// Get option APIs for LiteRt Sum op. -// Options: -// - KeepdimsOption : bool -// -//============================================================================== -LiteRtStatus LiteRtGetSumKeepDimsOption(LiteRtOp op, bool* keepdims); - -//============================================================================== -// -// Get option APIs for LiteRt Pack op. -// Options: -// - axisOption : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetPackAxisOption(LiteRtOp op, int32_t* axis); - -//============================================================================== -// -// Get option APIs for LiteRt Gather op. -// Options: -// - axisOption : int32_t -// - batch_dims : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetGatherAxisOption(LiteRtOp op, int32_t* axis); -LiteRtStatus LiteRtGetGatherBatchDimsOption(LiteRtOp op, int32_t* batch_dims); - -//============================================================================== -// -// Get option APIs for LiteRt Mean op. -// Options: -// - keepdimsOption : bool -// -//============================================================================== -LiteRtStatus LiteRtGetMeanKeepDimsOption(LiteRtOp op, bool* keepdims); - -//============================================================================== -// -// Get option APIs for LiteRt Split op. -// Options: -// - num_splits : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetSplitNumSplitsOption(LiteRtOp op, int32_t* num_splits); - -//============================================================================== -// -// Get option APIs for LiteRt SHLO Composite op. -// Options: -// - name : string -// - decomposition_subgraph_index : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetSHLOCompositeOpName(LiteRtOp op, const char** name); -LiteRtStatus LiteRtGetSHLOCompositeOpDecompositionSubgraphIndex( - LiteRtOp op, int32_t* decomposition_subgraph_index); - -//============================================================================== -// -// Get option APIs for LiteRt Conv2d op. -// Options: -// - padding : uint32_t -// - stride_w : int32_t -// - stride_h : int32_t -// - fused_activation_function : uint32_t -// - dilation_w_factor : int32_t -// - dilation_h_factor : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetConv2dPaddingOption(LiteRtOp op, uint32_t* padding); -LiteRtStatus LiteRtGetConv2dStrideWOption(LiteRtOp op, int32_t* stride_w); -LiteRtStatus LiteRtGetConv2dStrideHOption(LiteRtOp op, int32_t* stride_h); -LiteRtStatus LiteRtGetConv2dFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation_function); -LiteRtStatus LiteRtGetConv2dDilationWOption(LiteRtOp op, - int32_t* dilation_w_factor); -LiteRtStatus LiteRtGetConv2dDilationHOption(LiteRtOp op, - int32_t* dilation_h_factor); - -//============================================================================== -// -// Get option APIs for LiteRt DepthwiseConv2d op. -// Options: -// - padding : uint32_t -// - stride_w : int32_t -// - stride_h : int32_t -// - fused_activation_function : uint32_t -// - dilation_w_factor : int32_t -// - dilation_h_factor : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetDepthwiseConv2dPaddingOption(LiteRtOp op, - uint32_t* padding); -LiteRtStatus LiteRtGetDepthwiseConv2dStrideWOption(LiteRtOp op, - int32_t* stride_w); -LiteRtStatus LiteRtGetDepthwiseConv2dStrideHOption(LiteRtOp op, - int32_t* stride_h); -LiteRtStatus LiteRtGetDepthwiseConv2dFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation_function); -LiteRtStatus LiteRtGetDepthwiseConv2dDilationWOption( - LiteRtOp op, int32_t* dilation_w_factor); -LiteRtStatus LiteRtGetDepthwiseConv2dDilationHOptions( - LiteRtOp op, int32_t* dilation_h_factor); - -//============================================================================== -// -// Get option APIs for LiteRt AveragePool2d op. -// Options: -// - padding : uint32_t -// - stride_w : int32_t -// - stride_h : int32_t -// - filter_width : int32_t -// - filter_height : int32_t -// - fused_activation_function : uint32_t -// -//============================================================================== -LiteRtStatus LiteRtGetAveragePool2dPaddingOption(LiteRtOp op, - uint32_t* padding); -LiteRtStatus LiteRtGetAveragePool2dStrideWOption(LiteRtOp op, - int32_t* stride_w); -LiteRtStatus LiteRtGetAveragePool2dStrideHOption(LiteRtOp op, - int32_t* stride_h); -LiteRtStatus LiteRtGetAveragePool2dFilterWidthOption(LiteRtOp op, - int32_t* filter_width); -LiteRtStatus LiteRtGetAveragePool2dFilterHeightOption(LiteRtOp op, - int32_t* filter_height); -LiteRtStatus LiteRtGetAveragePool2dFusedActivationOption( - LiteRtOp op, uint32_t* fused_activation_function); - -//============================================================================== -// -// Get option APIs for LiteRt ResizeBilinear op. -// Options: -// - align_corners : bool -// - half_pixel_centers : bool -// -//============================================================================== -LiteRtStatus LiteRtGetResizeBilinearAlignCornersOption(LiteRtOp op, - bool* align_corners); -LiteRtStatus LiteRtGetResizeBilinearHalfPixelCenterOption( - LiteRtOp op, bool* half_pixel_centers); - -//============================================================================== -// -// Get option APIs for LiteRt LeakyRelu op. -// Options: -// - alpha : float -// -//============================================================================== -LiteRtStatus LiteRtGetLeakyReluAlphaOption(LiteRtOp op, float* alpha); - -//============================================================================== -// -// Get option APIs for LiteRt DepthToSpace op. -// Options: -// - block_size : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetDepthToSpaceBlockSizeOption(LiteRtOp op, - int32_t* block_size); - -//============================================================================== -// -// Get option APIs for LiteRt SpaceToDepth op. -// Options: -// - block_size : int32_t -// -//============================================================================== -LiteRtStatus LiteRtGetSpaceToDepthBlockSizeOption(LiteRtOp op, - int32_t* block_size); - -//============================================================================== -// -// Get option APIs for LiteRt ResizeNearestNeighbor op. -// Options: -// - align_corners : bool -// - half_pixel_centers : bool -// -//============================================================================== -LiteRtStatus LiteRtGetResizeNearestNeighborAlignCornersOption( - LiteRtOp op, bool* align_corners); -LiteRtStatus LiteRtGetResizeNearestNeighborHalfPixelCenterOption( - LiteRtOp op, bool* half_pixel_centers); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_options_test.cc b/tensorflow/lite/experimental/litert/c/litert_options_test.cc deleted file mode 100644 index 41c0c07f8bf39e..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_options_test.cc +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_options.h" - -#include - -#include // IWYU pragma: keep -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { -using testing::litert::IsError; - -TEST(GetOpOptionTest, TestGetAddOptions) { - auto model = litert::testing::LoadTestFileModel("simple_add_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t fused_activation; - LITERT_ASSERT_OK(LiteRtGetAddFusedActivationOption(op, &fused_activation)); - ASSERT_EQ(fused_activation, 0); -} - -TEST(GetOpOptionTest, TestGetBatchMatmulOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_batch_matmul_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - bool adj_x; - LITERT_ASSERT_OK(LiteRtGetBatchMatmulAdjXOption(op, &adj_x)); - ASSERT_EQ(adj_x, false); - - bool adj_y; - LITERT_ASSERT_OK(LiteRtGetBatchMatmulAdjYOption(op, &adj_y)); - ASSERT_EQ(adj_y, false); - - bool asymmetric_quantize_input; - LITERT_ASSERT_OK(LiteRtGetBatchMatmulAsymmetricQuantizeInputOption( - op, &asymmetric_quantize_input)); - ASSERT_EQ(asymmetric_quantize_input, false); -} - -TEST(GetOpOptionTest, TestGetConcatenationOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_concatenation_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t fused_activation; - LITERT_ASSERT_OK( - LiteRtGetConcatenationFusedActivationOption(op, &fused_activation)); - ASSERT_EQ(fused_activation, 0); - - int32_t axis; - LITERT_ASSERT_OK(LiteRtGetConcatenationAxisOption(op, &axis)); - ASSERT_EQ(axis, 2); -} - -TEST(GetOpOptionTest, TestGetDivOptions) { - auto model = litert::testing::LoadTestFileModel("simple_div_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t fused_activation; - LITERT_ASSERT_OK(LiteRtGetDivFusedActivationOption(op, &fused_activation)); - ASSERT_EQ(fused_activation, 0); -} - -TEST(GetOpOptionTest, TestGetFullyConnectedOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_fully_connected_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t fused_activation; - LITERT_ASSERT_OK( - LiteRtGetFullyConnectedFusedActivationOption(op, &fused_activation)); - ASSERT_EQ(fused_activation, 0); - - uint32_t weights_format; - LITERT_ASSERT_OK( - LiteRtGetFullyConnectedWeightsFormatOption(op, &weights_format)); - ASSERT_EQ(weights_format, 0); - - bool keep_num_dims; - LITERT_ASSERT_OK( - LiteRtGetFullyConnectedKeepNumDimsOption(op, &keep_num_dims)); - ASSERT_EQ(keep_num_dims, true); - - uint32_t quantized_bias_type; - LITERT_ASSERT_OK( - LiteRtFullyConnectedGetQuantizedBiasTypeOption(op, &quantized_bias_type)); - ASSERT_EQ(quantized_bias_type, 0); - - bool asymmetric_quantize_input; - LITERT_ASSERT_OK(LiteRtGetFullyConnectedAsymmetricQuantizeInputOption( - op, &asymmetric_quantize_input)); - ASSERT_EQ(asymmetric_quantize_input, false); -} - -TEST(GetOpOptionTest, TestGetMulOptions) { - auto model = litert::testing::LoadTestFileModel("simple_mul_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t fused_activation; - LITERT_ASSERT_OK(LiteRtGetMulFusedActivationOption(op, &fused_activation)); - ASSERT_EQ(fused_activation, 0); -} - -TEST(GetOpOptionTest, TestGetSoftmaxOptions) { - auto model = litert::testing::LoadTestFileModel("simple_softmax_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - float beta; - LITERT_ASSERT_OK(LiteRtGetSoftmaxBetaOption(op, &beta)); - EXPECT_FLOAT_EQ(beta, 1.0); -} - -TEST(GetOpOptionTest, TestGetStridedSliceOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_strided_slice_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - int32_t begin_mask; - LITERT_ASSERT_OK(LiteRtGetStridedSliceBeginMaskOption(op, &begin_mask)); - ASSERT_EQ(begin_mask, 0); - - int32_t end_mask; - LITERT_ASSERT_OK(LiteRtGetStridedSliceEndMaskOption(op, &end_mask)); - ASSERT_EQ(end_mask, 0); - - int32_t ellipsis_mask; - LITERT_ASSERT_OK(LiteRtGetStridedSliceEllipsisMaskOption(op, &ellipsis_mask)); - ASSERT_EQ(ellipsis_mask, 0); - - int32_t new_axis_mask; - LITERT_ASSERT_OK(LiteRtGetStridedSliceNewAxisMaskOption(op, &new_axis_mask)); - ASSERT_EQ(new_axis_mask, 0); - - int32_t shrink_axis_mask; - LITERT_ASSERT_OK( - LiteRtGetStridedSliceShrinkAxisMaskOption(op, &shrink_axis_mask)); - ASSERT_EQ(shrink_axis_mask, 0); - - bool offset; - LITERT_ASSERT_OK(LiteRtGetStridedSliceOffsetOption(op, &offset)); - ASSERT_EQ(offset, false); -} - -TEST(GetOpOptionTest, TestGetSubOptions) { - auto model = litert::testing::LoadTestFileModel("simple_sub_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t fused_activation; - LITERT_ASSERT_OK(LiteRtGetSubFusedActivationOption(op, &fused_activation)); - ASSERT_EQ(fused_activation, 0); -} - -TEST(GetOpOptionTest, TestGetNullReshapeOptions) { - auto model = litert::testing::LoadTestFileModel("simple_reshape_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - const int32_t* new_shape = nullptr; - int32_t new_shape_size; - - EXPECT_THAT(LiteRtGetReshapeNewShapeOption(op, &new_shape, &new_shape_size), - IsError(kLiteRtStatusErrorInvalidArgument)); - ASSERT_EQ(new_shape_size, -1); -} - -TEST(GetOpOptionTest, TestGetSumOptions) { - auto model = litert::testing::LoadTestFileModel("simple_sum_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - bool keepdims; - LITERT_ASSERT_OK(LiteRtGetSumKeepDimsOption(op, &keepdims)); - ASSERT_EQ(keepdims, true); -} - -TEST(GetOpOptionTest, TestGetPackOptions) { - auto model = litert::testing::LoadTestFileModel("simple_pack_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - int32_t axis; - LITERT_ASSERT_OK(LiteRtGetPackAxisOption(op, &axis)); - ASSERT_EQ(axis, 0); -} - -TEST(GetOpOptionTest, TestGetGatherOptions) { - auto model = litert::testing::LoadTestFileModel("simple_gather_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - int32_t axis; - LITERT_ASSERT_OK(LiteRtGetGatherAxisOption(op, &axis)); - ASSERT_EQ(axis, 0); - - int32_t batch_dims; - LITERT_ASSERT_OK(LiteRtGetGatherBatchDimsOption(op, &batch_dims)); - ASSERT_EQ(batch_dims, 0); -} - -TEST(GetOpOptionTest, TestGetMeanOptions) { - auto model = litert::testing::LoadTestFileModel("simple_mean_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - bool keepdims; - LITERT_ASSERT_OK(LiteRtGetMeanKeepDimsOption(op, &keepdims)); - ASSERT_EQ(keepdims, false); -} - -TEST(GetOpOptionTest, TestGetSplitOptions) { - auto model = litert::testing::LoadTestFileModel("simple_split_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - int32_t num_splits; - LITERT_ASSERT_OK(LiteRtGetSplitNumSplitsOption(op, &num_splits)); - ASSERT_EQ(num_splits, 3); -} - -TEST(GetOpOptionTest, TestGetConv2dOptions) { - auto model = litert::testing::LoadTestFileModel("simple_conv_2d_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t padding; - LITERT_ASSERT_OK(LiteRtGetConv2dPaddingOption(op, &padding)); - ASSERT_EQ(padding, 0); - int32_t stride_w; - LITERT_ASSERT_OK(LiteRtGetConv2dStrideWOption(op, &stride_w)); - ASSERT_EQ(stride_w, 1); - int32_t stride_h; - LITERT_ASSERT_OK(LiteRtGetConv2dStrideHOption(op, &stride_h)); - ASSERT_EQ(stride_h, 1); - uint32_t fused_activation_function; - LITERT_ASSERT_OK( - LiteRtGetConv2dFusedActivationOption(op, &fused_activation_function)); - ASSERT_EQ(fused_activation_function, 0); - int32_t dilation_w_factor; - LITERT_ASSERT_OK(LiteRtGetConv2dDilationWOption(op, &dilation_w_factor)); - ASSERT_EQ(dilation_w_factor, 1); - int32_t dilation_h_factor; - LITERT_ASSERT_OK(LiteRtGetConv2dDilationWOption(op, &dilation_h_factor)); - ASSERT_EQ(dilation_h_factor, 1); -} - -TEST(GetOpOptionTest, TestGetDepthwiseConv2dOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_depthwise_conv_2d_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t padding; - LITERT_ASSERT_OK(LiteRtGetDepthwiseConv2dPaddingOption(op, &padding)); - ASSERT_EQ(padding, 1); - int32_t stride_w; - LITERT_ASSERT_OK(LiteRtGetDepthwiseConv2dStrideWOption(op, &stride_w)); - ASSERT_EQ(stride_w, 1); - int32_t stride_h; - LITERT_ASSERT_OK(LiteRtGetDepthwiseConv2dStrideHOption(op, &stride_h)); - ASSERT_EQ(stride_h, 1); - uint32_t fused_activation_function; - LITERT_ASSERT_OK(LiteRtGetDepthwiseConv2dFusedActivationOption( - op, &fused_activation_function)); - ASSERT_EQ(fused_activation_function, 0); - int32_t dilation_w_factor; - LITERT_ASSERT_OK( - LiteRtGetDepthwiseConv2dDilationWOption(op, &dilation_w_factor)); - ASSERT_EQ(dilation_w_factor, 4); - int32_t dilation_h_factor; - LITERT_ASSERT_OK( - LiteRtGetDepthwiseConv2dDilationHOptions(op, &dilation_h_factor)); - ASSERT_EQ(dilation_h_factor, 4); -} - -TEST(GetOpOptionTest, TestGetAveragePool2dOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_average_poll_2d.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - uint32_t padding; - LITERT_ASSERT_OK(LiteRtGetAveragePool2dPaddingOption(op, &padding)); - ASSERT_EQ(padding, 1); - int32_t stride_w; - LITERT_ASSERT_OK(LiteRtGetAveragePool2dStrideWOption(op, &stride_w)); - ASSERT_EQ(stride_w, 4); - int32_t stride_h; - LITERT_ASSERT_OK(LiteRtGetAveragePool2dStrideHOption(op, &stride_h)); - ASSERT_EQ(stride_h, 4); - int32_t filter_width; - LITERT_ASSERT_OK(LiteRtGetAveragePool2dFilterWidthOption(op, &filter_width)); - ASSERT_EQ(filter_width, 4); - int32_t filter_height; - LITERT_ASSERT_OK( - LiteRtGetAveragePool2dFilterHeightOption(op, &filter_height)); - ASSERT_EQ(filter_height, 4); - uint32_t fused_activation_function; - LITERT_ASSERT_OK(LiteRtGetAveragePool2dFusedActivationOption( - op, &fused_activation_function)); - ASSERT_EQ(fused_activation_function, 0); -} - -TEST(GetOpOptionTest, TestGetResizeBilinearOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_resize_bilinear_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - bool align_corners; - LITERT_ASSERT_OK( - LiteRtGetResizeBilinearAlignCornersOption(op, &align_corners)); - ASSERT_EQ(align_corners, false); - bool half_pixel_centers; - LITERT_ASSERT_OK( - LiteRtGetResizeBilinearHalfPixelCenterOption(op, &half_pixel_centers)); - ASSERT_EQ(half_pixel_centers, true); -} - -TEST(GetOpOptionTest, TestGetLeakyReluOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_leaky_relu_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - float alpha; - LITERT_ASSERT_OK(LiteRtGetLeakyReluAlphaOption(op, &alpha)); - ASSERT_FLOAT_EQ(alpha, 0.2); -} - -TEST(GetOpOptionTest, TestGetDepthToSpaceOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_depth_to_space_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - int32_t block_size; - LITERT_ASSERT_OK(LiteRtGetDepthToSpaceBlockSizeOption(op, &block_size)); - ASSERT_EQ(block_size, 2); -} - -TEST(GetOpOptionTest, TestGetSpaceToDepthOptions) { - auto model = - litert::testing::LoadTestFileModel("simple_space_to_depth_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - int32_t block_size; - LITERT_ASSERT_OK(LiteRtGetSpaceToDepthBlockSizeOption(op, &block_size)); - ASSERT_EQ(block_size, 2); -} - -TEST(GetOpOptionTest, TestGetResizeNearestNeighborOptions) { - auto model = litert::testing::LoadTestFileModel( - "simple_resize_nearest_neighbor_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - auto op = ops.front().Get(); - - bool align_corners; - LITERT_ASSERT_OK( - LiteRtGetResizeNearestNeighborAlignCornersOption(op, &align_corners)); - ASSERT_EQ(align_corners, false); - bool half_pixel_centers; - LITERT_ASSERT_OK(LiteRtGetResizeNearestNeighborHalfPixelCenterOption( - op, &half_pixel_centers)); - ASSERT_EQ(half_pixel_centers, true); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc deleted file mode 100644 index 30588753ecb1ff..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc +++ /dev/null @@ -1,480 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" - -#if LITERT_HAS_OPENCL_SUPPORT -#include -#endif // LITERT_HAS_OPENCL_SUPPORT - -#ifdef __cplusplus -extern "C" { -#endif - -LiteRtStatus LiteRtCreateTensorBufferFromHostMemory( - const LiteRtRankedTensorType* tensor_type, void* host_buffer_addr, - size_t size, LiteRtHostMemoryDeallocator deallocator, - LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !host_buffer_addr || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromHostMemory( - *tensor_type, - absl::MakeSpan(static_cast(host_buffer_addr), size), - deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -#if LITERT_HAS_AHWB_SUPPORT -LiteRtStatus LiteRtCreateTensorBufferFromAhwb( - const LiteRtRankedTensorType* tensor_type, AHardwareBuffer* ahwb, - size_t ahwb_offset, LiteRtAhwbDeallocator deallocator, - LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !ahwb || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromAhwb( - *tensor_type, ahwb, ahwb_offset, deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferAhwb(LiteRtTensorBuffer tensor_buffer, - AHardwareBuffer** ahwb) { - if (!tensor_buffer || !ahwb) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto ahwb_buffer = tensor_buffer->GetAhwbBuffer(); - if (!ahwb_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", ahwb_buffer.Error().Message().c_str()); - return ahwb_buffer.Error().Status(); - } - - *ahwb = *ahwb_buffer; - return kLiteRtStatusOk; -} -#endif // LITERT_HAS_AHWB_SUPPORT - -#if LITERT_HAS_ION_SUPPORT -LiteRtStatus LiteRtCreateTensorBufferFromIonBuffer( - const LiteRtRankedTensorType* tensor_type, void* ion_buffer_addr, - int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, - LiteRtIonDeallocator deallocator, LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromIonBuffer( - *tensor_type, ion_buffer_addr, ion_buffer_fd, ion_buffer_size, - ion_buffer_offset, deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferIonBuffer(LiteRtTensorBuffer tensor_buffer, - void** ion_buffer_addr, - int* ion_buffer_fd) { - if (!tensor_buffer || !ion_buffer_addr || !ion_buffer_fd) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto ion_buffer = tensor_buffer->GetIonBuffer(); - if (!ion_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", ion_buffer.Error().Message().c_str()); - return ion_buffer.Error().Status(); - } - - *ion_buffer_addr = ion_buffer->first; - *ion_buffer_fd = ion_buffer->second; - return kLiteRtStatusOk; -} -#endif // LITERT_HAS_ION_SUPPORT - -#if LITERT_HAS_DMABUF_SUPPORT -LiteRtStatus LiteRtCreateTensorBufferFromDmaBufBuffer( - const LiteRtRankedTensorType* tensor_type, void* dmabuf_buffer_addr, - int dmabuf_buffer_fd, size_t dmabuf_buffer_size, - size_t dmabuf_buffer_offset, LiteRtDmaBufDeallocator deallocator, - LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromDmaBufBuffer( - *tensor_type, dmabuf_buffer_addr, dmabuf_buffer_fd, dmabuf_buffer_size, - dmabuf_buffer_offset, deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferDmaBufBuffer(LiteRtTensorBuffer tensor_buffer, - void** dmabuf_buffer_addr, - int* dmabuf_buffer_fd) { - if (!tensor_buffer || !dmabuf_buffer_addr || !dmabuf_buffer_fd) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto dmabuf_buffer = tensor_buffer->GetDmaBufBuffer(); - if (!dmabuf_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", dmabuf_buffer.Error().Message().c_str()); - return dmabuf_buffer.Error().Status(); - } - - *dmabuf_buffer_addr = dmabuf_buffer->first; - *dmabuf_buffer_fd = dmabuf_buffer->second; - return kLiteRtStatusOk; -} -#endif // LITERT_HAS_DMABUF_SUPPORT - -#if LITERT_HAS_OPENCL_SUPPORT -LiteRtStatus LiteRtCreateTensorBufferFromOpenClBuffer( - const LiteRtRankedTensorType* tensor_type, cl_mem cl_mem_addr, - size_t opencl_buffer_size, LiteRtOpenClDeallocator deallocator, - LiteRtTensorBuffer* buffer) { - if (!tensor_type || !buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromOpenClBuffer( - *tensor_type, cl_mem_addr, opencl_buffer_size); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferOpenClBuffer(LiteRtTensorBuffer tensor_buffer, - cl_mem* cl_mem_addr) { - if (!tensor_buffer || !cl_mem_addr) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto opencl_buffer = tensor_buffer->GetOpenClBuffer(); - if (!opencl_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", opencl_buffer.Error().Message().c_str()); - return opencl_buffer.Error().Status(); - } - - *cl_mem_addr = (*opencl_buffer)->GetMemoryPtr(); - return kLiteRtStatusOk; -} -#endif // LITERT_HAS_OPENCL_SUPPORT - -#if LITERT_HAS_FASTRPC_SUPPORT -LiteRtStatus LiteRtCreateTensorBufferFromFastRpcBuffer( - const LiteRtRankedTensorType* tensor_type, void* fastrpc_buffer_addr, - int fastrpc_buffer_fd, size_t fastrpc_buffer_size, - size_t fastrpc_buffer_offset, LiteRtFastRpcDeallocator deallocator, - LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromFastRpcBuffer( - *tensor_type, fastrpc_buffer_addr, fastrpc_buffer_fd, fastrpc_buffer_size, - fastrpc_buffer_offset, deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferFastRpcBuffer( - LiteRtTensorBuffer tensor_buffer, void** fastrpc_buffer_addr, - int* fastrpc_buffer_fd) { - if (!tensor_buffer || !fastrpc_buffer_addr || !fastrpc_buffer_fd) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto fastrpc_buffer = tensor_buffer->GetFastRpcBuffer(); - if (!fastrpc_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", fastrpc_buffer.Error().Message().c_str()); - return fastrpc_buffer.Error().Status(); - } - - *fastrpc_buffer_addr = fastrpc_buffer->first; - *fastrpc_buffer_fd = fastrpc_buffer->second; - return kLiteRtStatusOk; -} -#endif // LITERT_HAS_FASTRPC_SUPPORT - -LiteRtStatus LiteRtCreateTensorBufferFromGlBuffer( - const LiteRtRankedTensorType* tensor_type, LiteRtGLenum target, - LiteRtGLuint id, size_t size_bytes, size_t offset, - LiteRtGlBufferDeallocator deallocator, LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromGlBuffer( - *tensor_type, target, id, size_bytes, offset, deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().data()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferGlBuffer(LiteRtTensorBuffer tensor_buffer, - LiteRtGLenum* target, - LiteRtGLuint* id, size_t* size_bytes, - size_t* offset) { - if (!tensor_buffer || !target || !id) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto gl_buffer_expected = tensor_buffer->GetGlBuffer(); - if (!gl_buffer_expected) { - LITERT_LOG(LITERT_ERROR, "%s", - gl_buffer_expected.Error().Message().c_str()); - return gl_buffer_expected.Error().Status(); - } - *target = (*gl_buffer_expected)->target(); - *id = (*gl_buffer_expected)->id(); - *size_bytes = (*gl_buffer_expected)->size_bytes(); - *offset = (*gl_buffer_expected)->offset(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCreateTensorBufferFromGlTexture( - const LiteRtRankedTensorType* tensor_type, LiteRtGLenum target, - LiteRtGLuint id, LiteRtGLenum format, size_t size_bytes, LiteRtGLint layer, - LiteRtGlTextureDeallocator deallocator, LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromGlTexture( - *tensor_type, target, id, format, size_bytes, layer, deallocator); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferGlTexture( - LiteRtTensorBuffer tensor_buffer, LiteRtGLenum* target, LiteRtGLuint* id, - LiteRtGLenum* format, size_t* size_bytes, LiteRtGLint* layer) { - if (!tensor_buffer || !target || !id || !format || !size_bytes || !layer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto gl_texture_expected = tensor_buffer->GetGlTexture(); - if (!gl_texture_expected) { - LITERT_LOG(LITERT_ERROR, "%s", - gl_texture_expected.Error().Message().c_str()); - return gl_texture_expected.Error().Status(); - } - *target = (*gl_texture_expected)->target(); - *id = (*gl_texture_expected)->id(); - *format = (*gl_texture_expected)->format(); - *size_bytes = (*gl_texture_expected)->size_bytes(); - *layer = (*gl_texture_expected)->layer(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCreateManagedTensorBuffer( - LiteRtTensorBufferType buffer_type, - const LiteRtRankedTensorType* tensor_type, size_t buffer_size, - LiteRtTensorBuffer* tensor_buffer) { - if (!tensor_type || !tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - auto created_tensor_buffer = LiteRtTensorBufferT::CreateManaged( - buffer_type, *tensor_type, buffer_size); - if (!created_tensor_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", - created_tensor_buffer.Error().Message().c_str()); - return created_tensor_buffer.Error().Status(); - } - *tensor_buffer = created_tensor_buffer->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtDuplicateTensorBuffer(LiteRtTensorBuffer tensor_buffer) { - if (!tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - tensor_buffer->Duplicate(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferType(LiteRtTensorBuffer tensor_buffer, - LiteRtTensorBufferType* buffer_type) { - if (!tensor_buffer || !buffer_type) { - return kLiteRtStatusErrorInvalidArgument; - } - *buffer_type = tensor_buffer->buffer_type(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferTensorType( - LiteRtTensorBuffer tensor_buffer, LiteRtRankedTensorType* tensor_type) { - if (!tensor_buffer || !tensor_type) { - return kLiteRtStatusErrorInvalidArgument; - } - *tensor_type = tensor_buffer->tensor_type(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferSize(LiteRtTensorBuffer tensor_buffer, - size_t* buffer_size) { - if (!tensor_buffer || !buffer_size) { - return kLiteRtStatusErrorInvalidArgument; - } - *buffer_size = tensor_buffer->buffer_size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferOffset(LiteRtTensorBuffer tensor_buffer, - size_t* buffer_offset) { - if (!tensor_buffer || !buffer_offset) { - return kLiteRtStatusErrorInvalidArgument; - } - *buffer_offset = tensor_buffer->buffer_offset(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferHostMemory(LiteRtTensorBuffer tensor_buffer, - void** host_memory_addr) { - if (!tensor_buffer || !host_memory_addr) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto host_buffer = tensor_buffer->GetHostBuffer(); - if (!host_buffer) { - LITERT_LOG(LITERT_ERROR, "%s", host_buffer.Error().Message().c_str()); - return host_buffer.Error().Status(); - } - - *host_memory_addr = *host_buffer; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtHasTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, - bool* has_event) { - if (!tensor_buffer || !has_event) { - return kLiteRtStatusErrorInvalidArgument; - } - *has_event = tensor_buffer->HasEvent(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, - LiteRtEvent* event) { - if (!tensor_buffer || !event) { - return kLiteRtStatusErrorInvalidArgument; - } - auto result = tensor_buffer->GetEvent(); - if (!result) { - LITERT_LOG(LITERT_ERROR, "%s", result.Error().Message().c_str()); - return result.Error().Status(); - } - *event = *result; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtSetTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, - LiteRtEvent event) { - if (!tensor_buffer || !event) { - return kLiteRtStatusErrorInvalidArgument; - } - tensor_buffer->SetEvent(event); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtClearTensorBufferEvent(LiteRtTensorBuffer tensor_buffer) { - if (!tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - tensor_buffer->ClearEvent(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtLockTensorBuffer(LiteRtTensorBuffer tensor_buffer, - void** host_mem_addr) { - if (!tensor_buffer || !host_mem_addr) { - return kLiteRtStatusErrorInvalidArgument; - } - - auto mapped_addr = tensor_buffer->Lock(); - if (!mapped_addr) { - LITERT_LOG(LITERT_ERROR, "%s", mapped_addr.Error().Message().c_str()); - return mapped_addr.Error().Status(); - } - - *host_mem_addr = *mapped_addr; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtUnlockTensorBuffer(LiteRtTensorBuffer tensor_buffer) { - if (!tensor_buffer) { - return kLiteRtStatusErrorInvalidArgument; - } - - if (auto status = tensor_buffer->Unlock(); !status) { - LITERT_LOG(LITERT_ERROR, "%s", status.Error().Message().c_str()); - return status.Error().Status(); - } - - return kLiteRtStatusOk; -} - -void LiteRtDestroyTensorBuffer(LiteRtTensorBuffer tensor_buffer) { - if (tensor_buffer->Unref()) { - delete tensor_buffer; - } -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h deleted file mode 100644 index 7f1fd8af836e1a..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#if LITERT_HAS_OPENCL_SUPPORT -#include -#endif // LITERT_HAS_OPENCL_SUPPORT -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" - -#if LITERT_HAS_AHWB_SUPPORT -#include -#else -// Define a place holder AHardwareBuffer struct just to enable compilation. -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus -typedef struct AHardwareBuffer AHardwareBuffer; -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus -#endif // LITERT_HAS_AHWB_SUPPORT - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITERT_DEFINE_HANDLE(LiteRtTensorBuffer); - -#define LITERT_HOST_MEMORY_BUFFER_ALIGNMENT 64 - -typedef void (*LiteRtHostMemoryDeallocator)(void* addr); -typedef void (*LiteRtAhwbDeallocator)(AHardwareBuffer* ahwb); -typedef void (*LiteRtIonDeallocator)(void* ion_buffer_addr); -typedef void (*LiteRtDmaBufDeallocator)(void* dmabuf_buffer_addr); -typedef void (*LiteRtFastRpcDeallocator)(void* fastrpc_buffer_addr); -typedef void (*LiteRtOpenClDeallocator)(void* opencl_buffer_addr); -typedef void (*LiteRtGlBufferDeallocator)(void* gl_buffer_addr); -typedef void (*LiteRtGlTextureDeallocator)(void* gl_texture_addr); - -// ///////////////////////////////////////////////////////////////////////////// -// TensorBuffers. -// ///////////////////////////////////////////////////////////////////////////// - -// Create a tensor buffer from an existing host memory buffer of a given size, -// with optional host memory buffer deallocator (it can be NULL). Return an -// error if the passed host memory buffer doesn't satisfy -// LITERT_HOST_MEMORY_BUFFER_ALIGNMENT alignment. -LiteRtStatus LiteRtCreateTensorBufferFromHostMemory( - const LiteRtRankedTensorType* tensor_type, void* host_buffer_addr, - size_t host_buffer_size, LiteRtHostMemoryDeallocator deallocator, - LiteRtTensorBuffer* buffer); - -// Return an error if the backing buffer is not allocated on the host memory. -LiteRtStatus LiteRtGetTensorBufferHostMemory(LiteRtTensorBuffer tensor_buffer, - void** host_memory_addr); - -#if LITERT_HAS_AHWB_SUPPORT -// Create a tensor buffer from an existing AHardwareBuffer, with optional -// AHardwareBuffer deallocator (it can be NULL). An non-zero `buffer_offset` can -// be used to specify multiple tensor buffers sharing the same underlying AHWB, -// in which case the provided AHWB must be sufficiently large to accomodate for -// the allocation needed for all tensor buffers sharing it. -LiteRtStatus LiteRtCreateTensorBufferFromAhwb( - const LiteRtRankedTensorType* tensor_type, AHardwareBuffer* ahwb, - size_t ahwb_offset, LiteRtAhwbDeallocator deallocator, - LiteRtTensorBuffer* buffer); - -// Return an error if the backing buffer is not an AhardwareBuffer. -LiteRtStatus LiteRtGetTensorBufferAhwb(LiteRtTensorBuffer tensor_buffer, - AHardwareBuffer** ahwb); -#endif // LITERT_HAS_AHWB_SUPPORT - -#if LITERT_HAS_ION_SUPPORT -// Create a tensor buffer from an existing ION buffer of a given size, with -// optional ION buffer deallocator (it can be NULL). An non-zero -// `ion_buffer_offset` can be used to specify multiple tensor buffers sharing -// the same underlying ION buffer, in which case parameter `ion_buffer_size` -// must be the entire size of the underlying ION memory buffer, including the -// allocation needed for all tensor buffers sharing it. -LiteRtStatus LiteRtCreateTensorBufferFromIonBuffer( - const LiteRtRankedTensorType* tensor_type, void* ion_buffer_addr, - int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, - LiteRtIonDeallocator deallocator, LiteRtTensorBuffer* buffer); - -// Return an error if the backing buffer is not an ION buffer. -LiteRtStatus LiteRtGetTensorBufferIonBuffer(LiteRtTensorBuffer buffer, - void** ion_buffer_addr, - int* ion_buffer_fd); -#endif // LITERT_HAS_ION_SUPPORT - -#if LITERT_HAS_DMABUF_SUPPORT -// Create a tensor buffer from an existing DMA-BUF buffer of a given size, with -// optional DMA-BUF buffer deallocator (it can be NULL). An non-zero -// `dmabuf_buffer_offset` can be used to specify multiple tensor buffers sharing -// the same underlying ION buffer, in which case parameter `ion_buffer_size` -// must be the entire size of the underlying ION memory buffer, including the -// allocation needed for all tensor buffers sharing it. -LiteRtStatus LiteRtCreateTensorBufferFromDmaBufBuffer( - const LiteRtRankedTensorType* tensor_type, void* dmabuf_buffer_addr, - int dmabuf_buffer_fd, size_t dmabuf_buffer_size, - size_t dmabuf_buffer_offset, LiteRtDmaBufDeallocator deallocator, - LiteRtTensorBuffer* buffer); - -// Return an error if the backing buffer is not an DMA-BUF buffer. -LiteRtStatus LiteRtGetTensorBufferDmaBufBuffer(LiteRtTensorBuffer tensor_buffer, - void** dmabuf_buffer_addr, - int* dmabuf_buffer_fd); -#endif // LITERT_HAS_DMABUF_SUPPORT - -#if LITERT_HAS_FASTRPC_SUPPORT -// Create a tensor buffer from an existing FastRPC memory buffer of a given -// size, with optional FastRPC memory buffer deallocator (it can be NULL). An -// non-zero `fastrpc_buffer_offset` can be used to specify multiple tensor -// buffers sharing the same underlying FastRPC memory buffer, in which case -// parameter `fastrpc_buffer_size` must be the entire size of the underlying -// FastRPC memory buffer, including the allocation needed for all tensor buffers -// sharing it. -LiteRtStatus LiteRtCreateTensorBufferFromFastRpcBuffer( - const LiteRtRankedTensorType* tensor_type, void* fastrpc_buffer_addr, - int fastrpc_fd, size_t fastrpc_buffer_size, size_t fastrpc_buffer_offset, - LiteRtFastRpcDeallocator deallocator, LiteRtTensorBuffer* buffer); - -// Return an error if the backing buffer is not a FastRPC memory buffer. -LiteRtStatus LiteRtGetTensorBufferFastRpcBuffer( - LiteRtTensorBuffer tensor_buffer, void** fastrpc_buffer_addr, - int* fastrpc_buffer_fd); -#endif // LITERT_HAS_FASTRPC_SUPPORT - -#if LITERT_HAS_OPENCL_SUPPORT -// Create a tensor buffer from an existing OpenCL buffer of a given size, with -// optional opencl memory buffer deallocator (it can be NULL). An non-zero -// `opencl_buffer_offset` can be used to specify multiple tensor buffers sharing -// the same underlying OpenCL buffer, in which case parameter -// `opencl_buffer_size` must be the entire size of the underlying OpenCL -// memory buffer, including the allocation needed for all tensor buffers -// sharing it. -LiteRtStatus LiteRtCreateTensorBufferFromOpenClBuffer( - const LiteRtRankedTensorType* tensor_type, cl_mem cl_mem_addr, - size_t opencl_buffer_size, LiteRtOpenClDeallocator deallocator, - LiteRtTensorBuffer* buffer); - -// Return an error if the backing buffer is not a OpenCL buffer. -LiteRtStatus LiteRtGetTensorBufferOpenClBuffer(LiteRtTensorBuffer tensor_buffer, - cl_mem* cl_mem_addr); -#endif // LITERT_HAS_OPENCL_SUPPORT - -LiteRtStatus LiteRtCreateTensorBufferFromGlBuffer( - const LiteRtRankedTensorType* tensor_type, LiteRtGLenum target, - LiteRtGLuint id, size_t size_bytes, size_t offset, - LiteRtGlBufferDeallocator deallocator, LiteRtTensorBuffer* buffer); - -LiteRtStatus LiteRtGetTensorBufferGlBuffer(LiteRtTensorBuffer tensor_buffer, - LiteRtGLenum* target, - LiteRtGLuint* id, size_t* size_bytes, - size_t* offset); - -LiteRtStatus LiteRtCreateTensorBufferFromGlTexture( - const LiteRtRankedTensorType* tensor_type, LiteRtGLenum target, - LiteRtGLuint id, LiteRtGLenum format, size_t size_bytes, LiteRtGLint layer, - LiteRtGlTextureDeallocator deallocator, LiteRtTensorBuffer* buffer); - -LiteRtStatus LiteRtGetTensorBufferGlTexture( - LiteRtTensorBuffer tensor_buffer, LiteRtGLenum* target, LiteRtGLuint* id, - LiteRtGLenum* format, size_t* size_bytes, LiteRtGLint* layer); - -// Create a buffer backed by managed memory for a given size. -LiteRtStatus LiteRtCreateManagedTensorBuffer( - LiteRtTensorBufferType buffer_type, - const LiteRtRankedTensorType* tensor_type, size_t buffer_size, - LiteRtTensorBuffer* buffer); - -// Create a duplicate of the current tensor buffer. It will increase the -// reference count of a managed tensor buffer. And the number decreases when -// LiteRtDestroyTensorBuffer() is called. -LiteRtStatus LiteRtDuplicateTensorBuffer(LiteRtTensorBuffer tensor_buffer); - -LiteRtStatus LiteRtGetTensorBufferType(LiteRtTensorBuffer tensor_buffer, - LiteRtTensorBufferType* buffer_type); - -LiteRtStatus LiteRtGetTensorBufferTensorType( - LiteRtTensorBuffer tensor_buffer, LiteRtRankedTensorType* tensor_type); - -LiteRtStatus LiteRtGetTensorBufferSize(LiteRtTensorBuffer tensor_buffer, - size_t* size); - -LiteRtStatus LiteRtGetTensorBufferOffset(LiteRtTensorBuffer tensor_buffer, - size_t* offset); - -LiteRtStatus LiteRtHasTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, - bool* has_event); - -// Return an event attached a given tensor buffer, or NULL if no such event -// exists. The tensor buffer retains ownership of the returned event. -LiteRtStatus LiteRtGetTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, - LiteRtEvent* event); - -// Attach a given event to a given tensor buffer. The tensor buffer takes -// ownership of the event. -LiteRtStatus LiteRtSetTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, - LiteRtEvent event); - -// Remove any event that may have been previously attached to the given tensor -// buffer and deallocate such event. -LiteRtStatus LiteRtClearTensorBufferEvent(LiteRtTensorBuffer tensor_buffer); - -// Lock a tensor buffer and map it to host memory, potentially synchronizing on -// an event that was previously attached to the tensor buffer with -// `LiteRtSetTensorBufferEvent`. -LiteRtStatus LiteRtLockTensorBuffer(LiteRtTensorBuffer tensor_buffer, - void** host_mem_addr); - -// Unlock a tensor buffer and (potentially) unmap it from host memory. -LiteRtStatus LiteRtUnlockTensorBuffer(LiteRtTensorBuffer buffer); - -// Destroy a tensor buffer. If the tensor buffer is managed, the number of -// references to it is decreased and released the underlying TensorBufferT when -// the last reference is removed. -void LiteRtDestroyTensorBuffer(LiteRtTensorBuffer buffer); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.cc deleted file mode 100644 index fce2e4049f88e2..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer_requirements.h" - -#ifdef __cplusplus -extern "C" { -#endif - -LiteRtStatus LiteRtCreateTensorBufferRequirements( - int num_supported_tensor_buffer_types, - const LiteRtTensorBufferType* supported_tensor_buffer_types, - size_t buffer_size, int num_strides, const uint32_t* strides, - LiteRtTensorBufferRequirements* requirements) { - if (num_supported_tensor_buffer_types < 1 || !supported_tensor_buffer_types || - !requirements) { - return kLiteRtStatusErrorInvalidArgument; - } - *requirements = new LiteRtTensorBufferRequirementsT( - num_supported_tensor_buffer_types, supported_tensor_buffer_types, - buffer_size, std::vector(strides, strides + num_strides)); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - LiteRtTensorBufferRequirements requirements, int* num_types) { - if (!requirements || !num_types) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_types = requirements->SupportedBufferTypes().size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - LiteRtTensorBufferRequirements requirements, int type_index, - LiteRtTensorBufferType* type) { - if (!requirements || type_index < 0 || - type_index >= requirements->SupportedBufferTypes().size()) { - return kLiteRtStatusErrorInvalidArgument; - } - *type = requirements->SupportedBufferTypes()[type_index]; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferRequirementsBufferSize( - LiteRtTensorBufferRequirements requirements, size_t* buffer_size) { - if (!requirements || !buffer_size) { - return kLiteRtStatusErrorInvalidArgument; - } - *buffer_size = requirements->BufferSize(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetTensorBufferRequirementsStrides( - LiteRtTensorBufferRequirements requirements, int* num_strides, - const uint32_t** strides) { - if (!requirements || !num_strides || !strides) { - return kLiteRtStatusErrorInvalidArgument; - } - auto& s = requirements->Strides(); - *num_strides = s.size(); - *strides = s.data(); - return kLiteRtStatusOk; -} - -void LiteRtDestroyTensorBufferRequirements( - LiteRtTensorBufferRequirements requirements) { - delete requirements; -} - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h deleted file mode 100644 index 1c691a3ee38e9f..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITERT_DEFINE_HANDLE(LiteRtTensorBufferRequirements); - -LiteRtStatus LiteRtCreateTensorBufferRequirements( - int num_supported_tensor_buffer_types, - const LiteRtTensorBufferType* supported_tensor_buffer_types, - size_t buffer_size, int num_strides, const uint32_t* strides, - LiteRtTensorBufferRequirements* requirements); - -LiteRtStatus LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - LiteRtTensorBufferRequirements requirements, int* num_types); - -LiteRtStatus LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - LiteRtTensorBufferRequirements requirements, int type_index, - LiteRtTensorBufferType* type); - -LiteRtStatus LiteRtGetTensorBufferRequirementsBufferSize( - LiteRtTensorBufferRequirements requirements, size_t* buffer_size); - -LiteRtStatus LiteRtGetTensorBufferRequirementsStrides( - LiteRtTensorBufferRequirements requirements, int* num_strides, - const uint32_t** strides); - -void LiteRtDestroyTensorBufferRequirements( - LiteRtTensorBufferRequirements requirements); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc deleted file mode 100644 index 6a61eff786cbc9..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" - -#include -#include -#include - -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -namespace { - -constexpr const LiteRtTensorBufferType kSupportedTensorBufferTypes[] = { - kLiteRtTensorBufferTypeHostMemory, - kLiteRtTensorBufferTypeAhwb, - kLiteRtTensorBufferTypeIon, - kLiteRtTensorBufferTypeFastRpc, -}; - -constexpr const size_t kNumSupportedTensorBufferTypes = - sizeof(kSupportedTensorBufferTypes) / - sizeof(kSupportedTensorBufferTypes[0]); - -constexpr const size_t kBufferSize = 1234; - -} // namespace - -TEST(TensorBufferRequirements, NoStrides) { - LiteRtTensorBufferRequirements requirements; - ASSERT_EQ(LiteRtCreateTensorBufferRequirements( - kNumSupportedTensorBufferTypes, kSupportedTensorBufferTypes, - kBufferSize, - /*num_strides=*/0, /*strides=*/nullptr, &requirements), - kLiteRtStatusOk); - - int num_types; - ASSERT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - requirements, &num_types), - kLiteRtStatusOk); - ASSERT_EQ(num_types, kNumSupportedTensorBufferTypes); - - for (auto i = 0; i < num_types; ++i) { - LiteRtTensorBufferType type; - ASSERT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - requirements, i, &type), - kLiteRtStatusOk); - ASSERT_EQ(type, kSupportedTensorBufferTypes[i]); - } - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferRequirementsBufferSize(requirements, &size), - kLiteRtStatusOk); - ASSERT_EQ(size, kBufferSize); - - LiteRtDestroyTensorBufferRequirements(requirements); -} - -TEST(TensorBufferRequirements, WithStrides) { - constexpr std::array kStrides = {1, 2, 3}; - - LiteRtTensorBufferRequirements requirements; - ASSERT_EQ(LiteRtCreateTensorBufferRequirements( - kNumSupportedTensorBufferTypes, kSupportedTensorBufferTypes, - kBufferSize, kStrides.size(), kStrides.data(), &requirements), - kLiteRtStatusOk); - - int num_strides; - const uint32_t* strides; - ASSERT_EQ(LiteRtGetTensorBufferRequirementsStrides(requirements, &num_strides, - &strides), - kLiteRtStatusOk); - ASSERT_EQ(num_strides, kStrides.size()); - for (auto i = 0; i < kStrides.size(); ++i) { - ASSERT_EQ(strides[i], kStrides[i]); - } - - LiteRtDestroyTensorBufferRequirements(requirements); -} diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc deleted file mode 100644 index c77388d382f5de..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc +++ /dev/null @@ -1,439 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -#include -#include - -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/event.h" -#include "tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/ion_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h" - -namespace { -constexpr const float kTensorData[] = {10, 20, 30, 40}; - -constexpr const int32_t kTensorDimensions[] = {sizeof(kTensorData) / - sizeof(kTensorData[0])}; - -constexpr const LiteRtRankedTensorType kTensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - ::litert::BuildLayout(kTensorDimensions)}; - -} // namespace - -TEST(TensorBuffer, HostMemory) { - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeHostMemory; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -TEST(TensorBuffer, Ahwb) { - if (!litert::internal::AhwbBuffer::IsSupported()) { - GTEST_SKIP() << "AHardwareBuffers are not supported on this platform; " - "skipping the test"; - } - - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeAhwb; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -TEST(TensorBuffer, Ion) { - if (!litert::internal::IonBuffer::IsSupported()) { - GTEST_SKIP() - << "ION buffers are not supported on this platform; skipping the test"; - } - - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeIon; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -TEST(TensorBuffer, DmaBuf) { - if (!litert::internal::DmaBufBuffer::IsSupported()) { - GTEST_SKIP() - << "DMA-BUF buffers are not supported on this platform; skipping " - "the test"; - } - - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeDmaBuf; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -TEST(TensorBuffer, FastRpc) { - if (!litert::internal::FastRpcBuffer::IsSupported()) { - GTEST_SKIP() - << "FastRPC buffers are not supported on this platform; skipping " - "the test"; - } - - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeFastRpc; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -TEST(TensorBuffer, Event) { - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeHostMemory; - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - bool has_event = true; - ASSERT_EQ(LiteRtHasTensorBufferEvent(tensor_buffer, &has_event), - kLiteRtStatusOk); - EXPECT_FALSE(has_event); - - LiteRtEvent event = new LiteRtEventT; - ASSERT_EQ(LiteRtSetTensorBufferEvent(tensor_buffer, event), kLiteRtStatusOk); - - has_event = false; - ASSERT_EQ(LiteRtHasTensorBufferEvent(tensor_buffer, &has_event), - kLiteRtStatusOk); - EXPECT_TRUE(has_event); - - LiteRtEvent actual_event; - ASSERT_EQ(LiteRtGetTensorBufferEvent(tensor_buffer, &actual_event), - kLiteRtStatusOk); - ASSERT_EQ(actual_event, event); - - ASSERT_EQ(LiteRtClearTensorBufferEvent(tensor_buffer), kLiteRtStatusOk); - ASSERT_EQ(actual_event, event); - - has_event = true; - ASSERT_EQ(LiteRtHasTensorBufferEvent(tensor_buffer, &has_event), - kLiteRtStatusOk); - EXPECT_FALSE(has_event); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -TEST(TensorBuffer, OpenCL) { -// MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported In msan or tsan"; -#endif - - if (!litert::internal::OpenClBuffer::IsSupported()) { - GTEST_SKIP() << "OpenCL buffers are not supported on this platform; " - "skipping the test"; - } - - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeOpenCl; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} - -#if LITERT_HAS_OPENGL_SUPPORT -TEST(TensorBuffer, GlBuffer) { -// MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported In msan"; -#endif - - if (!litert::internal::GlBuffer::IsSupported()) { - GTEST_SKIP() << "OpenGL buffers are not supported on this platform; " - "skipping the test"; - } - - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeGlBuffer; - - LiteRtTensorBuffer tensor_buffer; - ASSERT_EQ( - LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, - sizeof(kTensorData), &tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBufferType buffer_type; - ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), - kLiteRtStatusOk); - ASSERT_EQ(buffer_type, kTensorBufferType); - - LiteRtRankedTensorType tensor_type; - ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), - kLiteRtStatusOk); - ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); - ASSERT_EQ(tensor_type.layout.rank, 1); - ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); - ASSERT_EQ(tensor_type.layout.strides, nullptr); - - size_t size; - ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); - ASSERT_EQ(size, sizeof(kTensorData)); - - size_t offset; - ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), - kLiteRtStatusOk); - ASSERT_EQ(offset, 0); - - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); - ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); - - LiteRtDestroyTensorBuffer(tensor_buffer); -} -#endif // LITERT_HAS_OPENGL_SUPPORT diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h deleted file mode 100644 index 1953915c153c98..00000000000000 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_TYPES_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_TYPES_H_ - -typedef enum { - kLiteRtTensorBufferTypeUnknown = 0, - kLiteRtTensorBufferTypeHostMemory = 1, - kLiteRtTensorBufferTypeAhwb = 2, - kLiteRtTensorBufferTypeIon = 3, - kLiteRtTensorBufferTypeDmaBuf = 4, - kLiteRtTensorBufferTypeFastRpc = 5, - kLiteRtTensorBufferTypeOpenCl = 6, - kLiteRtTensorBufferTypeGlBuffer = 7, - kLiteRtTensorBufferTypeGlTexture = 8, -} LiteRtTensorBufferType; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_TYPES_H_ diff --git a/tensorflow/lite/experimental/litert/cc/BUILD b/tensorflow/lite/experimental/litert/cc/BUILD deleted file mode 100644 index 566b35f370b70b..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/BUILD +++ /dev/null @@ -1,708 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], -) - -cc_library( - name = "litert_environment", - hdrs = ["litert_environment.h"], - deps = [ - ":litert_any", - ":litert_expected", - ":litert_handle", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "litert_environment_test", - srcs = [ - "litert_environment_test.cc", - ], - linkopts = select({ - "//tensorflow:android": ["-llog"], - "//conditions:default": [], - }), - deps = [ - ":litert_any", - ":litert_compiled_model", - ":litert_environment", - ":litert_expected", - ":litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_event", - hdrs = ["litert_event.h"], - deps = [ - ":litert_expected", - ":litert_handle", - ":litert_macros", - "//tensorflow/lite/experimental/litert/c:litert_event", - "//tensorflow/lite/experimental/litert/c:litert_event_type", - ], -) - -cc_library( - name = "litert_any", - hdrs = ["litert_any.h"], - deps = [ - ":litert_expected", - ":litert_macros", - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "litert_any_test", - srcs = [ - "litert_any_test.cc", - ], - linkopts = select({ - "//tensorflow:android": ["-llog"], - "//conditions:default": [], - }), - deps = [ - ":litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_consts", - hdrs = [ - "litert_consts.h", - ], -) - -cc_library( - name = "litert_model", - srcs = ["litert_model.cc"], - hdrs = [ - "litert_model.h", - ], - deps = [ - ":litert_buffer_ref", - ":litert_consts", - ":litert_detail", - ":litert_element_type", - ":litert_expected", - ":litert_handle", - ":litert_layout", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "litert_model_test", - srcs = [ - "litert_model_test.cc", - ], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - ], - deps = [ - ":litert_element_type", - ":litert_layout", - ":litert_model", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_handle", - hdrs = ["litert_handle.h"], -) - -cc_library( - name = "litert_tensor_buffer", - hdrs = [ - "litert_tensor_buffer.h", - "litert_tensor_buffer_requirements.h", - ], - deps = [ - ":litert_detail", - ":litert_event", - ":litert_expected", - ":litert_handle", - ":litert_macros", - ":litert_model", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_event", - "//tensorflow/lite/experimental/litert/c:litert_gl_types", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - "@com_google_absl//absl/types:span", - "@opencl_headers", - ], -) - -cc_test( - name = "litert_tensor_buffer_test", - srcs = [ - "litert_tensor_buffer_test.cc", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - deps = [ - ":litert_element_type", - ":litert_event", - ":litert_layout", - ":litert_macros", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - "//tensorflow/lite/experimental/litert/runtime:tensor_buffer", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/lite/delegates/gpu/gl:egl_environment", - ], - "//conditions:default": [], - }), -) - -cc_library( - name = "litert_tensor_buffer_requirements", - hdrs = [ - "litert_tensor_buffer_requirements.h", - ], - deps = [ - ":litert_detail", - ":litert_handle", - ":litert_macros", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "litert_tensor_buffer_requirements_test", - srcs = [ - "litert_tensor_buffer_requirements_test.cc", - ], - deps = [ - ":litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_buffer_ref", - hdrs = [ - "litert_buffer_ref.h", - ], - deps = [ - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "litert_macros", - srcs = ["litert_macros.cc"], - hdrs = ["litert_macros.h"], - deps = [ - ":litert_expected", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], -) - -cc_test( - name = "litert_macros_test", - srcs = ["litert_macros_test.cc"], - deps = [ - ":litert_expected", - ":litert_macros", - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_expected", - hdrs = ["litert_expected.h"], - deps = [ - ":litert_detail", - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - ], -) - -cc_test( - name = "litert_expected_test", - srcs = ["litert_expected_test.cc"], - deps = [ - ":litert_buffer_ref", - ":litert_expected", - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_detail", - hdrs = ["litert_detail.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_absl//absl/log:absl_check", - ], -) - -# Dispatch Delegate of LiteRt. -# Warning: This API is not ABI stable and is subject to change. -cc_library( - name = "litert_dispatch_delegate", - hdrs = [ - "litert_dispatch_delegate.h", - ], - deps = [ - "//tensorflow/lite/c:c_api", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/delegates/utils:simple_opaque_delegate", - "//tensorflow/lite/experimental/litert/c:litert_environment_options", - "//tensorflow/lite/experimental/litert/runtime/dispatch:dispatch_delegate", - ], -) - -cc_test( - name = "litert_buffer_ref_test", - srcs = ["litert_buffer_ref_test.cc"], - deps = [ - ":litert_buffer_ref", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_element_type", - hdrs = ["litert_element_type.h"], - deps = ["//tensorflow/lite/experimental/litert/c:litert_model"], -) - -cc_test( - name = "litert_element_type_test", - srcs = ["litert_element_type_test.cc"], - deps = [ - ":litert_element_type", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_model_predicates", - srcs = ["litert_model_predicates.cc"], - hdrs = ["litert_model_predicates.h"], - deps = [ - ":litert_detail", - ":litert_model", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "litert_layout", - hdrs = ["litert_layout.h"], - deps = [ - ":litert_consts", - "//tensorflow/lite/experimental/litert/c:litert_layout", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "litert_model_predicates_test", - srcs = ["litert_model_predicates_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - ], - deps = [ - ":litert_element_type", - ":litert_model", - ":litert_model_predicates", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/test:common", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "litert_layout_test", - srcs = ["litert_layout_test.cc"], - deps = [ - ":litert_layout", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_compiled_model", - srcs = ["litert_compiled_model.cc"], - hdrs = ["litert_compiled_model.h"], - deps = [ - ":litert_compilation_options", - ":litert_environment", - ":litert_expected", - ":litert_handle", - ":litert_macros", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core:cc_api_stable", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_compiled_model", - "//tensorflow/lite/experimental/litert/c:litert_environment", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "litert_compilation_options", - hdrs = ["litert_compilation_options.h"], - deps = [ - ":litert_accelerator_compilation_options", - ":litert_environment", - ":litert_expected", - ":litert_handle", - ":litert_macros", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core:cc_api_stable", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_compiled_model", - "//tensorflow/lite/experimental/litert/c:litert_environment", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "litert_compiled_model_test", - srcs = ["litert_compiled_model_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", - ], - deps = [ - ":litert_compiled_model", - ":litert_environment", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "litert_compiled_model_integration_test", - srcs = ["litert_compiled_model_integration_test.cc"], - deps = [ - ":litert_buffer_ref", - ":litert_compiled_model", - ":litert_environment", - ":litert_event", - ":litert_model", - ":litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_dispatch_headers", - "//tensorflow/lite/experimental/litert/core/model:model_buffer", - "//tensorflow/lite/experimental/litert/runtime:tensor_buffer", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/lite/delegates/gpu/gl:egl_environment", - ], - "//conditions:default": [], - }), -) - -# copybara:uncomment_begin(google-only) -# cc_test( -# name = "litert_compiled_model_gpu_test", -# srcs = ["litert_compiled_model_gpu_test.cc"], -# data = [ -# "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", -# ], -# env = { -# "ASAN_OPTIONS": "detect_odr_violation=0", -# }, -# tags = [ -# "manual", -# "notap", -# "requires-gpu-nvidia", -# ], -# deps = [ -# ":litert_compiled_model", -# ":litert_environment", -# ":litert_event", -# ":litert_model", -# ":litert_tensor_buffer", -# "@com_google_googletest//:gtest_main", -# "@com_google_absl//absl/debugging:leak_check", -# "@com_google_absl//absl/log:absl_log", -# "@com_google_absl//absl/strings:string_view", -# "@com_google_absl//absl/types:span", -# "//third_party/odml/infra/ml_drift_delegate/litert:ml_drift_cl_accelerator", # buildcleaner: keep -# "//tensorflow/lite:framework", -# "//tensorflow/lite/c:c_api_opaque", -# "//tensorflow/lite/c:common", -# "//tensorflow/lite/experimental/litert/c:litert_common", -# "//tensorflow/lite/experimental/litert/c:litert_event", -# "//tensorflow/lite/experimental/litert/c:litert_event_type", -# "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", -# "//tensorflow/lite/experimental/litert/test:common", -# "//tensorflow/lite/experimental/litert/test:matchers", -# "//tensorflow/lite/experimental/litert/test:simple_model", -# "//tensorflow/lite/kernels:builtin_ops", -# ], -# ) -# -# # The same test as above, but for Android. -# # This test doesn't run on TAP. -# # libLiteRtGpuAccelerator.so and libLiteRtRuntimeCApi.so are required to run this test. -# cc_test( -# name = "litert_compiled_model_gpu_android_test", -# srcs = ["litert_compiled_model_gpu_test.cc"], -# data = [ -# "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", -# ], -# tags = [ -# "manual", -# "notap", -# ], -# deps = [ -# ":litert_compiled_model", -# ":litert_environment", -# ":litert_event", -# ":litert_model", -# ":litert_tensor_buffer", -# "@com_google_googletest//:gtest_main", -# "@com_google_absl//absl/debugging:leak_check", -# "@com_google_absl//absl/log:absl_log", -# "@com_google_absl//absl/strings:string_view", -# "@com_google_absl//absl/types:span", -# "//third_party/odml/infra/ml_drift_delegate/litert:ml_drift_cl_accelerator_shared_lib", # buildcleaner: keep -# "//tensorflow/lite:framework", -# "//tensorflow/lite/c:c_api_opaque", -# "//tensorflow/lite/c:common", -# "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", -# "//tensorflow/lite/experimental/litert/test:common", -# "//tensorflow/lite/experimental/litert/test:matchers", -# "//tensorflow/lite/experimental/litert/test:simple_model", -# "//tensorflow/lite/kernels:builtin_ops", -# ], -# ) -# copybara:uncomment_end - -cc_library( - name = "litert_tensor_buffer_utils", - srcs = ["litert_tensor_buffer_utils.cc"], - hdrs = ["litert_tensor_buffer_utils.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - ], -) - -cc_library( - name = "litert_op_options", - srcs = ["litert_op_options.cc"], - hdrs = ["litert_op_options.h"], - deps = [ - ":litert_expected", - ":litert_macros", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "litert_op_options_test", - srcs = ["litert_op_options_test.cc"], - deps = [ - ":litert_op_options", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_shared_library", - srcs = ["litert_shared_library.cc"], - hdrs = ["litert_shared_library.h"], - deps = [ - ":litert_expected", - ":litert_macros", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "test_litert_shared_library_impl", - srcs = ["test_shared_library.cc"], -) - -cc_shared_library( - name = "test_litert_shared_library", - shared_lib_name = "test_shared_library.so", - deps = [":test_litert_shared_library_impl"], -) - -cc_test( - name = "litert_shared_library_test", - srcs = ["litert_shared_library_test.cc"], - data = [":test_litert_shared_library"], - defines = ["LITERT_DEFINE_GTEST_STATUS_PRINTER"], - deps = [ - ":litert_shared_library", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "litert_event_test", - srcs = ["litert_event_test.cc"], - deps = [ - ":litert_event", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_accelerator_compilation_options", - hdrs = [ - "litert_accelerator_compilation_options.h", - ], - deps = [ - ":litert_expected", - ":litert_handle", - ":litert_macros", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_absl//absl/strings:string_view", - ], -) - -exports_files(srcs = glob(["litert_*.h"])) diff --git a/tensorflow/lite/experimental/litert/cc/litert_accelerator_compilation_options.h b/tensorflow/lite/experimental/litert/cc/litert_accelerator_compilation_options.h deleted file mode 100644 index 80e36b6b4c980f..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_accelerator_compilation_options.h +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ACCELERATOR_COMPILATION_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ACCELERATOR_COMPILATION_OPTIONS_H_ - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -class AcceleratorCompilationOptions - : public internal::Handle { - public: - AcceleratorCompilationOptions() = default; - - // Parameter `owned` indicates if the created AcceleratorCompilationOptions - // object should take ownership of the provided `options` handle. - explicit AcceleratorCompilationOptions( - LiteRtAcceleratorCompilationOptions options, bool owned = true) - : internal::Handle(options, - owned) {} - - static Expected Create( - const LiteRtApiVersion& payload_version, - const std::string& payload_identifier, void* payload_data, - void (*payload_destructor)(void* payload_data)) { - LiteRtAcceleratorCompilationOptions options; - LITERT_RETURN_IF_ERROR(LiteRtCreateAcceleratorCompilationOptions( - &payload_version, payload_identifier.c_str(), payload_data, - payload_destructor, &options)); - return AcceleratorCompilationOptions(options); - } - - Expected GetVersion() const { - LiteRtApiVersion payload_version; - LITERT_RETURN_IF_ERROR( - LiteRtGetAcceleratorCompilationOptionsVersion(Get(), &payload_version)); - return payload_version; - } - - Expected GetIdentifier() const { - const char* payload_identifier; - LITERT_RETURN_IF_ERROR(LiteRtGetAcceleratorCompilationOptionsIdentifier( - Get(), &payload_identifier)); - return absl::string_view(payload_identifier); - } - - template - Expected GetData() const { - void* payload_data; - LITERT_RETURN_IF_ERROR( - LiteRtGetAcceleratorCompilationOptionsData(Get(), &payload_data)); - return reinterpret_cast(payload_data); - } - - template - Expected> FindData( - const std::string& payload_identifier) { - LiteRtApiVersion payload_version; - void* payload_data; - LITERT_RETURN_IF_ERROR(LiteRtFindAcceleratorCompilationOptionsData( - Get(), payload_identifier.c_str(), &payload_version, &payload_data)); - return std::make_pair(payload_version, reinterpret_cast(payload_data)); - } - - Expected Next() { - auto h = Get(); - LITERT_RETURN_IF_ERROR(LiteRtGetNextAcceleratorCompilationOptions(&h)); - return AcceleratorCompilationOptions(h, /*owned=*/false); - } - - Expected Append(AcceleratorCompilationOptions&& appended_options) { - auto h = Get(); - LITERT_RETURN_IF_ERROR(LiteRtAppendAcceleratorCompilationOptions( - &h, appended_options.Release())); - if (h != Get()) { - // If appending a new linked list item has changed the linked list head - // pointer, then we need to reflect that as the new handle. Note that - // should happen only if the previous handle was null. - assert(!Get()); - *this = AcceleratorCompilationOptions(h); - } - return {}; - } - - Expected Pop() { - auto h = Get(); - LITERT_RETURN_IF_ERROR(LiteRtPopAcceleratorCompilationOptions(&h)); - if (h != Get()) { - // If popping the last item has changed the linked list head pointer, then - // we release the current handle since it has been already destructed by - // the pop call, and then use the new head pointer as the new handle. - (void)Release(); - *this = AcceleratorCompilationOptions(h); - } - return {}; - } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ACCELERATOR_COMPILATION_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_any.h b/tensorflow/lite/experimental/litert/cc/litert_any.h deleted file mode 100644 index 97483ce3d63dcb..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_any.h +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ - -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -inline std::any ToStdAny(LiteRtAny litert_any) { - std::any res; - switch (litert_any.type) { - case kLiteRtAnyTypeNone: - break; - case kLiteRtAnyTypeBool: - res = litert_any.bool_value; - break; - case kLiteRtAnyTypeInt: - res = litert_any.int_value; - break; - case kLiteRtAnyTypeReal: - res = litert_any.real_value; - break; - case kLiteRtAnyTypeString: - res = litert_any.str_value; - break; - case kLiteRtAnyTypeVoidPtr: - res = litert_any.ptr_value; - break; - } - return res; -} - -inline Expected ToLiteRtAny(const std::any& any) { - LiteRtAny result; - if (!any.has_value()) { - result.type = kLiteRtAnyTypeNone; - return result; - - } else if (any.type() == typeid(LiteRtAny::bool_value)) { - result.type = kLiteRtAnyTypeBool; - result.bool_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(int8_t)) { - result.type = kLiteRtAnyTypeInt; - result.int_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(int16_t)) { - result.type = kLiteRtAnyTypeInt; - result.int_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(int32_t)) { - result.type = kLiteRtAnyTypeInt; - result.int_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(int64_t)) { - result.type = kLiteRtAnyTypeInt; - result.int_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(float)) { - result.type = kLiteRtAnyTypeReal; - result.real_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(double)) { - result.type = kLiteRtAnyTypeReal; - result.real_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(LiteRtAny::str_value)) { - result.type = kLiteRtAnyTypeString; - result.str_value = std::any_cast(any); - return result; - - } else if (any.type() == typeid(absl::string_view)) { - result.type = kLiteRtAnyTypeString; - result.str_value = std::any_cast(any).data(); - return result; - - } else if (any.type() == typeid(LiteRtAny::ptr_value)) { - result.type = kLiteRtAnyTypeVoidPtr; - result.ptr_value = std::any_cast(any); - return result; - - } else { - return Error(kLiteRtStatusErrorInvalidArgument, - "Invalid argument for ToLiteRtAny"); - } -} - -namespace internal { - -inline Expected CheckType(const LiteRtAny& any, - const LiteRtAnyType type) { - if (any.type != kLiteRtAnyTypeString) { - return Error(kLiteRtStatusErrorInvalidArgument, - absl::StrFormat("Wrong LiteRtAny type. Expected %s, got %s.", - LiteRtAnyTypeToString(type), - LiteRtAnyTypeToString(any.type))); - } - return {}; -} - -template -Expected GetInt(const LiteRtAny& any) { - LITERT_RETURN_IF_ERROR(CheckType(any, kLiteRtAnyTypeInt)); - if (any.int_value > std::numeric_limits::max() || - any.int_value < std::numeric_limits::lowest()) { - return Error( - kLiteRtStatusErrorInvalidArgument, - absl::StrFormat("LiteRtAny integer is out of range. %v <= %v <= %v", - std::numeric_limits::lowest(), any.int_value, - std::numeric_limits::max())); - } - return any.int_value; -} - -template -Expected GetReal(const LiteRtAny& any) { - LITERT_RETURN_IF_ERROR(CheckType(any, kLiteRtAnyTypeReal)); - if (any.real_value > std::numeric_limits::max() || - any.real_value < std::numeric_limits::lowest()) { - return Error( - kLiteRtStatusErrorInvalidArgument, - absl::StrFormat( - "LiteRtAny integer is out of range. %v <= %v <= %v failed.", - std::numeric_limits::lowest(), any.real_value, - std::numeric_limits::max())); - } - return any.real_value; -} -} // namespace internal - -// Extracts the value from a LiteRtAny object with type checking. -template -inline Expected Get(const LiteRtAny& any); - -template <> -inline Expected Get(const LiteRtAny& any) { - LITERT_RETURN_IF_ERROR(internal::CheckType(any, kLiteRtAnyTypeBool)); - return any.bool_value; -} - -template <> -inline Expected Get(const LiteRtAny& any) { - return internal::GetInt(any); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - return internal::GetInt(any); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - return internal::GetInt(any); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - return internal::GetInt(any); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - return internal::GetReal(any); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - return internal::GetReal(any); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - LITERT_RETURN_IF_ERROR(internal::CheckType(any, kLiteRtAnyTypeString)); - return std::string(any.str_value); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - LITERT_RETURN_IF_ERROR(internal::CheckType(any, kLiteRtAnyTypeString)); - return absl::string_view(any.str_value); -} - -template <> -inline Expected Get(const LiteRtAny& any) { - LITERT_RETURN_IF_ERROR(internal::CheckType(any, kLiteRtAnyTypeVoidPtr)); - return any.ptr_value; -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_any_test.cc b/tensorflow/lite/experimental/litert/cc/litert_any_test.cc deleted file mode 100644 index c6640ab8060c1c..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_any_test.cc +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" - -TEST(Any, ConversionNone) { - EXPECT_FALSE( - litert::ToStdAny(LiteRtAny{/*.type=*/kLiteRtAnyTypeNone}).has_value()); - - ASSERT_EQ(litert::ToLiteRtAny(std::any())->type, kLiteRtAnyTypeNone); -} - -TEST(Any, ConversionBool) { - ASSERT_EQ(std::any_cast(litert::ToStdAny(LiteRtAny{ - /*.type=*/kLiteRtAnyTypeBool, {/*.bool_value=*/true}})), - true); - ASSERT_EQ(std::any_cast(litert::ToStdAny(LiteRtAny{ - /*.type=*/kLiteRtAnyTypeBool, {/*.bool_value=*/false}})), - false); - - ASSERT_EQ(litert::ToLiteRtAny(std::any(true))->type, kLiteRtAnyTypeBool); - ASSERT_EQ(litert::ToLiteRtAny(std::any(true))->bool_value, true); - ASSERT_EQ(litert::ToLiteRtAny(std::any(false))->type, kLiteRtAnyTypeBool); - ASSERT_EQ(litert::ToLiteRtAny(std::any(false))->bool_value, false); -} - -TEST(Any, ConversionInt) { - LiteRtAny litert_any; - litert_any.type = kLiteRtAnyTypeInt; - litert_any.int_value = 1234; - ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), 1234); - - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(12)))->type, - kLiteRtAnyTypeInt); - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(12)))->int_value, - 12); - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1234)))->type, - kLiteRtAnyTypeInt); - ASSERT_EQ( - litert::ToLiteRtAny(std::any(static_cast(1234)))->int_value, - 1234); - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1234)))->type, - kLiteRtAnyTypeInt); - ASSERT_EQ( - litert::ToLiteRtAny(std::any(static_cast(1234)))->int_value, - 1234); - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1234)))->type, - kLiteRtAnyTypeInt); - ASSERT_EQ( - litert::ToLiteRtAny(std::any(static_cast(1234)))->int_value, - 1234); -} - -TEST(Any, ConversionReal) { - LiteRtAny litert_any; - litert_any.type = kLiteRtAnyTypeReal; - litert_any.real_value = 123.4; - ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), 123.4); - - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1.2)))->type, - kLiteRtAnyTypeReal); - EXPECT_NEAR( - litert::ToLiteRtAny(std::any(static_cast(1.2)))->real_value, 1.2, - 1e-7); - ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1.2)))->type, - kLiteRtAnyTypeReal); - EXPECT_NEAR( - litert::ToLiteRtAny(std::any(static_cast(1.2)))->real_value, 1.2, - 1e-7); -} - -TEST(Any, ConversionString) { - constexpr const char* kTestString = "test"; - LiteRtAny litert_any; - litert_any.type = kLiteRtAnyTypeString; - litert_any.str_value = kTestString; - ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), - kTestString); - - ASSERT_EQ(litert::ToLiteRtAny(std::any("test"))->type, kLiteRtAnyTypeString); - EXPECT_STREQ(litert::ToLiteRtAny(std::any("test"))->str_value, "test"); -} - -TEST(Any, ConversionPtr) { - const void* kTestPtr = reinterpret_cast(1234); - LiteRtAny litert_any; - litert_any.type = kLiteRtAnyTypeVoidPtr; - litert_any.ptr_value = kTestPtr; - ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), kTestPtr); - - ASSERT_EQ(litert::ToLiteRtAny(std::any(kTestPtr))->type, - kLiteRtAnyTypeVoidPtr); - EXPECT_EQ(litert::ToLiteRtAny(std::any(kTestPtr))->ptr_value, kTestPtr); -} diff --git a/tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h b/tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h deleted file mode 100644 index c81b5d12524afc..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h +++ /dev/null @@ -1,356 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_BUFFER_REF_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_BUFFER_REF_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" - -namespace litert { - -//===----------------------------------------------------------------------===// -// -// << BUFFER REF >> -// -// Read, read/write, and owning views of buffers of arbitrary byte width types. -// -// Serialized model artifacts and assets are frequently large strings that with -// (annoyingly) non-standard char type and left padded. The following classes -// simplify handling such buffers in an efficient copy free manner. They also -// provide read and write left-padded aware interpretebility through standard -// signed char strings types. This is used for making manual edits to flatbuffer -// metadata or dierctly to serialized flatbuffer. -// NOTE: std::basic_xxx not supported by our C++ toolchain. -// -// Pre-allocated buffers can be transferred to these classes or allocation can -// be internalized. XBufferRefs can be implictly upcasted to non-owning -// read/write or read-only to provide other routines with an appropriate view of -// the data. E.g.: -// -// ``` -// void ReadBuffer(BufferRef r_buf) { std::cerr << r_buf.StrView(); } -// void WriteToBuffer(MutableBufferRef rw_buf) { rw_buf.WriteTo("SomeData"); } -// ... -// OwningBuffer buf(size); -// WriteToBuffer(buf); // Implicitly convert to read/write with no ownership. -// ReadBuffer(buf); // Implicitly convert to read-only. -// ``` -// -//===----------------------------------------------------------------------===// - -// Allocation/Deallocation behavior for owning buffer refs. An allocator is a -// trivially constructible/destructible object that overrides () for allocating -// and freeing memory. - -// Malloc/free based memory. -template -struct Mallocator { - void operator()(ByteT* d) { - if (d != nullptr) { - free(d); - } - } - - ByteT* operator()(size_t bytes) { - return reinterpret_cast(malloc(bytes)); - } -}; - -// New/delete based memory. -template -struct Newlocator { - void operator()(ByteT* d) { - if (d != nullptr) { - delete[] d; - } - } - - ByteT* operator()(size_t bytes) { return new ByteT[bytes]; } -}; - -// -// Read-Only Bytes -// - -// Immutable and non-owning view of a buffer. -template -class BufferRef { - public: - using TupleT = std::tuple; - - // Null buffer. - BufferRef() : size_(0), offset_(0), data_(nullptr) {} - - // Construct from already allocated buffer. Methods will only expose - // data[offset, offset + size]. - BufferRef(const ByteT* data, size_t size, size_t offset = 0) - : size_(size), offset_(offset), data_(const_cast(data)) {} - BufferRef(const void* data, size_t size, size_t offset = 0) - : size_(size), - offset_(offset), - data_(const_cast(reinterpret_cast(data))) {} - explicit BufferRef(absl::Span data) - : size_(data.size()), - offset_(0), - data_(const_cast(data.data())) {} - - // Start of actual data. - const ByteT* Data() const { return data_ + offset_; } - - // Size of actual data. - size_t Size() const { return size_ - offset_; } - - // Get buffer details in tuple form. - TupleT Get() const { return TupleT(data_, size_, offset_); } - - // Start of actual data as signed char. Might not be null terminated. - const char* StrData() const { return reinterpret_cast(Data()); } - - // Convenience view of actual data as a string. Makes null terminated. - absl::string_view StrView() const { - return absl::string_view(StrData(), Size()); - } - - // Const view of actual data. - absl::Span Span() const { - return absl::MakeConstSpan(Data(), Size()); - } - - // Copy the buffer data to a vector. - std::vector ToVec() const { - return std::vector(StrData(), StrData() + Size()); - } - - // Write the string data to a stream. - void WriteStr(std::ostream& out) const { out.write(StrData(), Size()); } - - // Print info about this buffer. - void Dump(std::ostream& out) const { - out << absl::StreamFormat("%s[%lu:%lu]\n", TypeName(), offset_, size_); - } - - BufferRef(const BufferRef& other) = default; - BufferRef& operator=(const BufferRef& other) = default; - - virtual ~BufferRef() = default; - - protected: - size_t size_; - size_t offset_; - ByteT* data_ = nullptr; - - // Debug name. - virtual absl::string_view TypeName() const { return "BufferRef"; } -}; -template -BufferRef(const ByteT*, size_t, size_t) -> BufferRef; - -// -// Read-Write Non-Owning Bytes -// - -// Writeable (but still non-owning) version of BufferRef. -template -class MutableBufferRef : public BufferRef { - public: - using TupleT = std::tuple; - - // Null buffer. - MutableBufferRef() - : BufferRef((ByteT*)nullptr, /*size*/ 0, /*offset*/ 0) {} - - // Create a mutable view from pre-allocated non-const buffer. - MutableBufferRef(ByteT* data, size_t size, size_t offset = 0) - : BufferRef(data, size, offset) {} - MutableBufferRef(void* data, size_t size, size_t offset = 0) - : BufferRef(data, size, offset) {} - explicit MutableBufferRef(absl::Span data) : BufferRef(data) {} - explicit MutableBufferRef(absl::Span data) = delete; - MutableBufferRef(const ByteT*, size_t, size_t) = delete; - MutableBufferRef(const void*, size_t, size_t) = delete; - - // Mutable start of actual data. - ByteT* Data() { return this->data_ + this->offset_; } - - // Get the mutable start of actual data as a char pointer. - char* StrData() { return reinterpret_cast(Data()); } - - // Get buffer info in tuple form. - TupleT Get() { return TupleT(this->data_, this->size_, this->offset_); } - - // Mutable span of actual data. - absl::Span Span() { return absl::MakeSpan(Data(), this->Size()); } - - // Write string into the actual buffer at offset. Returns false if the entire - // string cannot fit into the actual buffer. - bool WriteInto(absl::string_view str, size_t offset = 0) { - if (str.size() > this->Size() - offset) { - return false; - } - std::memcpy(Data() + offset, str.data(), str.size()); - return true; - } - - MutableBufferRef(const MutableBufferRef& other) = default; - MutableBufferRef& operator=(const MutableBufferRef& other) = default; - - protected: - // Debug name. - absl::string_view TypeName() const override { return "MutableBufferRef"; } -}; -template -MutableBufferRef(ByteT*, size_t, size_t) -> MutableBufferRef; - -// -// Read-Write Owning Bytes -// - -// Writable and owning buffer reference. Can allocate new buffers internally and -// take ownership of existing buffers. Does not support resizing. -template > -class OwningBufferRef : public MutableBufferRef { - public: - using TupleT = std::tuple; - using WeakTupleT = std::tuple; - - // Null buffer. - OwningBufferRef() - : MutableBufferRef(/*data*/ (ByteT*)nullptr, /*size*/ 0, - /*offset*/ 0) {} - - // Initialize a new buffer reference and allocate internally. - explicit OwningBufferRef(size_t size) - : MutableBufferRef(/*data*/ (ByteT*)nullptr, size, /*offset*/ 0) { - this->data_ = (ByteT*)Allocator()(size); - } - - // Take ownership of given buffer. - OwningBufferRef(ByteT* data, size_t size, size_t offset = 0) - : MutableBufferRef(data, size, offset) {} - OwningBufferRef(void* data, size_t size, size_t offset = 0) - : MutableBufferRef(data, size, offset) {} - explicit OwningBufferRef(absl::Span data) - : MutableBufferRef(data) {} - - // Copy the given buffer. - OwningBufferRef(const ByteT* data, size_t size) - : MutableBufferRef(/*data*/ (ByteT*)nullptr, size, - /*offset*/ 0) { - this->data_ = (ByteT*)Allocator()(size); - std::memcpy(this->data_, data, size); - } - explicit OwningBufferRef(absl::Span data) - : OwningBufferRef(data.data(), data.size()) {} - - // Copy data from givens string. - explicit OwningBufferRef(absl::string_view data) - : OwningBufferRef( - reinterpret_cast(data.data()), data.size()) {} - - // Copy data from given c-style string. - explicit OwningBufferRef(const char* data) - : OwningBufferRef(absl::string_view(data)) {} - - // Drop reference to any owned memory. - void Drop() { - this->data_ = nullptr; - this->size_ = 0; - this->offset_ = 0; - } - - // Get the buffer details and drop references to them. - TupleT Release() { - auto res = std::make_tuple(this->data_, this->size_, this->offset_); - Drop(); - return res; - } - - // Get weak references to buffer data. Takes ownership of anything that - // is swapped in. - WeakTupleT GetWeak() { - return WeakTupleT(this->data_, this->size_, this->offset_); - } - - // Free any owned memory. - void Reset() { - Allocator()(this->data_); - Drop(); - } - - // Reset any existing data and copy in given ro buffer. - void Assign(const ByteT* buf, size_t size, size_t offset = 0) { - Reset(); - this->size_ = size; - this->data_ = (ByteT*)Allocator()(this->size_); - std::memcpy(this->data_, buf, this->size_); - this->offset_ = offset; - } - - OwningBufferRef(OwningBufferRef&& other) - : MutableBufferRef(other.data_, other.size_, other.offset_) { - other.Drop(); - } - - OwningBufferRef& operator=(OwningBufferRef&& other) { - if (this != &other) { - Reset(); - this->data_ = other.data_; - this->size_ = other.size_; - this->offset_ = other.offset_; - other.Drop(); - } - return *this; - } - - OwningBufferRef(const OwningBufferRef& other) - : MutableBufferRef(/*data*/ (ByteT*)nullptr, other.size_, - other.offset_) { - Assign(other.data_, other.size_, other.offset_); - } - - OwningBufferRef& operator=(const OwningBufferRef& other) { - Assign(other.data_, other.size_, other.offset_); - return *this; - } - - ~OwningBufferRef() override { Reset(); } - - protected: - // Debug string. - absl::string_view TypeName() const override { return "OwningBufferRef"; } -}; - -template > -OwningBufferRef(const ByteT*, size_t) -> OwningBufferRef; - -template > -OwningBufferRef(ByteT*, size_t) -> OwningBufferRef; - -template > -OwningBufferRef(const char*) -> OwningBufferRef; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_BUFFER_REF_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_buffer_ref_test.cc b/tensorflow/lite/experimental/litert/cc/litert_buffer_ref_test.cc deleted file mode 100644 index a2900d0c8946fd..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_buffer_ref_test.cc +++ /dev/null @@ -1,332 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -using litert::BufferRef; -using litert::Mallocator; -using litert::MutableBufferRef; -using litert::Newlocator; -using litert::OwningBufferRef; -using litert::internal::FbBufToStr; -using testing::ElementsAreArray; -using testing::Eq; -using testing::Pointwise; -using testing::StartsWith; - -namespace { - -static constexpr size_t kOffset = 4; - -static constexpr absl::string_view kData = "SomeRawBuffer"; -static constexpr absl::string_view kOtherData = "SOMERawBuffer"; - -absl::Span MakeConstFbData(absl::string_view data) { - const uint8_t* fb_data = reinterpret_cast(data.data()); - return absl::MakeConstSpan(fb_data, data.size()); -} - -absl::Span MakeFbData(absl::string_view data) { - const uint8_t* c_fb_data = reinterpret_cast(data.data()); - uint8_t* fb_data = const_cast(c_fb_data); - return absl::MakeSpan(fb_data, data.size()); -} - -std::vector MakeFbDataVec(absl::string_view data) { - const uint8_t* c_fb_data = reinterpret_cast(data.data()); - uint8_t* fb_data = const_cast(c_fb_data); - return std::vector(fb_data, fb_data + data.size()); -} - -template , typename ByteT = uint8_t> -absl::Span MakeInternalTestBuffer(absl::string_view data) { - ByteT* buffer = Allocator()(data.size()); - std::memcpy(buffer, data.data(), data.size()); - return absl::MakeSpan(reinterpret_cast(buffer), data.size()); -} - -// -// flatbuffer_tools.h -// - -TEST(FbBufToStringTest, ConstSpan) { - EXPECT_THAT(FbBufToStr(MakeConstFbData(kData)), Pointwise(Eq(), kData)); -} - -TEST(FbBufToStringTest, Span) { - EXPECT_THAT(FbBufToStr(MakeFbData(kData)), Pointwise(Eq(), kData)); -} - -TEST(FbBufToStringTest, ConstPointer) { - auto data = MakeConstFbData(kData); - EXPECT_THAT(FbBufToStr(data.data(), data.size()), Pointwise(Eq(), kData)); -} - -TEST(FbBufToStringTest, Pointer) { - auto data = MakeFbData(kData); - EXPECT_THAT(FbBufToStr(data.data(), data.size()), Pointwise(Eq(), kData)); -} - -// -// BufferRef (read-only) -// - -TEST(BufferRefTest, Dump) { - BufferRef buf(kData.data(), kData.size()); - std::stringstream out; - buf.Dump(out); - EXPECT_THAT(out.str(), StartsWith("BufferRef")); -} - -TEST(BufferRefTest, WithData) { - auto data = MakeConstFbData(kData); - BufferRef buf(data.data(), data.size()); - EXPECT_EQ(buf.Span(), data); - EXPECT_EQ(buf.StrView(), kData); -} - -TEST(BufferRefTest, WithDataAndOffset) { - auto data = MakeConstFbData(kData); - BufferRef buf(data.data(), data.size(), kOffset); - EXPECT_EQ(buf.Span(), data.subspan(kOffset, buf.Size())); - EXPECT_EQ(buf.StrView(), kData.substr(kOffset, buf.Size())); -} - -TEST(BufferRefTest, ToVec) { - auto data = MakeConstFbData(kData); - BufferRef buf(data.data(), data.size()); - EXPECT_THAT(buf.ToVec(), ElementsAreArray(data)); -} - -TEST(BufferRefTest, WriteStr) { - auto data = MakeConstFbData(kData); - BufferRef buf(data.data(), data.size()); - std::stringstream out; - buf.WriteStr(out); - EXPECT_EQ(out.str(), kData); -} - -TEST(BufferRefTest, WriteStrOffset) { - auto data = MakeConstFbData(kData); - BufferRef buf(data.data(), data.size(), kOffset); - std::stringstream out; - buf.WriteStr(out); - EXPECT_EQ(out.str(), kData.substr(kOffset, buf.Size())); -} - -TEST(BufferRefTest, TupleGet) { - auto input = MakeConstFbData(kData); - BufferRef buf(input); - auto [data, size, offset] = buf.Get(); - ASSERT_EQ(offset, 0); - EXPECT_EQ(input, buf.Span()); -} - -// -// MutableBufferRef (read/write) -// - -TEST(MutableBufferRefTest, Dump) { - MutableBufferRef buf; - std::stringstream out; - buf.Dump(out); - EXPECT_THAT(out.str(), StartsWith("MutableBufferRef")); -} - -TEST(MutableBufferRefTest, WriteInto) { - auto v_data = MakeFbDataVec(kOtherData); - MutableBufferRef buf(v_data.data(), v_data.size()); - ASSERT_TRUE(buf.WriteInto("Some")); - EXPECT_THAT(buf.Span(), ElementsAreArray(v_data)); - EXPECT_EQ(buf.StrView(), kData); -} - -TEST(MutableBufferRefTest, WriteIntoOffsetBuf) { - auto v_data = MakeFbDataVec(kOtherData); - static constexpr absl::string_view kExpData = "RAWBuffer"; - MutableBufferRef buf(v_data.data(), v_data.size(), kOffset); - ASSERT_TRUE(buf.WriteInto("RAW")); - EXPECT_THAT(buf.Span(), ElementsAreArray(MakeConstFbData(kExpData))); - EXPECT_EQ(buf.StrView(), kExpData); -} - -TEST(MutableBufferRefTest, WriteIntoOffsetData) { - auto v_data = MakeFbDataVec(kOtherData); - static constexpr absl::string_view kExpData = "SOMERAWBuffer"; - MutableBufferRef buf(v_data.data(), v_data.size()); - ASSERT_TRUE(buf.WriteInto("RAW", kOffset)); - EXPECT_THAT(buf.Span(), ElementsAreArray(MakeConstFbData(kExpData))); - EXPECT_EQ(buf.StrView(), kExpData); -} - -TEST(MutableBufferRefTest, TupleGet) { - auto input = MakeInternalTestBuffer("FOO"); - MutableBufferRef buf(input); - auto [data, size, offset] = buf.Get(); - *data = 'b'; - EXPECT_EQ(buf.StrView(), "bOO"); - delete[] input.data(); -} - -// -// OwningBufferRef (read/write with memory management) -// - -TEST(OwningBufferRefTest, Dump) { - OwningBufferRef buf; - std::stringstream out; - buf.Dump(out); - EXPECT_THAT(out.str(), StartsWith("OwningBufferRef")); -} - -TEST(OwningBufferRefTest, MoveCstor) { - auto raw = MakeInternalTestBuffer>(kData); - OwningBufferRef> buf(raw.data(), raw.size()); - OwningBufferRef> other(std::move(buf)); - EXPECT_EQ(other.StrView(), kData); -} - -TEST(OwningBufferRefTest, MoveAssign) { - auto raw = MakeInternalTestBuffer>(kData); - OwningBufferRef> buf(raw.data(), raw.size()); - OwningBufferRef> other = std::move(buf); - EXPECT_EQ(other.StrView(), kData); -} - -TEST(OwningBufferRefTest, CopyCstor) { - auto raw = MakeInternalTestBuffer>(kData); - OwningBufferRef> buf(raw.data(), raw.size()); - OwningBufferRef> other(buf); - other.WriteInto("SOME"); - EXPECT_EQ(buf.StrView(), kData); - EXPECT_EQ(other.StrView(), "SOMERawBuffer"); -} - -TEST(OwningBufferRefTest, CopyAssign) { - auto raw = MakeInternalTestBuffer>(kData); - OwningBufferRef> buf(raw.data(), raw.size()); - OwningBufferRef> other = buf; - other.WriteInto("SOME"); - EXPECT_EQ(buf.StrView(), kData); - EXPECT_EQ(other.StrView(), "SOMERawBuffer"); -} - -TEST(OwningBufferRefTest, InternalMalloc) { - OwningBufferRef> buf(kData.size()); - ASSERT_EQ(buf.Size(), kData.size()); - ASSERT_NE(buf.Data(), nullptr); - - buf.WriteInto(kData); - EXPECT_EQ(buf.StrView(), kData); -} - -TEST(OwningBufferRefTest, InternalNew) { - OwningBufferRef buf(kData.size()); - ASSERT_EQ(buf.Size(), kData.size()); - ASSERT_NE(buf.Data(), nullptr); - - buf.WriteInto(kData); - EXPECT_EQ(buf.StrView(), kData); -} - -TEST(OwningBufferRefTest, TakeOwnershipMalloc) { - auto malloc_buffer = MakeInternalTestBuffer>(kData); - OwningBufferRef> buf(malloc_buffer.data(), - malloc_buffer.size()); - EXPECT_EQ(buf.StrView(), kData); -} - -TEST(OwningBufferRefTest, TakeOwnershipNew) { - auto new_buffer = MakeInternalTestBuffer(kData); - OwningBufferRef buf(new_buffer.data(), new_buffer.size()); - EXPECT_EQ(buf.StrView(), kData); -} - -TEST(OwningBufferRefTest, TakeOwnershipOffset) { - auto malloc_buffer = MakeInternalTestBuffer>(kData); - OwningBufferRef> buf(malloc_buffer.data(), - malloc_buffer.size(), - /*offset=*/4); - EXPECT_EQ(buf.StrView(), "RawBuffer"); -} - -TEST(OwningBufferRefTest, CopyBuffer) { - auto const_buf = MakeConstFbData(kData); - OwningBufferRef buf(const_buf.data(), const_buf.size()); - buf.WriteInto("SOME"); - EXPECT_EQ(buf.StrView(), "SOMERawBuffer"); - EXPECT_EQ(FbBufToStr(const_buf), "SomeRawBuffer"); -} - -TEST(OwningBufferRefTest, ImplicitUpCasts) { - OwningBufferRef buf(kData.size()); - BufferRef c_buf = buf; - - buf.WriteInto(kData); - EXPECT_EQ(c_buf.StrView(), buf.StrView()); -} - -TEST(OwningBufferRefTest, TupleGetWeak) { - auto input = MakeInternalTestBuffer("FOO"); - - OwningBufferRef buf; - auto [data, size, offset] = buf.GetWeak(); - - data = input.data(); - size = input.size(); - offset = 0; - - ASSERT_EQ(buf.Size(), input.size()); - ASSERT_EQ(buf.Size(), input.size()); - - buf.WriteInto("BAR"); - - EXPECT_EQ(buf.StrView(), "BAR"); - EXPECT_EQ(buf.Span(), input); -} - -TEST(OwningBufferRefTest, TupleRelease) { - OwningBufferRef buf("BAZ"); - - auto [data, size, offset] = buf.Release(); - - EXPECT_EQ(buf.Size(), 0); - EXPECT_EQ(absl::string_view(data, size), "BAZ"); - - delete[] data; -} - -TEST(OwningBufferRefTest, Assign) { - auto const_buf = MakeConstFbData(kData); - OwningBufferRef buf; - buf.Assign(const_buf.data(), const_buf.size()); - buf.WriteInto("SOME"); - EXPECT_EQ(buf.StrView(), "SOMERawBuffer"); - EXPECT_EQ(FbBufToStr(const_buf), "SomeRawBuffer"); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/cc/litert_compilation_options.h b/tensorflow/lite/experimental/litert/cc/litert_compilation_options.h deleted file mode 100644 index 8a21f22d120a79..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_compilation_options.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILATION_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILATION_OPTIONS_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -class CompilationOptions - : public internal::Handle { - public: - CompilationOptions() = default; - - // Parameter `owned` indicates if the created CompilationOptions object - // should take ownership of the provided `compilation_options` handle. - explicit CompilationOptions(LiteRtCompilationOptions compilation_options, - bool owned = true) - : internal::Handle(compilation_options, - owned) {} - - static Expected Create() { - LiteRtCompilationOptions options; - LITERT_RETURN_IF_ERROR(LiteRtCreateCompilationOptions(&options)); - return CompilationOptions(options); - } - - Expected SetHardwareAccelerators(LiteRtHwAcceleratorSet accelerators) { - LITERT_RETURN_IF_ERROR( - LiteRtSetCompilationOptionsHardwareAccelerators(Get(), accelerators)); - return {}; - } - - Expected GetHardwareAccelerators() { - LiteRtHwAcceleratorSet accelerators; - LITERT_RETURN_IF_ERROR( - LiteRtGetCompilationOptionsHardwareAccelerators(Get(), &accelerators)); - return accelerators; - } - - Expected AddAcceleratorCompilationOptions( - AcceleratorCompilationOptions&& options) { - LITERT_RETURN_IF_ERROR( - LiteRtAddAcceleratorCompilationOptions(Get(), options.Release())); - return {}; - } - - Expected GetAcceleratorCompilationOptions() { - LiteRtAcceleratorCompilationOptions options; - LITERT_RETURN_IF_ERROR( - LiteRtGetAcceleratorCompilationOptions(Get(), &options)); - return AcceleratorCompilationOptions(options, /*owned=*/false); - } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILATION_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc b/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc deleted file mode 100644 index 9a6658dcaa481c..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" - -namespace litert { - -Expected CompiledModel::FindInputIndex( - size_t signature_index, absl::string_view input_name) const { - LITERT_ASSIGN_OR_RETURN(const Signature& signature, - model_.GetSignature(signature_index)); - const std::vector& input_names = signature.InputNames(); - auto it = std::find(input_names.begin(), input_names.end(), input_name); - if (it != input_names.end()) { - return std::distance(input_names.begin(), it); - } - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find input"); -} - -Expected CompiledModel::FindOutputIndex( - size_t signature_index, absl::string_view output_name) const { - LITERT_ASSIGN_OR_RETURN(const Signature& signature, - model_.GetSignature(signature_index)); - const std::vector& output_names = signature.OutputNames(); - auto it = std::find(output_names.begin(), output_names.end(), output_name); - if (it != output_names.end()) { - return std::distance(output_names.begin(), it); - } - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find output"); -} - -Expected CompiledModel::CreateBufferImpl( - const TensorBufferRequirements& buffer_requirements, - const RankedTensorType& tensor_type) { - LITERT_ASSIGN_OR_RETURN( - const std::vector& supported_types, - buffer_requirements.SupportedTypes()); - if (supported_types.empty()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Input doesn't support any tensor buffer types"); - } - // For simplicity we just pick the first supported tensor buffer type. - LiteRtTensorBufferType tensor_buffer_type = supported_types[0]; - LITERT_ASSIGN_OR_RETURN(size_t buffer_size, buffer_requirements.BufferSize()); - - LITERT_ASSIGN_OR_RETURN(TensorBuffer buffer, - TensorBuffer::CreateManaged( - tensor_buffer_type, tensor_type, buffer_size)); - return buffer; -} - -Expected CompiledModel::CreateInputOutputBuffer( - size_t signature_index, absl::string_view tensor_name, - bool is_input) const { - LITERT_ASSIGN_OR_RETURN(Signature signature, - model_.GetSignature(signature_index)); - - LITERT_ASSIGN_OR_RETURN(Subgraph subgraph, model_.Subgraph(signature.Key())); - - Expected tensor_expected = - is_input ? subgraph.Input(tensor_name) : subgraph.Output(tensor_name); - Expected buffer_requirements_expected = - is_input ? GetInputBufferRequirements(signature_index, tensor_name) - : GetOutputBufferRequirements(signature_index, tensor_name); - - LITERT_ASSIGN_OR_RETURN(const Tensor& tensor, tensor_expected); - LITERT_ASSIGN_OR_RETURN(const TensorBufferRequirements& buffer_requirements, - buffer_requirements_expected); - LITERT_ASSIGN_OR_RETURN(const RankedTensorType& tensor_type, - tensor.RankedTensorType()); - - return CreateBufferImpl(buffer_requirements, tensor_type); -} - -Expected> CompiledModel::CreateInputOutputBuffers( - size_t signature_index, bool is_input) const { - LITERT_ASSIGN_OR_RETURN(const Signature& signature, - model_.GetSignature(signature_index)); - LITERT_ASSIGN_OR_RETURN(const Subgraph subgraph, - model_.Subgraph(signature.Key())); - std::vector tensor_buffers; - std::vector tensor_names; - - tensor_names = is_input ? signature.InputNames() : signature.OutputNames(); - tensor_buffers.reserve(tensor_names.size()); - - for (int i = 0; i < tensor_names.size(); ++i) { - LITERT_ASSIGN_OR_RETURN( - TensorBuffer tensor_buffer, - CreateInputOutputBuffer(signature.Key(), tensor_names[i], is_input)); - tensor_buffers.push_back(std::move(tensor_buffer)); - } - - return tensor_buffers; -} - -Expected CompiledModel::RunCApiHelper(LiteRtParamIndex signature_index, - size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers, - bool& async) const { - LiteRtStatus status = - async ? LiteRtRunCompiledModelAsync( - Get(), signature_index, num_input_buffers, input_buffers, - num_output_buffers, output_buffers, &async) - : LiteRtRunCompiledModel(Get(), signature_index, num_input_buffers, - input_buffers, num_output_buffers, - output_buffers); - if (status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to invoke the compiled model"); - } - return {}; -} - -Expected CompiledModel::RunHelper( - size_t signature_index, const std::vector& input_buffers, - const std::vector& output_buffers, bool& async) const { - auto input_buffers_ptr = - std::make_unique(input_buffers.size()); - for (int i = 0; i < input_buffers.size(); ++i) { - input_buffers_ptr[i] = input_buffers[i].Get(); - } - auto output_buffers_ptr = - std::make_unique(output_buffers.size()); - for (int i = 0; i < output_buffers.size(); ++i) { - output_buffers_ptr[i] = output_buffers[i].Get(); - } - return RunCApiHelper(signature_index, input_buffers.size(), - input_buffers_ptr.get(), output_buffers.size(), - output_buffers_ptr.get(), async); -} - -Expected CompiledModel::RunMapHelper( - absl::string_view signature_key, - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map, - bool& async) const { - auto signature_index = model_.GetSignatureIndex(signature_key); - if (!signature_index) { - return Unexpected(kLiteRtStatusErrorNotFound, - "Failed to get signature_index"); - } - auto subgraph = model_.Subgraph(signature_key); - if (!subgraph) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get subgraph"); - } - return RunMapWithIndexHelper(*signature_index, *subgraph, input_map, - output_map, async); -} - -Expected CompiledModel::RunMapWithIndexHelper( - size_t signature_index, const Subgraph& subgraph, - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map, - bool& async) const { - auto input_tensors = subgraph.Inputs(); - size_t num_inputs = input_tensors.size(); - auto input_buffers_ptr = std::make_unique(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - absl::string_view input_name = input_tensors[i].Name(); - auto it = input_map.find(input_name); - if (it == input_map.end()) { - return Unexpected(kLiteRtStatusErrorNotFound, - "The given map is missing some input TensorBuffers"); - } - input_buffers_ptr[i] = it->second.Get(); - } - auto output_tensors = subgraph.Outputs(); - size_t num_outputs = output_tensors.size(); - auto output_buffers_ptr = std::make_unique(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - absl::string_view output_name = output_tensors[i].Name(); - auto it = output_map.find(output_name); - if (it == output_map.end()) { - return Unexpected(kLiteRtStatusErrorNotFound, - "The given map is missing some output TensorBuffers"); - } - output_buffers_ptr[i] = it->second.Get(); - } - return RunCApiHelper(signature_index, num_inputs, input_buffers_ptr.get(), - num_outputs, output_buffers_ptr.get(), async); -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h b/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h deleted file mode 100644 index 7ad7207ad569ce..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h +++ /dev/null @@ -1,438 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILED_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILED_MODEL_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" - -namespace litert { - -// The CompiledModel is a higher level inference API. It is created by -// provided model with compilation options. Internally, it instantiates runtime -// and applies Delegates mapped to the compilation options. -// It also supports getting BufferRequirements to create input/output -// TensorBuffers, and it allows to invoke the model with the input/output -// TensorBuffers. -// -// Example user flow: -// -// 1. Create CompiledModel -// 2. Query the model input/output requirements -// 3. Create input/output TensorBuffers -// 4. Fill the input TensorBuffers with input data -// 5. Invoke the model with the input/output TensorBuffers -// 6. Evaluate the output TensorBuffers - -class CompiledModel - : public internal::Handle { - public: - CompiledModel() = default; - - // Creates a CompiledModel instance. - // - // If `owned` is `true`, then the created object takes ownership of the - // `compiled_model` handle. - explicit CompiledModel(LiteRtModel litert_model, - LiteRtCompiledModel compiled_model, bool owned = true) - : internal::Handle( - compiled_model, owned), - model_(Model::CreateFromNonOwnedHandle(litert_model)) {} - - // Creates a CompiledModel from a TFLite file. - // - // The model is loaded into memory and the caller takes ownership of the - // returned CompiledModel object. The caller should keep the model alive - // until the CompiledModel is destroyed. - // The given `compilation_options` is used for JIT compilation of the model. - // - // Note: The given environment must outlive the compiled model and any - // execution running it. - // Note: If the model is fully AOT compiled for NPU, NPU accelerator is used - // automatically which means the provided `compilation_options` are - // meaningless. - static Expected Create( - litert::Environment& env, litert::Model& model, - const CompilationOptions& jit_compilation_options) { - LiteRtModel litert_model = model.Get(); - LiteRtCompiledModel compiled_model; - LITERT_RETURN_IF_ERROR(LiteRtCreateCompiledModel( - env.Get(), litert_model, jit_compilation_options.Get(), - &compiled_model)); - return CompiledModel(litert_model, compiled_model); - } - - // Simpler version of Create() that uses the default compilation options. - // The provided hardware accelerator is used for JIT compilation of the model. - // - // Note: If the model is fully AOT compiled for NPU, NPU accelerator - // is used automatically which means the provided `hardware_accelerator` is - // meaningless. - static Expected Create( - litert::Environment& env, litert::Model& model, - LiteRtHwAccelerators hardware_accelerator = kLiteRtHwAcceleratorCpu) { - LITERT_ASSIGN_OR_RETURN(auto jit_compilation_options, - CompilationOptions::Create()); - jit_compilation_options.SetHardwareAccelerators(hardware_accelerator); - return Create(env, model, jit_compilation_options); - } - - // Get input buffer requirements for the given signature and input name. - Expected GetInputBufferRequirements( - absl::string_view signature_name, absl::string_view input_name) { - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_name)); - return GetInputBufferRequirements(signature_index, input_name); - } - - // Returns the buffer requirements for the given n-th input tensor. The - // returned TensorBufferRequirements is used to create the input tensor - // buffer. - Expected GetInputBufferRequirements( - size_t signature_index, size_t input_index) const { - LiteRtTensorBufferRequirements buffer_requirements; - LITERT_RETURN_IF_ERROR(LiteRtGetCompiledModelInputBufferRequirements( - Get(), signature_index, input_index, &buffer_requirements)); - return TensorBufferRequirements(buffer_requirements, /*owned=*/false); - } - - // The same as above except this function takes input tensor name. - Expected GetInputBufferRequirements( - size_t signature_index, absl::string_view input_name) const { - LITERT_ASSIGN_OR_RETURN(size_t input_index, - FindInputIndex(signature_index, input_name)); - return GetInputBufferRequirements(signature_index, input_index); - } - - // Get input buffer requirements of the default signature for the given n-th - // input tensor. - Expected GetInputBufferRequirements( - size_t input_index) const { - return GetInputBufferRequirements(/*signature_index=*/0, input_index); - } - - // Get input buffer requirements of the default signature for input name. - Expected GetInputBufferRequirements( - absl::string_view input_name) const { - return GetInputBufferRequirements(/*signature_index=*/0, input_name); - } - - // Get output buffer requirements for the given signature and output name. - Expected GetOutputBufferRequirements( - absl::string_view signature_name, absl::string_view output_name) { - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_name)); - return GetOutputBufferRequirements(signature_index, output_name); - } - - // Returns the buffer requirements for the given output tensor. The returned - // TensorBufferRequirements is used to create the output tensor - // buffer. - Expected GetOutputBufferRequirements( - size_t signature_index, size_t output_index) const { - LiteRtTensorBufferRequirements buffer_requirements; - LITERT_RETURN_IF_ERROR(LiteRtGetCompiledModelOutputBufferRequirements( - Get(), signature_index, output_index, &buffer_requirements)); - return TensorBufferRequirements(buffer_requirements, /*owned=*/false); - } - - // The same as above except this function takes output tensor name. - Expected GetOutputBufferRequirements( - size_t signature_index, absl::string_view output_name) const { - LITERT_ASSIGN_OR_RETURN(size_t output_index, - FindOutputIndex(signature_index, output_name)); - return GetOutputBufferRequirements(signature_index, output_index); - } - - // Get input buffer requirements of the default signature for the given n-th - // input tensor. - Expected GetOutputBufferRequirements( - size_t output_index) const { - return GetOutputBufferRequirements(/*signature_index=*/0, output_index); - } - - // Get input buffer requirements of the default signature for input name. - Expected GetOutputBufferRequirements( - absl::string_view output_name) const { - return GetOutputBufferRequirements(/*signature_index=*/0, output_name); - } - - // Creates an input tensor buffer for the given signature and input name. - Expected CreateInputBuffer(absl::string_view signature_name, - absl::string_view input_name) const { - return CreateInputOutputBuffer(signature_name, input_name, - /*is_input=*/true); - } - - // Creates an input tensor buffer of the default signature for the given input - // name. - Expected CreateInputBuffer(absl::string_view input_name) const { - return CreateInputOutputBuffer(/*signature_index=*/0, input_name, - /*is_input=*/true); - } - - // Creates an output tensor buffer for the given signature and output name. - Expected CreateOutputBuffer( - absl::string_view signature_name, absl::string_view output_name) const { - return CreateInputOutputBuffer(signature_name, output_name, - /*is_input=*/false); - } - - // Creates an output tensor buffer of the default signature for the given - // output name. - Expected CreateOutputBuffer( - absl::string_view output_name) const { - return CreateInputOutputBuffer(/*signature_index=*/0, output_name, - /*is_input=*/false); - } - - // A helper function to create input tensor buffers for the given signature. - // It uses BufferRequirements and RankedTensorType to create the input tensor - // buffers. - Expected> CreateInputBuffers( - absl::string_view signature_name) const { - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_name)); - return CreateInputOutputBuffers(signature_index, /*is_input=*/true); - } - - // A helper function to creates the input tensor buffers for the given - // signature. It uses BufferRequirements and RankedTensorType to create the - // input tensor buffers. - Expected> CreateInputBuffers( - size_t signature_index) const { - return CreateInputOutputBuffers(signature_index, /*is_input=*/true); - } - - // A helper function to creates the input tensor buffers for the default - // signature. It uses BufferRequirements and RankedTensorType to create the - // input tensor buffers. - Expected> CreateInputBuffers() const { - return CreateInputOutputBuffers(/*signature_index=*/0, /*is_input=*/true); - } - - // A helper function to create output tensor buffers for the given signature. - // It uses BufferRequirements and RankedTensorType to create the output tensor - // buffers. - Expected> CreateOutputBuffers( - absl::string_view signature_name) const { - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_name)); - return CreateOutputBuffers(signature_index); - } - - // A helper function to creates the output tensor buffers for the given - // signature. It uses BufferRequirements and RankedTensorType to create the - // output tensor buffers. - Expected> CreateOutputBuffers( - size_t signature_index) const { - return CreateInputOutputBuffers(signature_index, /*is_input=*/false); - } - - // A helper function to creates the output tensor buffers for the default - // signature. It uses BufferRequirements and RankedTensorType to create the - // output tensor buffers. - Expected> CreateOutputBuffers() const { - return CreateInputOutputBuffers(/*signature_index=*/0, /*is_input=*/false); - } - - // Runs the model of the given signature index synchronously with the provided - // input/output TensorBuffers. - Expected Run(size_t signature_index, - const std::vector& input_buffers, - const std::vector& output_buffers) const { - bool async = false; - return RunHelper(signature_index, input_buffers, output_buffers, async); - } - - // Runs the model of the default signature synchronously with the provided - // input/output TensorBuffers. - Expected Run(const std::vector& input_buffers, - const std::vector& output_buffers) const { - bool async = false; - return RunHelper(/*signature_index=*/0, input_buffers, output_buffers, - async); - } - - // Runs the model of the given signature index asynchronously, if possible, - // with the provided input/output TensorBuffers. If asynchronous execution is - // possible then the function returns true in parameter `async`; otherwise the - // function runs the model synchronously. - Expected RunAsync(size_t signature_index, - const std::vector& input_buffers, - const std::vector& output_buffers, - bool& async) const { - async = true; - return RunHelper(signature_index, input_buffers, output_buffers, async); - } - - // Runs the model of the default signature asynchronously, if possible, - // with the provided input/output TensorBuffers. If asynchronous execution is - // possible then the function returns true in parameter `async`; otherwise the - // function runs the model synchronously. - Expected RunAsync(const std::vector& input_buffers, - const std::vector& output_buffers, - bool& async) const { - async = true; - return RunHelper(/*signature_index=*/0, input_buffers, output_buffers, - async); - } - - // Runs the model of the given signature key synchronously with the provided - // input/output TensorBuffers. - Expected Run(absl::string_view signature_key, - const std::vector& input_buffers, - const std::vector& output_buffers) const { - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_key)); - return Run(signature_index, input_buffers, output_buffers); - } - - // Runs the model of the given signature key asynchronously, if possible, with - // the provided input/output TensorBuffers. If asynchronous execution is - // possible then the function returns true in parameter `async`; otherwise the - // function runs the model synchronously. - Expected RunAsync(absl::string_view signature_key, - const std::vector& input_buffers, - const std::vector& output_buffers, - bool& async) const { - async = true; - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_key)); - return RunAsync(signature_index, input_buffers, output_buffers, async); - } - - // Runs the model of the given signature key synchronously with the provided - // input/output TensorBuffer map. - Expected Run( - absl::string_view signature_key, - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map) - const { - bool async = false; - return RunMapHelper(signature_key, input_map, output_map, async); - } - - // Runs the model of the default signature synchronously with the provided - // input/output TensorBuffer map. - Expected Run( - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map) - const { - bool async = false; - auto subgraph = model_.MainSubgraph(); - if (!subgraph) { - return Unexpected(kLiteRtStatusErrorNotFound, - "Failed to get main subgraph"); - } - return RunMapWithIndexHelper(/*signature_index=*/0, *subgraph, input_map, - output_map, async); - } - - // Runs the model of the given signature key asynchronously, if possible, with - // the provided input/output TensorBuffer map. If asynchronous execution is - // possible then the function returns true in parameter `async`; otherwise the - // function runs the model synchronously. - Expected RunAsync( - absl::string_view signature_key, - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map, - bool& async) const { - async = true; - return RunMapHelper(signature_key, input_map, output_map, async); - } - - private: - // Returns the signature input index for the given input tensor name. - Expected FindInputIndex(size_t signature_index, - absl::string_view input_name) const; - - // Returns the signature output index for the given output tensor name. - Expected FindOutputIndex(size_t signature_index, - absl::string_view output_name) const; - - // Creates a TensorBuffer with the given buffer requirements and tensor type. - static Expected CreateBufferImpl( - const TensorBufferRequirements& buffer_requirements, - const RankedTensorType& tensor_type); - - // Creates a TensorBuffer for the given signature index and tensor name. - Expected CreateInputOutputBuffer(size_t signature_index, - absl::string_view tensor_name, - bool is_input) const; - - // Creates a TensorBuffer for the given signature and tensor name. - Expected CreateInputOutputBuffer( - absl::string_view signature_name, absl::string_view tensor_name, - bool is_input) const { - LITERT_ASSIGN_OR_RETURN(size_t signature_index, - model_.GetSignatureIndex(signature_name)); - return CreateInputOutputBuffer(signature_index, tensor_name, is_input); - } - - // Creates a vector of TensorBuffers for the given signature subgraph. - Expected> CreateInputOutputBuffers( - size_t signature_index, bool is_input) const; - - Expected RunCApiHelper(LiteRtParamIndex signature_index, - size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers, - bool& async) const; - - Expected RunHelper(size_t signature_index, - const std::vector& input_buffers, - const std::vector& output_buffers, - bool& async) const; - - Expected RunMapHelper( - absl::string_view signature_key, - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map, - bool& async) const; - - Expected RunMapWithIndexHelper( - size_t signature_index, const Subgraph& subgraph, - const absl::flat_hash_map& input_map, - const absl::flat_hash_map& output_map, - bool& async) const; - - Model model_; -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_COMPILED_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_gpu_test.cc b/tensorflow/lite/experimental/litert/cc/litert_compiled_model_gpu_test.cc deleted file mode 100644 index 425658907802cb..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_gpu_test.cc +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include -#include -#include "absl/debugging/leak_check.h" -#include "absl/log/absl_log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_event_type.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -using testing::FloatNear; -using testing::Pointwise; - -namespace litert { -namespace { - -void BasicTest() { - auto model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - auto env = litert::Environment::Create({}); - ASSERT_TRUE(env); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto compiled_model, - CompiledModel::Create(*env, model, kLiteRtHwAcceleratorGpu)); - auto signatures = model.GetSignatures().Value(); - EXPECT_EQ(signatures.size(), 1); - - auto signature_key = signatures[0].Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - LITERT_ASSERT_OK_AND_ASSIGN( - auto input_buffers, compiled_model.CreateInputBuffers(signature_index)); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto output_buffers, compiled_model.CreateOutputBuffers(signature_index)); - - // Fill model inputs. - auto input_names = signatures[0].InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - EXPECT_EQ(*input_buffers[0].BufferType(), kLiteRtTensorBufferTypeOpenCl); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - EXPECT_EQ(*input_buffers[1].BufferType(), kLiteRtTensorBufferTypeOpenCl); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - auto output_names = signatures[0].OutputNames(); - EXPECT_EQ(output_names.size(), 1); - EXPECT_EQ(output_names.at(0), "tfl.add"); - EXPECT_EQ(*output_buffers[0].BufferType(), kLiteRtTensorBufferTypeOpenCl); - { - auto lock_and_addr = - litert::TensorBufferScopedLock::Create(output_buffers[0]); - ASSERT_TRUE(lock_and_addr); - auto output = absl::MakeSpan(lock_and_addr->second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -TEST(CompiledModelGpuTest, Basic) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported in MSAN"; -#endif - // To workaround the memory leak in Nvidia's driver - absl::LeakCheckDisabler disable_leak_check; - - BasicTest(); -} - -TEST(CompiledModelGpuTest, Basic2nd) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported in MSAN"; -#endif - // To workaround the memory leak in Nvidia's driver - absl::LeakCheckDisabler disable_leak_check; - - // Run the test twice to verify that the CL environment is shared between - // instances. - BasicTest(); -} - -TEST(CompiledModelGpuTest, Async) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported in MSAN"; -#endif - // To workaround the memory leak in Nvidia's driver - absl::LeakCheckDisabler disable_leak_check; - - auto model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - auto env = litert::Environment::Create({}); - ASSERT_TRUE(env); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto compiled_model, - CompiledModel::Create(*env, model, kLiteRtHwAcceleratorGpu)); - - auto signatures = model.GetSignatures().Value(); - EXPECT_EQ(signatures.size(), 1); - - auto signature_key = signatures[0].Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - LITERT_ASSERT_OK_AND_ASSIGN( - auto input_buffers, compiled_model.CreateInputBuffers(signature_index)); - - LITERT_ASSERT_OK_AND_ASSIGN(auto input_event, - Event::CreateManaged(LiteRtEventTypeOpenCl)); - // Copy of the event to trigger the signal since the ownership of the - // input_event is transferred to the input_buffers[0]. - LiteRtEvent litert_input_event = input_event.Get(); - - // Fill model inputs. - auto input_names = signatures[0].InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - EXPECT_EQ(*input_buffers[0].BufferType(), kLiteRtTensorBufferTypeOpenCl); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - EXPECT_EQ(*input_buffers[1].BufferType(), kLiteRtTensorBufferTypeOpenCl); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Bind the input event to the input buffers. - // Note: The task should be done after the input buffers are filled. - // Otherwise the input_buffers[0].Write<> will be blocked by the associated - // event. - input_buffers[0].SetEvent(std::move(input_event)); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto output_buffers, compiled_model.CreateOutputBuffers(signature_index)); - - // Execute model asynchronously. - bool async_execution_mode = true; - compiled_model.RunAsync(signature_index, input_buffers, output_buffers, - async_execution_mode); - - // Signal the input event to resume the async execution. - LiteRtEventSignal(litert_input_event); - - // Check model output. - auto output_names = signatures[0].OutputNames(); - EXPECT_EQ(output_names.size(), 1); - EXPECT_EQ(output_names.at(0), "tfl.add"); - EXPECT_EQ(*output_buffers[0].BufferType(), kLiteRtTensorBufferTypeOpenCl); - { - auto lock_and_addr = - litert::TensorBufferScopedLock::Create(output_buffers[0]); - ASSERT_TRUE(lock_and_addr); - auto output = absl::MakeSpan(lock_and_addr->second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_integration_test.cc b/tensorflow/lite/experimental/litert/cc/litert_compiled_model_integration_test.cc deleted file mode 100644 index 7bc36e48268ed1..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_integration_test.cc +++ /dev/null @@ -1,376 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#if LITERT_HAS_OPENGL_SUPPORT -#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -namespace litert { -namespace { - -using ::testing::Eq; -using ::testing::FloatNear; -using ::testing::Pointwise; -using ::testing::SizeIs; - -constexpr absl::string_view kNpuFile = kGoogleTensorModelFileName; -constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; -constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - -TEST(CompiledModelTest, RunWithGoogleTensorModel) { - if (!litert::internal::AhwbBuffer::IsSupported()) { - GTEST_SKIP() - << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; - } - - // Environment setup. - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, - litert::Environment::Create(environment_options)); - - // Create Model. - - // TODO(gcarranza): Replace internal API with C++ API or single npu tflite - // file. - LITERT_ASSERT_OK_AND_ASSIGN( - BufferRef model_with_byte_code, - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile))); - - LITERT_ASSERT_OK_AND_ASSIGN(Model model, - Model::CreateFromBuffer(model_with_byte_code)); - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(model.DefaultSignatureKey())); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(model.DefaultSignatureKey())); - - ASSERT_THAT(input_buffers, SizeIs(2)); - ASSERT_THAT(output_buffers, SizeIs(1)); - - // Confirm input and output buffers are AHWB. - EXPECT_THAT(*input_buffers[0].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - EXPECT_THAT(*input_buffers[1].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - EXPECT_THAT(*output_buffers[0].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - - LITERT_ASSERT_OK(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - LITERT_ASSERT_OK(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Run compiled model. - compiled_model.Run(model.DefaultSignatureKey(), input_buffers, - output_buffers); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -TEST(CompiledModel, RunAsyncWithGoogleTensorModel) { - if (!litert::internal::AhwbBuffer::IsSupported()) { - GTEST_SKIP() - << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; - } - - // Environment setup. - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, - litert::Environment::Create(environment_options)); - - // Create Model. - - // TODO(gcarranza): Replace internal API with C++ API or single npu tflite - // file. - LITERT_ASSERT_OK_AND_ASSIGN( - BufferRef model_with_byte_code, - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile))); - - LITERT_ASSERT_OK_AND_ASSIGN(Model model, - Model::CreateFromBuffer(model_with_byte_code)); - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(model.DefaultSignatureKey())); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(model.DefaultSignatureKey())); - - ASSERT_THAT(input_buffers, SizeIs(2)); - ASSERT_THAT(output_buffers, SizeIs(1)); - - // Confirm input and output buffers are AHWB. - EXPECT_THAT(*input_buffers[0].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - EXPECT_THAT(*input_buffers[1].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - EXPECT_THAT(*output_buffers[0].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - - LITERT_ASSERT_OK(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - LITERT_ASSERT_OK(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Run compiled model. - bool async; - compiled_model.RunAsync(model.DefaultSignatureKey(), input_buffers, - output_buffers, async); - // Since output buffers have events, async should be true. - ASSERT_TRUE(async); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -void FillGlBuffer1(LiteRtGLuint id, size_t size) { -#if LITERT_HAS_OPENGL_SUPPORT - std::string shader_source = R"( #version 310 es - precision highp float; - layout(local_size_x = 1, local_size_y = 1) in; - layout(std430, binding = 0) buffer Output {float elements[];} output_data; - void main() { - uint v = gl_GlobalInvocationID.x * 2u; - output_data.elements[v] = float(v + 1u) / 1.0; - output_data.elements[v + 1u] = float(v + 2u) / 1.0; - })"; - GLuint shader = glCreateShader(GL_COMPUTE_SHADER); - const GLchar* sources[] = {shader_source.c_str()}; - glShaderSource(shader, 1, sources, nullptr); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glCompileShader(shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - - GLuint to_buffer_program = glCreateProgram(); - glAttachShader(to_buffer_program, shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDeleteShader(shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glLinkProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, id); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glUseProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDispatchCompute(size / 2, 1, 1); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDeleteProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -void FillGlBuffer2(LiteRtGLuint id, size_t size) { -#if LITERT_HAS_OPENGL_SUPPORT - std::string shader_source = R"( #version 310 es - precision highp float; - layout(local_size_x = 1, local_size_y = 1) in; - layout(std430, binding = 0) buffer Output {float elements[];} output_data; - void main() { - uint v = gl_GlobalInvocationID.x * 2u; - output_data.elements[v] = float(v + 1u) / 0.1; - output_data.elements[v + 1u] = float(v + 2u) / 0.1; - })"; - GLuint shader = glCreateShader(GL_COMPUTE_SHADER); - const GLchar* sources[] = {shader_source.c_str()}; - glShaderSource(shader, 1, sources, nullptr); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glCompileShader(shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - - GLuint to_buffer_program = glCreateProgram(); - glAttachShader(to_buffer_program, shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDeleteShader(shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glLinkProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, id); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glUseProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDispatchCompute(size / 2, 1, 1); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDeleteProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -TEST(CompiledModel, RunAsyncWithGoogleTensorModelUseAhwbGlInterop) { - if (!litert::internal::AhwbBuffer::IsSupported()) { - GTEST_SKIP() - << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; - } - - // Environment setup. - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, - litert::Environment::Create(environment_options)); - - // Create Model. - - // TODO(gcarranza): Replace internal API with C++ API or single npu tflite - // file. - LITERT_ASSERT_OK_AND_ASSIGN( - BufferRef model_with_byte_code, - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile))); - - LITERT_ASSERT_OK_AND_ASSIGN(Model model, - Model::CreateFromBuffer(model_with_byte_code)); - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(model.DefaultSignatureKey())); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(model.DefaultSignatureKey())); - - ASSERT_THAT(input_buffers, SizeIs(2)); - ASSERT_THAT(output_buffers, SizeIs(1)); - - // Confirm input and output buffers are AHWB. - EXPECT_THAT(*input_buffers[0].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - EXPECT_THAT(*input_buffers[1].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - EXPECT_THAT(*output_buffers[0].BufferType(), Eq(kLiteRtTensorBufferTypeAhwb)); - - // TODO(gcarranza): Integrate with LiteRT Environment. -#if LITERT_HAS_OPENGL_SUPPORT - std::unique_ptr egl_env; - ASSERT_TRUE( - tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&egl_env).ok()); - LITERT_LOG(LITERT_INFO, "Initialized EGL environment"); -#else - LITERT_LOG(LITERT_INFO, "EGL environment not initialized"); -#endif // LITERT_HAS_OPENGL_SUPPORT - - // Write to input buffers on GPU. - LITERT_ASSERT_OK_AND_ASSIGN(auto gl_buffer_1, input_buffers[0].GetGlBuffer()); - FillGlBuffer1(gl_buffer_1.id, 2); - LITERT_ASSERT_OK_AND_ASSIGN(auto gl_buffer_2, input_buffers[1].GetGlBuffer()); - FillGlBuffer2(gl_buffer_2.id, 2); - - // Create EGL sync and fence before AHWB read. - // TODO(gcarranza): Integrate into LiteRT C++ API. - LITERT_ASSERT_OK_AND_ASSIGN( - int native_fence, ::litert::internal::GlBuffer::CreateEglSyncAndFence()); - - LITERT_ASSERT_OK_AND_ASSIGN( - Event event_1, - Event::CreateFromSyncFenceFd(native_fence, /*owns_fd=*/false)); - LITERT_ASSERT_OK_AND_ASSIGN( - Event event_2, - Event::CreateFromSyncFenceFd(native_fence, /*owns_fd=*/false)); - - // Set event so that AHWB read is blocked by GPU write. - input_buffers[0].SetEvent(std::move(event_1)); - input_buffers[1].SetEvent(std::move(event_2)); - - // Run compiled model asynchronously. - bool async; - compiled_model.RunAsync(model.DefaultSignatureKey(), input_buffers, - output_buffers, async); - // Since output buffers have events, async should be true. - ASSERT_TRUE(async); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_test.cc b/tensorflow/lite/experimental/litert/cc/litert_compiled_model_test.cc deleted file mode 100644 index 426817fc792e1d..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_test.cc +++ /dev/null @@ -1,359 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" - -#include -#include -#include - -#include -#include -#include "absl/container/flat_hash_map.h" -#include "absl/log/absl_log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -using ::testing::ElementsAre; -using ::testing::Eq; -using testing::FloatNear; -using testing::Pointwise; -using ::testing::SizeIs; - -namespace litert { -namespace { - -TEST(CompiledModelTest, Basic) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, litert::Environment::Create({})); - - // Create Model. - Model model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - // Check CompiledModel buffer requirements. - // input and output expect host memory. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg0, - compiled_model.GetInputBufferRequirements(/*input_name=*/"arg0")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg0, - input_buffer_requirements_arg0.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg1, - compiled_model.GetInputBufferRequirements(/*input_name=*/"arg1")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg1, - input_buffer_requirements_arg1.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements output_buffer_requirements, - compiled_model.GetOutputBufferRequirements(/*output_name=*/"tfl.add")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffer_types, - output_buffer_requirements.SupportedTypes()); - EXPECT_THAT(output_buffer_types, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - // Create and fill input and output buffers. - LITERT_ASSERT_OK_AND_ASSIGN(std::vector input_buffers, - compiled_model.CreateInputBuffers()); - - LITERT_ASSERT_OK_AND_ASSIGN(std::vector output_buffers, - compiled_model.CreateOutputBuffers()); - - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model with input and output buffers. - compiled_model.Run(input_buffers, output_buffers); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -TEST(CompiledModelTest, BasicSignatureIndex) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, litert::Environment::Create({})); - - // Create Model and check signatures. - Model model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - LITERT_ASSERT_OK_AND_ASSIGN(std::vector signatures, - model.GetSignatures()); - EXPECT_EQ(signatures.size(), 1); - absl::string_view signature_key = signatures[0].Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - std::vector input_names = signatures[0].InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - std::vector output_names = signatures[0].OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.add")); - - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - // Check CompiledModel buffer requirements. - // input and output expect host memory. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg0, - compiled_model.GetInputBufferRequirements(signature_index, - /*input_name=*/"arg0")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg0, - input_buffer_requirements_arg0.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg1, - compiled_model.GetInputBufferRequirements(signature_index, - /*input_name=*/"arg1")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg1, - input_buffer_requirements_arg1.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements output_buffer_requirements, - compiled_model.GetOutputBufferRequirements(signature_index, - /*output_name=*/"tfl.add")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffer_types, - output_buffer_requirements.SupportedTypes()); - EXPECT_THAT(output_buffer_types, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - // Create and fill input and output buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(signature_index)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(signature_index)); - - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model with input and output buffers. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -TEST(CompiledModelTest, RunWithInputOutputMap) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, litert::Environment::Create({})); - - // Create Model and check signatures. - Model model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - LITERT_ASSERT_OK_AND_ASSIGN(std::vector signatures, - model.GetSignatures()); - EXPECT_EQ(signatures.size(), 1); - absl::string_view signature_key = signatures[0].Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - std::vector input_names = signatures[0].InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - std::vector output_names = signatures[0].OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.add")); - - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - // Check CompiledModel buffer requirements. - // input and output expect host memory. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg0, - compiled_model.GetInputBufferRequirements(signature_index, - /*input_name=*/"arg0")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg0, - input_buffer_requirements_arg0.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg1, - compiled_model.GetInputBufferRequirements(signature_index, - /*input_name=*/"arg1")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg1, - input_buffer_requirements_arg1.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements output_buffer_requirements, - compiled_model.GetOutputBufferRequirements(signature_index, - /*output_name=*/"tfl.add")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffer_types, - output_buffer_requirements.SupportedTypes()); - EXPECT_THAT(output_buffer_types, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - // Create and fill input and output buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer input_buffer0, - compiled_model.CreateInputBuffer(signature_key, "arg0")); - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer input_buffer1, - compiled_model.CreateInputBuffer(signature_key, "arg1")); - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer output_buffer0, - compiled_model.CreateOutputBuffer(signature_key, "tfl.add")); - - ASSERT_TRUE(input_buffer0.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffer1.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Create input and output map. - absl::flat_hash_map input_map; - input_map["arg0"] = std::move(input_buffer0); - input_map["arg1"] = std::move(input_buffer1); - - absl::flat_hash_map output_map; - output_map["tfl.add"] = std::move(output_buffer0); - - // Execute model with input and output maps instead of buffers. - compiled_model.Run(signature_key, input_map, output_map); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, litert::TensorBufferScopedLock::Create( - output_map["tfl.add"])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -// Tests Compiled Model async API on CPU. In the CPU case, the async API should -// always return false. -TEST(CompiledModelTest, RunAsyncReturnsFalse) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, litert::Environment::Create({})); - - // Create Model and check signatures. - Model model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - // Create input and output buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(model.DefaultSignatureKey())); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(model.DefaultSignatureKey())); - - // Confirm input and output buffers are host memory. - EXPECT_THAT(*input_buffers[0].BufferType(), - Eq(kLiteRtTensorBufferTypeHostMemory)); - EXPECT_THAT(*input_buffers[1].BufferType(), - Eq(kLiteRtTensorBufferTypeHostMemory)); - EXPECT_THAT(*output_buffers[0].BufferType(), - Eq(kLiteRtTensorBufferTypeHostMemory)); - - ASSERT_THAT(input_buffers, SizeIs(2)); - ASSERT_THAT(output_buffers, SizeIs(1)); - - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model with input and output buffers. - bool async; - compiled_model.RunAsync(model.DefaultSignatureKey(), input_buffers, - output_buffers, async); - // Since there are no events on the output buffers, async should be false. - ASSERT_FALSE(async); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_consts.h b/tensorflow/lite/experimental/litert/cc/litert_consts.h deleted file mode 100644 index 14ac9a0b00e832..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_consts.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_CONSTS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_CONSTS_H_ - -#include - -namespace litert { - -// The following constants are used to properly size absl::InlinedVector<> -// uses used in the LiteRT code. Their values don't need to be exact; they -// are just optimization hints. -static constexpr size_t kExpectedMaxTensorRank = 6; -static constexpr size_t kExpectedMaxNumOfTensorUses = 8; -static constexpr size_t kExpectedMaxNumOfOpInputs = 4; -static constexpr size_t kExpectedMaxNumOfOpOutputs = 8; -static constexpr size_t kExpectedMaxNumOfSubgraphInputs = 4; -static constexpr size_t kExpectedMaxNumOfSubgraphOutputs = 4; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_CONSTS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_detail.h b/tensorflow/lite/experimental/litert/cc/litert_detail.h deleted file mode 100644 index 566d8468fa8148..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_detail.h +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DETAIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DETAIL_H_ - -#include -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -namespace litert { - -// See "std::construct_at" from C++20. -template -T* ConstructAt(T* p, Args&&... args) { - return ::new (static_cast(p)) T(std::forward(args)...); -} - -// Reduce all over zipped iters of same size. -template -bool AllZip(const LeftVals& lhs, const RightVals& rhs, - std::function - bin_pred) { - if (lhs.size() != rhs.size()) { - return false; - } - for (auto i = 0; i < lhs.size(); ++i) { - if (!bin_pred(lhs.at(i), rhs.at(i))) { - return false; - } - } - return true; -} - -// Reduce any over zipped iters of same size. -template -bool AnyZip(const LeftVals& lhs, const RightVals& rhs, - std::function - bin_pred) { - auto neg = [&](const auto& l, const auto& r) { return !bin_pred(l, r); }; - return !(AllZip(lhs, rhs, neg)); -} - -// Does element exist in range. -template -bool Contains(It begin, It end, const T& val) { - return std::find(begin, end, val) != end; -} - -// Does element exist in range satisfying pred. -template -bool ContainsIf(It begin, It end, UPred u_pred) { - return std::find_if(begin, end, u_pred) != end; -} - -// Get the ind of the given element if it is present. -template -std::optional FindInd(It begin, It end, T val) { - auto it = std::find(begin, end, val); - return (it == end) ? std::nullopt : std::make_optional(it - begin); -} - -namespace internal { - -// Call function "get" and assert it returns value equal to given expected -// value. -template -inline void AssertEq(F get, Expected expected, Args&&... args) { - auto status = get(std::forward(args)...); - ABSL_CHECK_EQ(status, expected); -} - -// Call function "get" and assert it returns true. -template -inline void AssertTrue(F get, Args&&... args) { - AssertEq(get, true, std::forward(args)...); -} - -// Call function "get" and assert it returns an OK LiteRtStatus. -template -inline void AssertOk(F get, Args&&... args) { - AssertEq(get, kLiteRtStatusOk, std::forward(args)...); -} - -} // namespace internal -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DETAIL_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h b/tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h deleted file mode 100644 index bdbb3a0c4df8c7..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DISPATCH_DELEGATE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DISPATCH_DELEGATE_H_ - -#include - -#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" - -namespace litert { - -using DispatchDelegateOptionsPtr = - std::unique_ptr; - -using DispatchDelegatePtr = tflite::TfLiteOpaqueDelegateUniquePtr; - -DispatchDelegateOptionsPtr CreateDispatchDelegateOptionsPtr( - LiteRtEnvironmentOptions environment_options); - -DispatchDelegatePtr CreateDispatchDelegatePtr( - LiteRtEnvironmentOptions environment_options, - DispatchDelegateOptionsPtr&& options); - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_DISPATCH_DELEGATE_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_element_type.h b/tensorflow/lite/experimental/litert/cc/litert_element_type.h deleted file mode 100644 index 84b032b3820a7a..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_element_type.h +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ELEMENT_TYPE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ELEMENT_TYPE_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" - -namespace litert { - -// Data type of tensor elements. C++ equivalent to LiteRtElementType. -enum class ElementType { - None = kLiteRtElementTypeNone, - Bool = kLiteRtElementTypeBool, - Int4 = kLiteRtElementTypeInt4, - Int8 = kLiteRtElementTypeInt8, - Int16 = kLiteRtElementTypeInt16, - Int32 = kLiteRtElementTypeInt32, - Int64 = kLiteRtElementTypeInt64, - UInt8 = kLiteRtElementTypeUInt8, - UInt16 = kLiteRtElementTypeUInt16, - UInt32 = kLiteRtElementTypeUInt32, - UInt64 = kLiteRtElementTypeUInt64, - Float16 = kLiteRtElementTypeFloat16, - BFloat16 = kLiteRtElementTypeBFloat16, - Float32 = kLiteRtElementTypeFloat32, - Float64 = kLiteRtElementTypeFloat64, - Complex64 = kLiteRtElementTypeComplex64, - Complex128 = kLiteRtElementTypeComplex128, - TfResource = kLiteRtElementTypeTfResource, - TfString = kLiteRtElementTypeTfString, - TfVariant = kLiteRtElementTypeTfVariant, -}; - -// Get number of bytes of a single element of given type. -inline constexpr std::optional GetByteWidth(ElementType ty) { - if (ty == ElementType::Bool) - return 1; - else if (ty == ElementType::Int8) - return 1; - else if (ty == ElementType::Int16) - return 2; - else if (ty == ElementType::Int32) - return 4; - else if (ty == ElementType::Int64) - return 8; - else if (ty == ElementType::UInt8) - return 1; - else if (ty == ElementType::UInt16) - return 2; - else if (ty == ElementType::UInt32) - return 4; - else if (ty == ElementType::UInt64) - return 8; - else if (ty == ElementType::Float16) - return 2; - else if (ty == ElementType::BFloat16) - return 2; - else if (ty == ElementType::Float32) - return 4; - else if (ty == ElementType::Float64) - return 8; - else - return std::nullopt; -} - -// Get number of bytes of a single element of given type via template. -template -inline constexpr size_t GetByteWidth() { - constexpr auto byte_width = GetByteWidth(Ty); - static_assert(byte_width.has_value(), "Type does not have byte width"); - return byte_width.value(); -} - -template -constexpr bool dependent_false = false; // workaround before CWG2518/P2593R1 - -// Get the litert::ElementType associated with given C++ type. -template -inline constexpr ElementType GetElementType() { - static_assert(dependent_false, "Uknown C++ type"); - return ElementType::None; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Bool; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Int8; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::UInt8; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Int16; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::UInt16; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Int32; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::UInt32; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Int64; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::UInt64; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Float32; -} - -template <> -inline constexpr ElementType GetElementType() { - return ElementType::Float64; -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ELEMENT_TYPE_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_element_type_test.cc b/tensorflow/lite/experimental/litert/cc/litert_element_type_test.cc deleted file mode 100644 index 929bc499f32c63..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_element_type_test.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" - -#include -#include - -#include - -namespace litert { - -namespace { - -template -class ElementTypeTest : public ::testing::Test { - public: - size_t Size() const { return sizeof(T); } -}; - -TYPED_TEST_SUITE_P(ElementTypeTest); - -TYPED_TEST_P(ElementTypeTest, TypeAndSize) { - const size_t size = GetByteWidth()>(); - EXPECT_EQ(size, this->Size()); -} - -REGISTER_TYPED_TEST_SUITE_P(ElementTypeTest, TypeAndSize); - -using Types = - ::testing::Types; - -INSTANTIATE_TYPED_TEST_SUITE_P(ElementTypeTestSuite, ElementTypeTest, Types); - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_environment.h b/tensorflow/lite/experimental/litert/cc/litert_environment.h deleted file mode 100644 index 69faebdea892d3..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_environment.h +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ENVIRONMENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ENVIRONMENT_H_ - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -namespace litert { - -class Environment - : public internal::Handle { - public: - explicit Environment(LiteRtEnvironment env) - : internal::Handle(env, - true) {} - - enum class OptionTag { - CompilerPluginLibraryDir = kLiteRtEnvOptionTagCompilerPluginLibraryDir, - DispatchLibraryDir = kLiteRtEnvOptionTagDispatchLibraryDir, - }; - - struct Option { - OptionTag tag; - std::any value; - }; - - static Expected Create(absl::Span options) { - auto c_options = ConvertOptions(options); - if (!c_options) { - return c_options.Error(); - } - LiteRtEnvironment env; - if (auto status = - LiteRtEnvironmentCreate(c_options->size(), c_options->data(), &env); - status != kLiteRtStatusOk) { - return Error(status); - } else { - return Environment(env); - } - } - - private: - static Expected> ConvertOptions( - absl::Span options) { - std::vector c_options; - c_options.reserve(options.size()); - - for (auto& option : options) { - auto litert_any = ToLiteRtAny(option.value); - if (!litert_any) { - return litert_any.Error(); - } - - LiteRtEnvOption c_option = { - /*.tag=*/static_cast(option.tag), - /*.value=*/*litert_any, - }; - c_options.push_back(c_option); - } - - return c_options; - } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ENVIRONMENT_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_environment_test.cc b/tensorflow/lite/experimental/litert/cc/litert_environment_test.cc deleted file mode 100644 index 0f012aedfee66b..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_environment_test.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" - -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -namespace litert { -namespace { - -TEST(EnvironmentTest, Default) { - auto env = litert::Environment::Create({}); - EXPECT_TRUE(env); -} - -TEST(EnvironmentTest, Options) { - constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - EXPECT_TRUE(env); -} - -TEST(EnvironmentTest, CompiledModelBasic) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, litert::Environment::Create({})); - - // Create Model and check signatures. - Model model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - // Create CompiledModel. - auto compiled_model = CompiledModel::Create(env, model); - EXPECT_TRUE(compiled_model); -} - -TEST(EnvironmentTest, StringLifeCycle) { - std::string dispatch_library_dir = "/data/local/tmp"; - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - absl::string_view(dispatch_library_dir), - }, - }; - - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - - EXPECT_TRUE(env); - - // Change the string value but the environment should still have a copy. - dispatch_library_dir = ""; - - // Create Model and check signatures. - Model model = testing::LoadTestFileModel(kModelFileName); - ASSERT_TRUE(model); - - // Create CompiledModel. - auto compiled_model = CompiledModel::Create(*env, model); - EXPECT_TRUE(compiled_model); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_event.h b/tensorflow/lite/experimental/litert/cc/litert_event.h deleted file mode 100644 index 0f8582205d9c39..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_event.h +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EVENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EVENT_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_event_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -extern "C" { -// Forward declaration of OpenCL event to avoid including OpenCL headers. -typedef struct _cl_event* cl_event; -} - -namespace litert { - -class Event : public internal::Handle { - public: - // Parameter `owned` indicates if the created TensorBufferRequirements object - // should take ownership of the provided `requirements` handle. - explicit Event(LiteRtEvent event, bool owned = true) - : internal::Handle(event, owned) {} - - // Creates an Event object with the given `sync_fence_fd`. - static Expected CreateFromSyncFenceFd(int sync_fence_fd, - bool owns_fd) { - LiteRtEvent event; - LITERT_RETURN_IF_ERROR( - LiteRtCreateEventFromSyncFenceFd(sync_fence_fd, owns_fd, &event)); - return Event(event); - } - - // Creates an Event object with the given `cl_event`. - static Expected CreateFromOpenClEvent(cl_event cl_event) { - LiteRtEvent event; - LITERT_RETURN_IF_ERROR(LiteRtCreateEventFromOpenClEvent(cl_event, &event)); - return Event(event); - } - - // Creates a managed event of the given `type`. Currently only - // LiteRtEventTypeOpenCl is supported. - static Expected CreateManaged(LiteRtEventType type) { - LiteRtEvent event; - LITERT_RETURN_IF_ERROR(LiteRtCreateManagedEvent(type, &event)); - return Event(event); - } - - Expected GetSyncFenceFd() { - int fd; - LITERT_RETURN_IF_ERROR(LiteRtGetEventSyncFenceFd(Get(), &fd)); - return fd; - } - - // Returns the underlying OpenCL event if the event type is OpenCL. - Expected GetOpenClEvent() { - cl_event cl_event; - LITERT_RETURN_IF_ERROR(LiteRtGetEventOpenClEvent(Get(), &cl_event)); - return cl_event; - } - - // Pass -1 for timeout_in_ms for indefinite wait. - Expected Wait(int64_t timeout_in_ms) { - LITERT_RETURN_IF_ERROR(LiteRtEventWait(Get(), timeout_in_ms)); - return {}; - } - - // Singal the event. - // Note: This is only supported for OpenCL events. - Expected Signal() { - LITERT_RETURN_IF_ERROR(LiteRtEventSignal(Get())); - return {}; - } - - // Returns the underlying event type. - LiteRtEventType Type() const { - LiteRtEventType type; - LiteRtGetEventEventType(Get(), &type); - return type; - } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EVENT_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_event_test.cc b/tensorflow/lite/experimental/litert/cc/litert_event_test.cc deleted file mode 100644 index 752e8c0a6c3ce6..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_event_test.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" - -#include -#include -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace litert { -namespace { - -using ::testing::Eq; - -constexpr int kFakeSyncFenceFd = 1; - -TEST(Event, NoEvent) { - LITERT_ASSERT_OK_AND_ASSIGN( - Event event, Event::CreateFromSyncFenceFd(kFakeSyncFenceFd, true)); - LITERT_ASSERT_OK_AND_ASSIGN(int fd, event.GetSyncFenceFd()); - EXPECT_THAT(fd, Eq(kFakeSyncFenceFd)); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_expected.h b/tensorflow/lite/experimental/litert/cc/litert_expected.h deleted file mode 100644 index 5d60e86094391a..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_expected.h +++ /dev/null @@ -1,390 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EXPECTED_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EXPECTED_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" - -namespace litert { - -// An "Expected" incapsulates the result of some routine which may have an -// unexpected result. Unexpected results in this context are a standard -// LiteRtStatus plus extra usability data such as error messages. This is -// similar to an absl::StatusOr or std::expected (C++23) but better integrated -// with LiteRtStatus as the canonical status code. - -// C++ wrapper around LiteRtStatus code. Provides a status as well -// as an error message. -class Error { - public: - // Construct Unexpected from status and optional error message. - // - // NOTE: kLiteRtStatusOk should not be passed to Unexpected. - explicit Error(LiteRtStatus status, std::string message = "") - : status_(status), message_(std::move(message)) { - ABSL_DCHECK(status != kLiteRtStatusOk); - } - - // Get the status. - constexpr LiteRtStatus Status() const { return status_; } - - // Get the error message, empty string if none was attached. - const std::string& Message() const { return message_; } - - friend std::ostream& operator<<(std::ostream& stream, const Error& error) { - stream << LiteRtGetStatusString(error.Status()); - if (!error.Message().empty()) { - stream << ": " << error.Message(); - } - return stream; - } - - template - friend void AbslStringify(Sink& sink, const Error& error) { - absl::Format(&sink, "%s", LiteRtGetStatusString(error.Status())); - if (!error.Message().empty()) { - absl::Format(&sink, ": %v", error.Message()); - } - } - - private: - LiteRtStatus status_; - std::string message_; -}; - -class Unexpected { - public: - template - constexpr explicit Unexpected(Args&&... args) - : error_(std::forward(args)...) {} - - // Allow for implicit conversion from convertible Error value inplace. - // NOLINTNEXTLINE(*-explicit-constructor) - Unexpected(class Error&& e) : error_(std::move(e)) {} - - Unexpected(Unexpected&& other) = default; - Unexpected(const Unexpected& other) = default; - Unexpected& operator=(Unexpected&& other) = default; - Unexpected& operator=(const Unexpected& other) = default; - - constexpr const class Error& Error() const& noexcept { return error_; } - constexpr class Error& Error() & noexcept { return error_; } - constexpr const class Error&& Error() const&& noexcept { - return std::move(error_); - } - constexpr class Error&& Error() && noexcept { return std::move(error_); } - - template - friend void AbslStringify(Sink& sink, const Unexpected& unexpected) { - AbslStringify(sink, unexpected.Error()); - } - - private: - class Error error_; -}; - -// Utility for generic return values that may be a statused failure. Expecteds -// store and own the lifetime of either an Unexpected, or a T. T may be any -// type, primitive or non-primitive. -// -// No dynamic allocations occur during initialization, so the underlying T is -// only movable (as opposed to something like "release"). Arguments should be -// constructed in place at the time of initializing the expected if possible. -// -// Unexpected&& and T&& may be implicitly casted -// to an Expected. For example, -// -// Expected Bar() { -// bool success = ... -// if (!success) { return Unexpected(kLiteRtStatus, "Bad Baz"); } -// return Foo(); -// } -// -template -class Expected { - public: - // Construct Expected with T inplace. - - // Construct T from initializer list inplace. - template - Expected(std::initializer_list il) : has_value_(true), value_(il) {} - - // Construct T from forwarded args inplace. - template - explicit Expected(Args&&... args) - : has_value_(true), value_(std::forward(args)...) {} - - // NOLINTBEGIN(*-explicit-constructor) - - // Allow for implicit conversion from convertible T value inplace. - Expected(const T& t) : has_value_(true), value_(t) {} - Expected(T&& t) : has_value_(true), value_(std::move(t)) {} - - // Construct from Unexpected inplace. - - // Allow for implicit conversion from Error. - Expected(const Unexpected& err) : has_value_(false), unexpected_(err) {} - Expected(Unexpected&& err) : has_value_(false), unexpected_(std::move(err)) {} - Expected(const class Error& e) : has_value_(false), unexpected_(e) {} - - // NOLINTEND(*-explicit-constructor) - - // Copy/move - - Expected(Expected&& other) : has_value_(other.HasValue()) { - if (HasValue()) { - ConstructAt(std::addressof(value_), std::move(other.value_)); - } else { - ConstructAt(std::addressof(unexpected_), std::move(other.unexpected_)); - } - } - - Expected(const Expected& other) : has_value_(other.has_value_) { - if (HasValue()) { - ConstructAt(std::addressof(value_), other.value_); - value_ = other.value_; - } else { - ConstructAt(std::addressof(unexpected_), other.unexpected_); - } - } - - Expected& operator=(Expected&& other) { - if (this != &other) { - if (HasValue()) { - if (other.HasValue()) { - value_ = std::move(other.value_); - } else { - value_.~T(); - ConstructAt(std::addressof(unexpected_), - std::move(other.unexpected_)); - } - } else { - if (other.HasValue()) { - unexpected_.~Unexpected(); - ConstructAt(std::addressof(value_), std::move(other.value_)); - } else { - unexpected_ = std::move(other.unexpected_); - } - } - has_value_ = other.has_value_; - } - return *this; - } - - Expected& operator=(const Expected& other) { - if (this != &other) { - if (HasValue()) { - if (other.HasValue()) { - value_ = other.value_; - } else { - value_.~T(); - ConstructAt(std::addressof(unexpected_), other.unexpected_); - } - } else { - if (other.HasValue()) { - unexpected_.~Unexpected(); - ConstructAt(std::addressof(value_), other.value_); - } else { - unexpected_ = other.unexpected_; - } - } - has_value_ = other.has_value_; - } - return *this; - } - - ~Expected() { - if (has_value_ && std::is_destructible()) { - value_.~T(); - } else { - unexpected_.~Unexpected(); - } - } - - // Observers for T value, program exits if it doesn't have one. - const T& Value() const& { - CheckVal(); - return value_; - } - - T& Value() & { - CheckVal(); - return value_; - } - - const T&& Value() const&& { - CheckVal(); - return std::move(value_); - } - - T&& Value() && { - CheckVal(); - return std::move(value_); - } - - const T* operator->() const { - CheckVal(); - return &value_; - } - - T* operator->() { - CheckVal(); - return &value_; - } - - const T& operator*() const& { return Value(); } - - T& operator*() & { return Value(); } - - const T&& operator*() const&& { return std::move(Value()); } - - T&& operator*() && { return std::move(Value()); } - - // Observer for Unexpected, program exits if it doesn't have one. - const class Error& Error() const& { - CheckNoVal(); - return unexpected_.Error(); - } - - class Error& Error() & { - CheckNoVal(); - return unexpected_.Error(); - } - - const class Error&& Error() const&& { - CheckNoVal(); - return std::move(unexpected_.Error()); - } - - class Error&& Error() && { - CheckNoVal(); - return std::move(unexpected_.Error()); - } - - // Does this expected contain a T Value. It contains an unexpected if not. - bool HasValue() const { return has_value_; } - - // Convert to bool for HasValue. - explicit operator bool() const { return HasValue(); } - - private: - bool has_value_; - union { - T value_; - Unexpected unexpected_; - }; - void CheckNoVal() const { ABSL_CHECK(!HasValue()); } - void CheckVal() const { ABSL_CHECK(HasValue()); } -}; - -namespace internal { -template -struct CanBeAbslFormated { - template - static constexpr auto Check(int) - -> decltype(absl::StrCat(std::declval()), true) { - return true; - } - template - static constexpr bool Check(...) { - return false; - } - enum { value = Check(0) }; -}; -} // namespace internal - -template -void AbslStringify(Sink& sink, const Expected& expected) { - if (!expected.HasValue()) { - absl::Format(&sink, "%v", expected.Error()); - } else { - if constexpr (std::is_same_v) { - sink.Append("void expected value"); - } else { - if constexpr (internal::CanBeAbslFormated::value) { - absl::Format(&sink, "%v", expected.Value()); - } else { - absl::Format(&sink, "unformattable expected value"); - } - } - } -} - -template <> -class Expected { - public: - // Implicit construction is used to simplify returning a valid value, e.g., in - // "return {};" - Expected() : unexpected_(std::nullopt) {} - - // NOLINTBEGIN(*-explicit-constructor) - - // Construct from Unexpected inplace. - Expected(const Unexpected& err) : unexpected_(err) {} - Expected(Unexpected&& err) : unexpected_(std::move(err)) {} - - // Allow for implicit conversion from Error. - Expected(const Error& e) : unexpected_(e) {} - - // NOLINTEND(*-explicit-constructor) - - // Observer for Unexpected, program exits if it doesn't have one. - const class Error& Error() const& { - CheckNoVal(); - return unexpected_->Error(); - } - - class Error& Error() & { - CheckNoVal(); - return unexpected_->Error(); - } - - const class Error&& Error() const&& { - CheckNoVal(); - return std::move(unexpected_->Error()); - } - - class Error&& Error() && { - CheckNoVal(); - return std::move(unexpected_->Error()); - } - - // Does this expected contain a T Value. It contains an unexpected if not. - bool HasValue() const { return !unexpected_.has_value(); } - - // Convert to bool for HasValue. - explicit operator bool() const { return HasValue(); } - - private: - std::optional unexpected_; - void CheckNoVal() const { ABSL_CHECK(!HasValue()); } - void CheckVal() const { ABSL_CHECK(HasValue()); } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EXPECTED_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_expected_test.cc b/tensorflow/lite/experimental/litert/cc/litert_expected_test.cc deleted file mode 100644 index ad68a834dbe80f..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_expected_test.cc +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include "absl/strings/str_cat.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" - -namespace litert { - -namespace { -using testing::StrEq; - -static constexpr LiteRtStatus kErrorStatus = kLiteRtStatusErrorInvalidArgument; - -struct TypeWithAllocation { - TypeWithAllocation(std::initializer_list il) : allocated(il) {} - std::vector allocated; -}; - -struct TypeWithFields { - TypeWithFields(int i_, int j_) : i(i_), j(j_) {} - int i; - int j; -}; - -TEST(ExpectedTest, PrimitiveExplicit) { - Expected exp(1.0); - ASSERT_TRUE(exp.HasValue()); -} - -TEST(ExpectedTest, PrimitiveImplicit) { - Expected exp = 1.0; - ASSERT_TRUE(exp.HasValue()); -} - -TEST(ExpectedTest, ClassWithAllocation) { - Expected exp(TypeWithAllocation({1, 2, 3})); - ASSERT_TRUE(exp.HasValue()); -} - -TEST(ExpectedTest, ClassWithFields) { - Expected exp(TypeWithFields(1, 2)); - ASSERT_TRUE(exp.HasValue()); -} - -TEST(ExpectedTest, FromErrorExplicit) { - Expected exp((Unexpected(kErrorStatus, "MESSAGE"))); - ASSERT_FALSE(exp.HasValue()); -} - -TEST(ExpectedTest, FromErrorImplicit) { - Expected exp = Unexpected(kErrorStatus); - ASSERT_FALSE(exp.HasValue()); -} - -TEST(ExpectedTest, CopyCstorError) { - const Expected exp = Unexpected(kErrorStatus); - Expected other(exp); - ASSERT_FALSE(other.HasValue()); - EXPECT_EQ(other.Error().Status(), kErrorStatus); -} - -TEST(ExpectedTest, CopyCstorVal) { - const Expected exp = 2; - Expected other(exp); - ASSERT_TRUE(other.HasValue()); - EXPECT_EQ(other.Value(), 2); -} - -TEST(ExpectedTest, CopyAssignError) { - const Expected exp = Unexpected(kErrorStatus); - ASSERT_FALSE(exp.HasValue()); - Expected other = exp; - ASSERT_FALSE(other.HasValue()); - EXPECT_EQ(other.Error().Status(), kErrorStatus); -} - -TEST(ExpectedTest, CopyAssignVal) { - const Expected exp = 2; - Expected other = exp; - ASSERT_TRUE(other.HasValue()); - EXPECT_EQ(other.Value(), 2); -} - -TEST(ExpectedTest, MoveCstorError) { - Expected exp = Unexpected(kErrorStatus); - Expected other(std::move(exp)); - ASSERT_FALSE(other.HasValue()); - EXPECT_EQ(other.Error().Status(), kErrorStatus); -} - -TEST(ExpectedTest, MoveCstorVal) { - Expected exp = 2; - Expected other(std::move(exp)); - ASSERT_TRUE(other.HasValue()); - EXPECT_EQ(other.Value(), 2); -} - -TEST(ExpectedTest, MoveAssignError) { - Expected exp = Unexpected(kErrorStatus); - Expected other = std::move(exp); - ASSERT_FALSE(other.HasValue()); - EXPECT_EQ(other.Error().Status(), kErrorStatus); -} - -TEST(ExpectedTest, MoveAssignVal) { - Expected exp = 2; - Expected other = std::move(exp); - ASSERT_TRUE(other.HasValue()); - EXPECT_EQ(other.Value(), 2); -} - -TEST(ExpectedTest, Indirection) { - Expected exp(TypeWithFields(1, 2)); - EXPECT_EQ(exp->i, 1); - EXPECT_EQ(exp->j, 2); -} - -TEST(ExpectedTest, Dereference) { - Expected exp(TypeWithFields(1, 2)); - const auto& val = *exp; - EXPECT_EQ(val.i, 1); - EXPECT_EQ(val.j, 2); -} - -TEST(UnexpectedTest, WithStatus) { - Unexpected err(kErrorStatus); - EXPECT_EQ(err.Error().Status(), kErrorStatus); - EXPECT_TRUE(err.Error().Message().empty()); -} - -TEST(UnexpectedTest, WithMessage) { - Unexpected err(kErrorStatus, "MESSAGE"); - EXPECT_EQ(err.Error().Status(), kErrorStatus); - EXPECT_EQ(err.Error().Message(), "MESSAGE"); -} - -TEST(UnexpectedTest, WithLocalMessageString) { - // Message is a string with scoped lifetime. - Unexpected err(kErrorStatus, absl::StrCat("MESSAGE", 1)); - EXPECT_EQ(err.Error().Status(), kErrorStatus); - EXPECT_EQ(err.Error().Message(), "MESSAGE1"); -} - -Expected> Go() { - std::string data = "21234"; - OwningBufferRef buf(data.c_str()); - return buf; -} - -Expected> Forward() { - auto thing = Go(); - if (!thing.HasValue()) { - return thing.Error(); - } - // No copy elision here. - return thing; -} - -TEST(ExpectedTest, ForwardBufThroughFuncs) { - auto res = Forward(); - EXPECT_TRUE(res.HasValue()); - EXPECT_EQ(res->StrView(), "21234"); -} - -TEST(ExpectedWithNoValue, WithoutError) { - Expected expected = {}; - EXPECT_TRUE(expected.HasValue()); -} - -TEST(ExpectedWithNoValue, WithError) { - Expected expected(Unexpected(kErrorStatus, "MESSAGE")); - EXPECT_FALSE(expected.HasValue()); - EXPECT_EQ(expected.Error().Status(), kErrorStatus); - EXPECT_EQ(expected.Error().Message(), "MESSAGE"); -} - -TEST(ExpectedWithNoValue, OStreamOutput) { - Expected expected(Unexpected(kErrorStatus, "MESSAGE")); - std::ostringstream oss; - oss << expected.Error(); - EXPECT_THAT(oss.str(), testing::HasSubstr("MESSAGE")); -} - -TEST(ExpectedTest, PrintingWorks) { - EXPECT_THAT(absl::StrCat(Expected(3)), StrEq("3")); - - EXPECT_THAT(absl::StrCat(Expected()), StrEq("void expected value")); - - EXPECT_THAT(absl::StrCat(Unexpected(kLiteRtStatusErrorNotFound)), - StrEq("kLiteRtStatusErrorNotFound")); - - EXPECT_THAT(absl::StrCat(Unexpected(kLiteRtStatusErrorNotFound, - "Error not found message")), - StrEq("kLiteRtStatusErrorNotFound: Error not found message")); - - EXPECT_THAT(absl::StrCat(Error(kLiteRtStatusErrorNotFound)), - StrEq("kLiteRtStatusErrorNotFound")); - - EXPECT_THAT(absl::StrCat( - Error(kLiteRtStatusErrorNotFound, "Error not found message")), - StrEq("kLiteRtStatusErrorNotFound: Error not found message")); - - struct UnknownStruct {}; - EXPECT_THAT(absl::StrCat(Expected({})), - StrEq("unformattable expected value")); -} - -} // namespace - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_handle.h b/tensorflow/lite/experimental/litert/cc/litert_handle.h deleted file mode 100644 index 503eaad335b764..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_handle.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_HANDLE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_HANDLE_H_ - -#include -#include - -namespace litert { -namespace internal { - -template -inline void DummyDeleter(H) {} - -// This class is used to wrap and manage the lifetime of opaque handles from the -// C API into an equivalent C++ object. The class is a wrapper on -// std::unique_ptr<> that has a default constructor and doesn't crash if the -// deleter is null. -template -class Handle { - public: - Handle() = default; - explicit Handle(H handle, bool owned) noexcept - : ptr_(handle, owned ? deleter : DummyDeleter) {} - - Handle(Handle&& other) noexcept { *this = std::move(other); } - - Handle& operator=(Handle&& other) noexcept { - std::swap(ptr_, other.ptr_); - return *this; - } - - // Return true if the underlying LiteRtTensorBuffer handle is valid. - explicit operator bool() const noexcept { return static_cast(ptr_); } - - // Return the underlying LiteRtTensorBuffer handle. - H Get() const noexcept { return ptr_.get(); } - - H Release() noexcept { return ptr_.release(); } - - bool IsOwned() const noexcept { - return ptr_.get_deleter() != DummyDeleter; - } - - private: - std::unique_ptr, void (*)(H)> ptr_ = {nullptr, - DummyDeleter}; -}; - -// This class is similar to Handle, but the managed opaque handle is not owned -// (i.e., it will not be destroyed). -template -class NonOwnedHandle : public Handle> { - public: - explicit NonOwnedHandle(H handle) noexcept - : Handle>(handle, /*owned=*/false) {} -}; - -} // namespace internal -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_HANDLE_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_layout.h b/tensorflow/lite/experimental/litert/cc/litert_layout.h deleted file mode 100644 index a8f90ac6dc1069..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_layout.h +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_LAYOUT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_LAYOUT_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_layout.h" -#include "tensorflow/lite/experimental/litert/cc/litert_consts.h" - -namespace litert { - -using Dimensions = absl::InlinedVector; -using Strides = absl::InlinedVector; - -// Small standalone helper functions for working with the C layout API. - -// Build layout from given iterator of dimensions. -template -inline constexpr LiteRtLayout BuildLayout(Begin begin, End end, - const uint32_t* strides = nullptr) { - LiteRtLayout res{static_cast(end - begin), {}, strides}; - auto i = 0; - - for (auto* it = begin; it != end; ++it) { - res.dimensions[i] = *it; - ++i; - } - - return res; -} - -// Build layout from given iterable of dimensions. -template -inline constexpr LiteRtLayout BuildLayout(const Dims& dims, - const uint32_t* strides = nullptr) { - return BuildLayout(std::cbegin(dims), std::cend(dims), strides); -} - -// Build layout from literal dimensions. -inline constexpr LiteRtLayout BuildLayout(std::initializer_list dims, - const uint32_t* strides = nullptr) { - return BuildLayout(dims.begin(), dims.end(), strides); -} - -// Compute the number of elements in dims iterator. Nullopt if there exists -// a dynamic dimension. -template -inline constexpr std::optional NumElements(Begin begin, End end) { - if (end - begin == 0) { - return {}; - } - size_t res = 1; - for (auto* it = begin; it != end; ++it) { - if (*it < 0) { - return {}; - } - res *= *it; - } - return res; -} - -// Override for layouts. -inline constexpr std::optional NumElements(const LiteRtLayout& layout) { - auto* b = std::cbegin(layout.dimensions); - return NumElements(b, b + layout.rank); -} - -// Get dims as span. -inline constexpr absl::Span DimsSpan( - const LiteRtLayout& layout) { - return absl::MakeConstSpan(layout.dimensions, layout.rank); -} - -// Get strides as span if they exist. -inline constexpr std::optional> StridesSpan( - const LiteRtLayout& layout) { - if (layout.strides) { - return absl::MakeConstSpan(layout.strides, layout.rank); - } - return {}; -} - -// Tensor layout. C++ equivalent to LiteRtLayout. -class Layout { - public: - explicit Layout(litert::Dimensions&& dimensions, - litert::Strides&& strides = litert::Strides()) - : dimensions_(std::move(dimensions)), strides_(std::move(strides)) {} - - explicit Layout(const LiteRtLayout& layout) - : dimensions_(layout.dimensions, layout.dimensions + layout.rank) { - if (layout.strides) { - strides_.assign(layout.strides, layout.strides + layout.rank); - } - } - - // Cast the existing Layout to a LiteRtLayout. Note that the present Layout - // object must outlive the returned LiteRtLayout, otherwise pointers in the - // latter may become dangling. - explicit operator LiteRtLayout() const { - auto res = BuildLayout(dimensions_); - res.strides = HasStrides() ? strides_.data() : nullptr; - return res; - } - - bool operator==(const Layout& other) const { - return dimensions_ == other.dimensions_ && strides_ == other.strides_; - } - - uint32_t Rank() const { return dimensions_.size(); } - - absl::Span Dimensions() const { - return absl::MakeSpan(dimensions_.data(), dimensions_.size()); - } - - bool HasStrides() const { return !strides_.empty(); } - - absl::Span Strides() const { - if (HasStrides()) - return {strides_.data(), Rank()}; - else - return {}; - } - - // Get the number of scalar elements in this tensor type. std::nullopt if - // not fully static. - std::optional NumElements() const { - return ::litert::NumElements(dimensions_.cbegin(), dimensions_.cend()); - } - - private: - litert::Dimensions dimensions_; - litert::Strides strides_; -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_LAYOUT_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_layout_test.cc b/tensorflow/lite/experimental/litert/cc/litert_layout_test.cc deleted file mode 100644 index 40d9cb9873e045..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_layout_test.cc +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" - -#include - -#include -#include - -namespace litert { -namespace { - -using ::testing::ElementsAreArray; - -static constexpr int32_t kStaticDims[] = {2, 2}; -static constexpr int32_t kDynDims[] = {-1, 2}; -static constexpr uint32_t kStrides[] = {1, 1}; - -TEST(LayoutTest, BuildFromDims) { - auto layout = BuildLayout(kStaticDims); - EXPECT_EQ(layout.rank, 2); - EXPECT_THAT(DimsSpan(layout), ElementsAreArray(kStaticDims)); - EXPECT_EQ(layout.strides, nullptr); - EXPECT_FALSE(StridesSpan(layout).has_value()); -} - -TEST(LayoutTest, BuildFromDimsWithStrides) { - auto layout = BuildLayout(kStaticDims, kStrides); - EXPECT_EQ(layout.rank, 2); - EXPECT_THAT(DimsSpan(layout), ElementsAreArray(kStaticDims)); - auto strides = StridesSpan(layout); - ASSERT_TRUE(strides.has_value()); - EXPECT_THAT(*strides, ElementsAreArray(kStrides)); -} - -TEST(LayoutTest, NumElements) { - auto layout = BuildLayout(kStaticDims); - auto num_elements = NumElements(layout); - ASSERT_TRUE(num_elements.has_value()); - EXPECT_EQ(*num_elements, 4); -} - -TEST(LayoutTest, NumElementsDynamic) { - auto layout = BuildLayout(kDynDims); - auto num_elements = NumElements(layout); - ASSERT_FALSE(num_elements.has_value()); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_macros.cc b/tensorflow/lite/experimental/litert/cc/litert_macros.cc deleted file mode 100644 index 7d01ca346818e9..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_macros.cc +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -#include "absl/status/status.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { - -ErrorStatusBuilder::operator absl::Status() const noexcept { - switch (error_.Status()) { - case kLiteRtStatusOk: - return absl::OkStatus(); - case kLiteRtStatusErrorInvalidArgument: - return absl::InvalidArgumentError(error_.Message()); - case kLiteRtStatusErrorMemoryAllocationFailure: - return absl::ResourceExhaustedError(error_.Message()); - case kLiteRtStatusErrorRuntimeFailure: - return absl::InternalError(error_.Message()); - case kLiteRtStatusErrorMissingInputTensor: - return absl::InvalidArgumentError(error_.Message()); - case kLiteRtStatusErrorUnsupported: - return absl::UnimplementedError(error_.Message()); - case kLiteRtStatusErrorNotFound: - return absl::NotFoundError(error_.Message()); - case kLiteRtStatusErrorTimeoutExpired: - return absl::DeadlineExceededError(error_.Message()); - case kLiteRtStatusErrorWrongVersion: - return absl::FailedPreconditionError(error_.Message()); - case kLiteRtStatusErrorUnknown: - return absl::UnknownError(error_.Message()); - case kLiteRtStatusErrorFileIO: - return absl::UnavailableError(error_.Message()); - case kLiteRtStatusErrorInvalidFlatbuffer: - return absl::InvalidArgumentError(error_.Message()); - case kLiteRtStatusErrorDynamicLoading: - return absl::UnavailableError(error_.Message()); - case kLiteRtStatusErrorSerialization: - return absl::InternalError(error_.Message()); - case kLiteRtStatusErrorCompilation: - return absl::InternalError(error_.Message()); - case kLiteRtStatusErrorIndexOOB: - return absl::OutOfRangeError(error_.Message()); - case kLiteRtStatusErrorInvalidIrType: - return absl::InvalidArgumentError(error_.Message()); - case kLiteRtStatusErrorInvalidGraphInvariant: - return absl::InvalidArgumentError(error_.Message()); - case kLiteRtStatusErrorGraphModification: - return absl::InternalError(error_.Message()); - case kLiteRtStatusErrorInvalidToolConfig: - return absl::InvalidArgumentError(error_.Message()); - case kLiteRtStatusLegalizeNoMatch: - return absl::NotFoundError(error_.Message()); - case kLiteRtStatusErrorInvalidLegalization: - return absl::InvalidArgumentError(error_.Message()); - default: - return absl::UnknownError(error_.Message()); - } -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_macros.h b/tensorflow/lite/experimental/litert/cc/litert_macros.h deleted file mode 100644 index 299649061fb333..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_macros.h +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MACROS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MACROS_H_ - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" // IWYU pragma: keep - -#define _CONCAT_NAME_IMPL(x, y) x##y - -#define _CONCAT_NAME(x, y) _CONCAT_NAME_IMPL(x, y) - -#define _RETURN_VAL(val) return val - -#define LITERT_CHECK_STATUS_HAS_CODE(expr, code) ABSL_CHECK(expr == code); - -#define LITERT_CHECK_STATUS_OK(expr) \ - LITERT_CHECK_STATUS_HAS_CODE(expr, kLiteRtStatusOk); - -#define LITERT_ENSURE_SUPPORTED(cond, msg) \ - if (!(cond)) { \ - LITERT_LOG(LITERT_ERROR, "%s", msg); \ - return kLiteRtStatusErrorUnsupported; \ - } - -#define LITERT_ENSURE(expr, fail_stat, msg) \ - if (!(expr)) { \ - LITERT_LOG(LITERT_ERROR, "%s", msg); \ - return fail_stat; \ - } - -#define LITERT_RETURN_IF_ERROR_OR_NOT_MATCHED(expr) \ - if (LiteRtStatus status = expr; \ - (status != kLiteRtStatusOk && status != kLiteRtStatusLegalizeNoMatch)) \ - return status; - -#define LITERT_STACK_ARRAY(ty, var, size, init) \ - ty* var = (ty*)alloca(sizeof(ty) * size); \ - for (ty* e = var; e < var + size; ++e) { \ - *e = init; \ - } - -// LITERT_RETURN_IF_ERROR(expr); -// LITERT_RETURN_IF_ERROR(expr, return_value); -// -// Returns the result of `expr` if it represents an LiteRT error status (either -// `litert::Expected` holding an error, a `LiteRtStatus` or a bool that -// evaluated to `false`). -// -// Returns `return_value` if the result of `expr` represents an error. -// -// The result of `expr` may be referenced as `status` in `return_expr`. -// -// By default, the return value is an `ErrorStatusBuilder` built from using the -// result of `expr`. The error message of this builder can be customized by -// using its `*Log*()` functions and the << operator. -// -// ```cpp -// LITERT_RETURN_IF_ERROR(expr) << "Failed while trying to ..."; -// ``` -#define LITERT_RETURN_IF_ERROR(...) \ - LITERT_RETURN_IF_ERROR_SELECT_OVERLOAD( \ - (__VA_ARGS__, LITERT_RETURN_IF_ERROR_2, LITERT_RETURN_IF_ERROR_1))( \ - __VA_ARGS__) - -// ASSIGN_OR_RETURN(decl, expr) -// ASSIGN_OR_RETURN(decl, expr, return_value) -// -// Evaluates `expr` that should convert to a `litert::Expected` object. -// -// - If the object holds a value, move-assigns the value to `decl`. -// - If the object holds an error, returns the error, casting it to a -// `LiteRtStatus` if required. -// -// `return_value` may be specified to return a custom value in case of error. -// -// By when specifying `return_value`, an `ErrorStatusBuilder` variable called -// `_` can be used to customize the error message. -// -// ```cpp -// LITERT_ASSIGN_OR_RETURN(expr, _ << "Failed while trying to ..."); -// ``` -#define LITERT_ASSIGN_OR_RETURN(DECL, ...) \ - LITERT_ASSIGN_OR_RETURN_SELECT_OVERLOAD((DECL, __VA_ARGS__, \ - LITERT_ASSIGN_OR_RETURN_HELPER_3, \ - LITERT_ASSIGN_OR_RETURN_HELPER_2))( \ - _CONCAT_NAME(expected_value_or_error_, __LINE__), DECL, __VA_ARGS__) - -namespace litert { - -#if defined(__has_builtin) && __has_builtin(__builtin_FILE) && \ - __has_builtin(__builtin_LINE) -#define LITERT_INTERNAL_BUILTIN_FILE __builtin_FILE() -#define LITERT_INTERNAL_BUILTIN_LINE __builtin_LINE() -#else -#define LITERT_INTERNAL_BUILTIN_FILE "unknown" -#define LITERT_INTERNAL_BUILTIN_LINE 0 -#endif - -// Stores a file and a line number. -// -// Mimics a subset of `std::source_location` to be replaced by it when we update -// to C++20. -class SourceLocation { - // We have this to prevent `current()` parameters from begin modified. - struct PrivateTag {}; - - public: - // Creates a SourceLocation with the line and file corresponding to the - // call site. - static constexpr SourceLocation current( - PrivateTag = PrivateTag{}, - const char* file = LITERT_INTERNAL_BUILTIN_FILE, - uint32_t line = LITERT_INTERNAL_BUILTIN_LINE) { - return SourceLocation{file, line}; - } - - constexpr const char* file_name() const { return file_; } - constexpr uint32_t line() const { return line_; } - - private: - // Builds a SourceLocation object. - // - // Note: This is private as `std::source_location` doesn't provide a way of - // manually building a source location. - constexpr SourceLocation(const char* file, uint32_t line) - : file_(file), line_(line) {} - - const char* file_; - uint32_t line_; -}; - -// Converts implicitly to either `LiteRtStatus` or `litert::Expected` holding an -// error. This allows returning a status in functions using either of these as a -// return type in `LITERT_RETURN_IF_ERROR` and `LITERT_ASSIGN_OR_RETURN`. -// -// When a C++ error with a message is converted to a `LiteRtStatus`, the message -// is logged (as an error by default, use the `Log*()` functions to customize -// that). -// -// The error message may be completed with extra info by using the << operator. -class ErrorStatusBuilder { - public: - explicit ErrorStatusBuilder( - bool expr_result, - litert::SourceLocation loc = litert::SourceLocation::current()) - : error_(kLiteRtStatusErrorUnknown), loc_(loc) {} - - template - explicit ErrorStatusBuilder( - const litert::Expected& expected, - litert::SourceLocation loc = litert::SourceLocation::current()) - : error_(expected.Error()), loc_(loc) {} - - template - explicit ErrorStatusBuilder( - litert::Expected&& expected, - litert::SourceLocation loc = litert::SourceLocation::current()) - : error_(std::move(expected.Error())), loc_(loc) {} - - explicit ErrorStatusBuilder( - LiteRtStatus status, - litert::SourceLocation loc = litert::SourceLocation::current()) - : error_(status), loc_(loc) {} - - explicit ErrorStatusBuilder( - const litert::Unexpected& unexpected, - litert::SourceLocation loc = litert::SourceLocation::current()) - : error_(unexpected.Error()), loc_(loc) {} - - explicit ErrorStatusBuilder( - litert::Unexpected&& unexpected, - litert::SourceLocation loc = litert::SourceLocation::current()) - : error_(std::move(unexpected.Error())), loc_(loc) {} - - // NOLINTBEGIN(*-explicit-constructor): This class transparently converts to - // `LiteRtStatus` and `litert::Expected`. - - // Note: this conversion logs the error message if there is one unless NDEBUG - // is set (generally in case of optimized builds). - operator LiteRtStatus() const noexcept { -#ifndef NDEBUG - if (ShouldLog()) { - LiteRtLogger logger = LiteRtGetDefaultLogger(); - LiteRtLogSeverity __min_severity__; - if (LiteRtGetMinLoggerSeverity(logger, &__min_severity__) != - kLiteRtStatusOk) { - __min_severity__ = kLiteRtLogSeverityVerbose; - } - if (log_level_ >= __min_severity__) { - LiteRtLoggerLog(logger, log_level_, "[%s:%u] %s", loc_.file_name(), - loc_.line(), LogMessage().c_str()); - } - } -#endif - return error_.Status(); - } - - template - operator litert::Expected() const noexcept { - return litert::Unexpected(error_.Status(), LogMessage()); - } - - operator absl::Status() const noexcept; - - template - operator absl::StatusOr() const noexcept { - return static_cast(*this); - } - // NOLINTEND(*-explicit-constructor) - - static constexpr bool IsError(bool status) { return !status; } - - static constexpr bool IsError(LiteRtStatus status) { - return status != kLiteRtStatusOk; - } - - static constexpr bool IsError(const litert::Unexpected&) { return true; } - - template - static constexpr bool IsError(const litert::Expected& expected) { - return !expected.HasValue(); - } - - // Appends data to the error message. - template - ErrorStatusBuilder& operator<<(T&& val) { - if (!extra_log_) { - extra_log_ = std::make_unique(); - } - *extra_log_ << static_cast(val); - return *this; - } - - // Sets the log level used when converting to a `LiteRtStatus`. - ErrorStatusBuilder& Log(LiteRtLogSeverity log_level) noexcept { - log_level_ = log_level; - return *this; - } - - // Sets the log level used when converting to a `LiteRtStatus` to `error`. - ErrorStatusBuilder& LogVerbose() noexcept { - return Log(kLiteRtLogSeverityVerbose); - } - - // Sets the log level used when converting to a `LiteRtStatus` to `info`. - ErrorStatusBuilder& LogInfo() noexcept { return Log(kLiteRtLogSeverityInfo); } - - // Sets the log level used when converting to a `LiteRtStatus` to `error`. - ErrorStatusBuilder& LogWarning() noexcept { - return Log(kLiteRtLogSeverityWarning); - } - - // Sets the log level used when converting to a `LiteRtStatus` to `error`. - ErrorStatusBuilder& LogError() noexcept { - return Log(kLiteRtLogSeverityError); - } - - // Prevent logging any message when converting to a `LiteRtStatus`. - ErrorStatusBuilder& NoLog() noexcept { return Log(kLiteRtLogSeveritySilent); } - - private: - bool ShouldLog() const noexcept { - return log_level_ != kLiteRtLogSeveritySilent && - (!error_.Message().empty() || extra_log_); - } - - std::string LogMessage() const { - if (!error_.Message().empty() && extra_log_) { - std::string res; - res.reserve(error_.Message().size() + extra_log_->tellp() + 2); - res.append(error_.Message()); - res.append(" "); - res.append(extra_log_->str()); - return res; - } - if (!error_.Message().empty()) { - return error_.Message(); - } - if (extra_log_) { - return extra_log_->str(); - } - return {}; - } - - litert::Error error_; - litert::SourceLocation loc_; - std::unique_ptr extra_log_; - LiteRtLogSeverity log_level_ = kLiteRtLogSeverityError; -}; - -} // namespace litert - -//////////// Implementation details start here. /////////////////////// - -#define LITERT_RETURN_IF_ERROR_SELECT_OVERLOAD_HELPER(_1, _2, OVERLOAD, ...) \ - OVERLOAD - -#define LITERT_RETURN_IF_ERROR_SELECT_OVERLOAD(args) \ - LITERT_RETURN_IF_ERROR_SELECT_OVERLOAD_HELPER args - -#define LITERT_RETURN_IF_ERROR_1(EXPR) \ - LITERT_RETURN_IF_ERROR_2(EXPR, \ - ::litert::ErrorStatusBuilder{std::move(status)}) - -#define LITERT_RETURN_IF_ERROR_2(EXPR, RETURN_VALUE) \ - if (auto status = (EXPR); ::litert::ErrorStatusBuilder::IsError(status)) \ - return RETURN_VALUE - -#define LITERT_ASSIGN_OR_RETURN_SELECT_OVERLOAD_HELPER(_1, _2, _3, OVERLOAD, \ - ...) \ - OVERLOAD - -#define LITERT_ASSIGN_OR_RETURN_SELECT_OVERLOAD(args) \ - LITERT_ASSIGN_OR_RETURN_SELECT_OVERLOAD_HELPER args - -#define LITERT_ASSIGN_OR_RETURN_HELPER_2(TMP_VAR, DECL, EXPR) \ - LITERT_ASSIGN_OR_RETURN_HELPER_3(TMP_VAR, DECL, EXPR, _) - -#define LITERT_ASSIGN_OR_RETURN_HELPER_3(TMP_VAR, DECL, EXPR, RETURN_VALUE) \ - auto&& TMP_VAR = (EXPR); \ - if (::litert::ErrorStatusBuilder::IsError(TMP_VAR)) { \ - [[maybe_unused]] ::litert::ErrorStatusBuilder _(std::move(TMP_VAR)); \ - return RETURN_VALUE; \ - } \ - DECL = std::move(TMP_VAR.Value()); - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MACROS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_macros_test.cc b/tensorflow/lite/experimental/litert/cc/litert_macros_test.cc deleted file mode 100644 index f1d0b66e6748bb..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_macros_test.cc +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { -namespace { - -using testing::AllOf; -using testing::Property; - -TEST(LiteRtReturnIfErrorTest, ConvertsResultToLiteRtStatus) { - EXPECT_EQ( - []() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR( - Expected(Unexpected(kLiteRtStatusErrorNotFound))); - return kLiteRtStatusOk; - }(), - kLiteRtStatusErrorNotFound); - EXPECT_EQ( - []() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(Unexpected(kLiteRtStatusErrorNotFound)); - return kLiteRtStatusOk; - }(), - kLiteRtStatusErrorNotFound); - EXPECT_EQ( - []() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(kLiteRtStatusErrorNotFound); - return kLiteRtStatusOk; - }(), - kLiteRtStatusErrorNotFound); -} - -TEST(LiteRtReturnIfErrorTest, ConvertsResultToExpectedHoldingAnError) { - EXPECT_THAT( - []() -> Expected { - LITERT_RETURN_IF_ERROR( - Expected(Unexpected(kLiteRtStatusErrorNotFound))); - return {}; - }(), - AllOf(Property(&Expected::HasValue, false), - Property(&Expected::Error, - Property(&Error::Status, kLiteRtStatusErrorNotFound)))); - EXPECT_THAT( - []() -> Expected { - LITERT_RETURN_IF_ERROR(Unexpected(kLiteRtStatusErrorNotFound)); - return {}; - }(), - AllOf(Property(&Expected::HasValue, false), - Property(&Expected::Error, - Property(&Error::Status, kLiteRtStatusErrorNotFound)))); - EXPECT_THAT( - []() -> Expected { - LITERT_RETURN_IF_ERROR(kLiteRtStatusErrorNotFound); - return {}; - }(), - AllOf(Property(&Expected::HasValue, false), - Property(&Expected::Error, - Property(&Error::Status, kLiteRtStatusErrorNotFound)))); -} - -TEST(LiteRtReturnIfErrorTest, DoesntReturnOnSuccess) { - int canary_value = 0; - auto ReturnExpectedIfError = [&canary_value]() -> Expected { - LITERT_RETURN_IF_ERROR(Expected()); - canary_value = 1; - return {}; - }; - EXPECT_THAT(ReturnExpectedIfError(), - Property(&Expected::HasValue, true)); - EXPECT_EQ(canary_value, 1); - - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(kLiteRtStatusOk); - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 2); -} - -TEST(LiteRtReturnIfErrorTest, ExtraLoggingWorks) { - int canary_value = 0; - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(false) << "Successful default level logging."; - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 0); - - canary_value = 0; - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(false).LogVerbose() << "Successful verbose logging."; - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 0); - - canary_value = 0; - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(false).LogInfo() << "Successful info logging."; - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 0); - - canary_value = 0; - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(false).LogWarning() << "Successful warning logging."; - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 0); - - canary_value = 0; - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(false).LogError() << "Successful error logging."; - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 0); - - canary_value = 0; - [&canary_value]() -> LiteRtStatus { - LITERT_RETURN_IF_ERROR(false).NoLog() << "This should never be printed"; - canary_value = 2; - return kLiteRtStatusOk; - }(); - EXPECT_EQ(canary_value, 0); -} - -TEST(LiteRtAssignOrReturnTest, VariableAssignmentWorks) { - int canary_value = 0; - auto ChangeCanaryValue = [&canary_value]() -> LiteRtStatus { - LITERT_ASSIGN_OR_RETURN(canary_value, Expected(1)); - return kLiteRtStatusOk; - }; - EXPECT_EQ(ChangeCanaryValue(), kLiteRtStatusOk); - EXPECT_EQ(canary_value, 1); -} - -TEST(LiteRtAssignOrReturnTest, MoveOnlyVariableAssignmentWorks) { - struct MoveOnly { - explicit MoveOnly(int val) : val(val) {}; - MoveOnly(const MoveOnly&) = delete; - MoveOnly& operator=(const MoveOnly&) = delete; - MoveOnly(MoveOnly&&) = default; - MoveOnly& operator=(MoveOnly&&) = default; - int val = 1; - }; - - MoveOnly canary_value{0}; - auto ChangeCanaryValue = [&canary_value]() -> LiteRtStatus { - LITERT_ASSIGN_OR_RETURN(canary_value, Expected(1)); - return kLiteRtStatusOk; - }; - EXPECT_EQ(ChangeCanaryValue(), kLiteRtStatusOk); - EXPECT_EQ(canary_value.val, 1); -} - -TEST(LiteRtAssignOrReturnTest, ReturnsOnFailure) { - const Expected InvalidArgumentError = - Expected(Unexpected(kLiteRtStatusErrorInvalidArgument)); - - int canary_value = 0; - auto ErrorWithStatus = [&]() -> LiteRtStatus { - LITERT_ASSIGN_OR_RETURN(canary_value, InvalidArgumentError); - return kLiteRtStatusOk; - }; - EXPECT_EQ(ErrorWithStatus(), kLiteRtStatusErrorInvalidArgument); - EXPECT_EQ(canary_value, 0); - - auto ErrorWithCustomStatus = [&]() -> int { - LITERT_ASSIGN_OR_RETURN(canary_value, InvalidArgumentError, 42); - return 1; - }; - EXPECT_EQ(ErrorWithCustomStatus(), 42); - EXPECT_EQ(canary_value, 0); - - auto ErrorWithExpected = [&]() -> Expected { - LITERT_ASSIGN_OR_RETURN(canary_value, InvalidArgumentError); - return {}; - }; - auto expected_return = ErrorWithExpected(); - ASSERT_FALSE(expected_return.HasValue()); - EXPECT_EQ(expected_return.Error().Status(), - kLiteRtStatusErrorInvalidArgument); - EXPECT_EQ(canary_value, 0); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_model.cc b/tensorflow/lite/experimental/litert/cc/litert_model.cc deleted file mode 100644 index b67c5c75d2375a..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_model.cc +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { - -bool Tensor::IsSubgraphOutput() const { return Uses().empty(); } - -bool Tensor::IsSubgraphInput() const { - // A special case for zero-sized tensors. - if (RankedTensorType()->Layout().Rank() == 1 && - RankedTensorType()->Layout().Dimensions()[0] == 0) { - return false; - } - return !HasWeights() && !DefiningOp().has_value(); -} - -bool Tensor::IsConstant() const { - return HasWeights() && !DefiningOp().has_value(); -} - -Tensor::TensorUses Tensor::Uses() const { - LiteRtParamIndex num_uses; - litert::internal::AssertOk(LiteRtGetNumTensorUses, Get(), &num_uses); - - TensorUses uses; - for (auto i = 0; i < num_uses; ++i) { - LiteRtOp user; - LiteRtParamIndex user_arg_index; - litert::internal::AssertOk(LiteRtGetTensorUse, Get(), i, &user, - &user_arg_index); - uses.emplace_back(TensorUse{Op(user), user_arg_index}); - } - return uses; -} - -OpInputs Op::Inputs() const { - LiteRtParamIndex num_inputs; - internal::AssertOk(LiteRtGetNumOpInputs, Get(), &num_inputs); - - OpInputs inputs; - for (auto i = 0; i < num_inputs; ++i) { - LiteRtTensor input; - internal::AssertOk(LiteRtGetOpInput, Get(), i, &input); - inputs.emplace_back(Tensor(input)); - } - return inputs; -} - -OpOutputs Op::Outputs() const { - LiteRtParamIndex num_outputs; - internal::AssertOk(LiteRtGetNumOpOutputs, Get(), &num_outputs); - - OpOutputs outputs; - for (auto i = 0; i < num_outputs; ++i) { - LiteRtTensor output; - internal::AssertOk(LiteRtGetOpOutput, Get(), i, &output); - outputs.emplace_back(Tensor(output)); - } - return outputs; -} - -SubgraphInputs Subgraph::Inputs() const { - LiteRtParamIndex num_inputs; - internal::AssertOk(LiteRtGetNumSubgraphInputs, Get(), &num_inputs); - - SubgraphInputs inputs; - for (auto i = 0; i < num_inputs; ++i) { - LiteRtTensor input; - internal::AssertOk(LiteRtGetSubgraphInput, Get(), i, &input); - inputs.emplace_back(Tensor(input)); - } - return inputs; -} - -Expected Subgraph::Input(absl::string_view name) const { - LiteRtParamIndex num_inputs; - internal::AssertOk(LiteRtGetNumSubgraphInputs, Get(), &num_inputs); - - for (auto i = 0; i < num_inputs; ++i) { - LiteRtTensor input; - internal::AssertOk(LiteRtGetSubgraphInput, Get(), i, &input); - const char* input_name; - internal::AssertOk(LiteRtGetTensorName, input, &input_name); - if (name == input_name) { - return Tensor(input); - } - } - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find input"); -} - -Expected Subgraph::Output(absl::string_view name) const { - LiteRtParamIndex num_outputs; - internal::AssertOk(LiteRtGetNumSubgraphOutputs, Get(), &num_outputs); - - for (auto i = 0; i < num_outputs; ++i) { - LiteRtTensor output; - internal::AssertOk(LiteRtGetSubgraphOutput, Get(), i, &output); - const char* output_name; - internal::AssertOk(LiteRtGetTensorName, output, &output_name); - if (name == output_name) { - return Tensor(output); - } - } - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find output"); -} - -SubgraphOutputs Subgraph::Outputs() const { - LiteRtParamIndex num_outputs; - internal::AssertOk(LiteRtGetNumSubgraphOutputs, Get(), &num_outputs); - - SubgraphOutputs outputs; - for (auto i = 0; i < num_outputs; ++i) { - LiteRtTensor output; - internal::AssertOk(LiteRtGetSubgraphOutput, Get(), i, &output); - outputs.emplace_back(Tensor(output)); - } - return outputs; -} - -std::vector Subgraph::Ops() const { - LiteRtParamIndex num_ops; - internal::AssertOk(LiteRtGetNumSubgraphOps, Get(), &num_ops); - - std::vector ops; - for (auto i = 0; i < num_ops; ++i) { - LiteRtOp op; - litert::internal::AssertOk(LiteRtGetSubgraphOp, Get(), i, &op); - ops.emplace_back(Op(op)); - } - return ops; -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_model.h b/tensorflow/lite/experimental/litert/cc/litert_model.h deleted file mode 100644 index 579e97db9888e1..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_model.h +++ /dev/null @@ -1,473 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_consts.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" - -namespace litert { - -// Type for tensors with known dimensions. C++ equivalent to -// LiteRtRankedTensorType. -class RankedTensorType { - public: - RankedTensorType(enum ElementType element_type, class Layout&& layout) - : element_type_(element_type), layout_(std::move(layout)) {} - explicit RankedTensorType(const LiteRtRankedTensorType& type) - : element_type_(static_cast(type.element_type)), - layout_(type.layout) {} - - explicit operator LiteRtRankedTensorType() const { - return LiteRtRankedTensorType{ - /*.element_type=*/static_cast(element_type_), - /*layout=*/static_cast(layout_), - }; - } - - bool operator==(const RankedTensorType& other) const { - return ElementType() == other.ElementType() && Layout() == other.Layout(); - } - - enum ElementType ElementType() const { return element_type_; } - - const class Layout& Layout() const { return layout_; } - - private: - enum ElementType element_type_; - class Layout layout_; -}; - -// Tensor weights. C++ equivalent of LiteRtWeights. -class Weights : public internal::NonOwnedHandle { - public: - explicit Weights(LiteRtWeights weights) - : internal::NonOwnedHandle(weights) {} - - absl::Span Bytes() const { - size_t size; - const void* addr; - internal::AssertOk(LiteRtGetWeightsBytes, Get(), &addr, &size); - return absl::MakeSpan(static_cast(addr), size); - } -}; - -// Tensor. C++ equivalent of LiteRtTensor. -class Tensor : public internal::NonOwnedHandle { - public: - explicit Tensor(LiteRtTensor tensor) - : internal::NonOwnedHandle(tensor) {} - - enum ElementType ElementType() const { - if (TypeId() == kLiteRtUnrankedTensorType) { - return static_cast(UnrankedTensorType()->element_type); - } else { - return RankedTensorType()->ElementType(); - } - } - - LiteRtTensorTypeId TypeId() const { - LiteRtTensorTypeId type_id; - internal::AssertOk(LiteRtGetTensorTypeId, Get(), &type_id); - return type_id; - } - - Expected UnrankedTensorType() const { - if (TypeId() != kLiteRtUnrankedTensorType) { - return Error(kLiteRtStatusErrorInvalidArgument, - "Not an unranked invalid tensor"); - } - LiteRtUnrankedTensorType unranked_tensor_type; - internal::AssertOk(LiteRtGetUnrankedTensorType, Get(), - &unranked_tensor_type); - return unranked_tensor_type; - } - - Expected RankedTensorType() const { - if (TypeId() != kLiteRtRankedTensorType) { - return Error(kLiteRtStatusErrorInvalidArgument, - "Not a ranked tensor type"); - } - LiteRtRankedTensorType ranked_tensor_type; - internal::AssertOk(LiteRtGetRankedTensorType, Get(), &ranked_tensor_type); - return litert::RankedTensorType(ranked_tensor_type); - } - - LiteRtQuantizationTypeId QTypeId() const { - LiteRtQuantizationTypeId q_type_id; - internal::AssertOk(LiteRtGetQuantizationTypeId, Get(), &q_type_id); - return q_type_id; - } - - bool HasQuantization() const { return QTypeId() != kLiteRtQuantizationNone; } - - LiteRtQuantizationPerTensor PerTensorQuantization() const { - internal::AssertEq([&]() { return QTypeId(); }, - kLiteRtQuantizationPerTensor); - LiteRtQuantizationPerTensor per_tensor_quantization; - internal::AssertOk(LiteRtGetPerTensorQuantization, Get(), - &per_tensor_quantization); - return per_tensor_quantization; - } - - LiteRtQuantizationPerChannel PerChannelQuantization() const { - internal::AssertEq([&]() { return QTypeId(); }, - kLiteRtQuantizationPerChannel); - LiteRtQuantizationPerChannel per_channel_quantization; - internal::AssertOk(LiteRtGetPerChannelQuantization, Get(), - &per_channel_quantization); - return per_channel_quantization; - } - - bool HasWeights() const { - auto weights = Weights(); - return !weights.Bytes().empty(); - } - - class Weights Weights() const { - LiteRtWeights weights; - internal::AssertOk(LiteRtGetTensorWeights, Get(), &weights); - return litert::Weights(weights); - } - - absl::string_view Name() const { - const char* name; - internal::AssertOk(LiteRtGetTensorName, Get(), &name); - return absl::string_view(name); - } - - struct TensorUse; - using TensorUses = - absl::InlinedVector; - - TensorUses Uses() const; - - template - Expected> WeightsData() const { - auto ranked_tensor_type = RankedTensorType(); - if (!ranked_tensor_type) { - return ranked_tensor_type.Error(); - } - - const enum ElementType ty = ranked_tensor_type->ElementType(); - if (ty != GetElementType()) { - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); - } - - if (!HasWeights()) { - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); - } - const absl::Span weights = Weights().Bytes(); - - auto num_elements = ranked_tensor_type->Layout().NumElements(); - if (!num_elements.has_value()) { - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); - } - auto byte_width = GetByteWidth(ty); - if (!byte_width.has_value()) { - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); - } - - if (byte_width.value() * num_elements.value() != weights.size()) { - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); - } - - return absl::MakeConstSpan(reinterpret_cast(weights.data()), - num_elements.value()); - } - - std::optional DefiningOp() const { - bool has_defining_op; - LiteRtTensorDefiningOp defining_op; - internal::AssertOk(LiteRtGetTensorDefiningOp, Get(), &has_defining_op, - &defining_op); - if (has_defining_op) { - return defining_op; - } else { - return std::nullopt; - } - } - - bool IsSubgraphOutput() const; - bool IsSubgraphInput() const; - bool IsConstant() const; -}; - -using OpInputs = absl::InlinedVector; -using OpOutputs = absl::InlinedVector; - -// Operator. C++ equivalent of LiteRtOp. -class Op : public internal::NonOwnedHandle { - public: - explicit Op(LiteRtOp op) : internal::NonOwnedHandle(op) {} - - LiteRtOpCode Code() const { - LiteRtOpCode opcode; - internal::AssertOk(LiteRtGetOpCode, Get(), &opcode); - return opcode; - } - - OpInputs Inputs() const; - OpOutputs Outputs() const; -}; - -struct Tensor::TensorUse { - Op user; - LiteRtParamIndex user_arg_ind; -}; - -using SubgraphInputs = - absl::InlinedVector; -using SubgraphOutputs = - absl::InlinedVector; - -// Model subgraph. C++ equivalent of LiteRtSubgraph. -class Subgraph : public internal::NonOwnedHandle { - public: - explicit Subgraph(LiteRtSubgraph subgraph) - : internal::NonOwnedHandle(subgraph) {} - - SubgraphInputs Inputs() const; - SubgraphOutputs Outputs() const; - std::vector Ops() const; - - // Returns the input tensor with the given input signature name. - Expected Input(absl::string_view name) const; - - // Returns the output tensor with the given output signature name. - Expected Output(absl::string_view name) const; -}; - -// Model signature. C++ equivalent of LiteRtSignature. -class Signature : public internal::NonOwnedHandle { - public: - explicit Signature(LiteRtSignature signature) - : internal::NonOwnedHandle(signature) {} - - absl::string_view Key() const { - const char* key; - internal::AssertOk(LiteRtGetSignatureKey, Get(), &key); - return key; - } - - LiteRtSubgraph Subgraph() const { - LiteRtSubgraph subgraph; - internal::AssertOk(LiteRtGetSignatureSubgraph, Get(), &subgraph); - return subgraph; - } - - std::vector InputNames() const { - LiteRtParamIndex num_inputs; - internal::AssertOk(LiteRtGetNumSignatureInputs, Get(), &num_inputs); - std::vector input_names; - input_names.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - const char* input_name; - internal::AssertOk(LiteRtGetSignatureInputName, Get(), i, &input_name); - input_names.push_back(input_name); - } - return input_names; - } - - std::vector OutputNames() const { - LiteRtParamIndex num_outputs; - internal::AssertOk(LiteRtGetNumSignatureOutputs, Get(), &num_outputs); - std::vector output_names; - output_names.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - const char* output_name; - internal::AssertOk(LiteRtGetSignatureOutputName, Get(), i, &output_name); - output_names.push_back(output_name); - } - return output_names; - } -}; - -// Model. C++ equivalent of LiteRtModel. -class Model : public internal::Handle { - public: - Model() = default; - - static Model CreateFromOwnedHandle(LiteRtModel model) { - return Model(model, /*owned=*/true); - } - - static Model CreateFromNonOwnedHandle(LiteRtModel model) { - return Model(model, /*owned=*/false); - } - - static Expected CreateFromFile(const std::string& filename) { - LiteRtModel model; - if (auto status = LiteRtCreateModelFromFile(filename.c_str(), &model); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to load model from file"); - } - return CreateFromOwnedHandle(model); - } - - static Expected CreateFromBuffer(BufferRef buffer) { - LiteRtModel model; - if (auto status = - LiteRtCreateModelFromBuffer(buffer.Data(), buffer.Size(), &model); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to load model from buffer"); - } - return CreateFromOwnedHandle(model); - } - - Expected> Metadata( - const std::string& metadata_key) const { - const void* buffer; - size_t buffer_size; - if (LiteRtGetModelMetadata(Get(), metadata_key.data(), &buffer, - &buffer_size) != kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorNotFound, "Metadata key not found"); - } - return absl::MakeSpan(static_cast(buffer), buffer_size); - } - - Expected MainSubgraph() const { - LiteRtParamIndex main_subgraph_index; - internal::AssertOk(LiteRtGetMainModelSubgraphIndex, Get(), - &main_subgraph_index); - return this->Subgraph(main_subgraph_index); - } - - size_t NumSubgraphs() const { - LiteRtParamIndex num_subgraphs; - internal::AssertOk(LiteRtGetNumModelSubgraphs, Get(), &num_subgraphs); - return num_subgraphs; - } - - Expected Subgraph(size_t subgraph_index) const { - LiteRtSubgraph subgraph; - if (LiteRtGetModelSubgraph(Get(), subgraph_index, &subgraph) != - kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorNotFound, "Subgraph not found"); - } - return litert::Subgraph(subgraph); - } - - Expected Subgraph(absl::string_view signature_key) const { - auto signature = FindSignature(signature_key); - if (!signature) { - return Unexpected(kLiteRtStatusErrorNotFound, "Signature not found"); - } - return litert::Subgraph(signature->Subgraph()); - } - - size_t GetNumSignatures() const { - LiteRtParamIndex num_signatures; - internal::AssertOk(LiteRtGetNumModelSignatures, Get(), &num_signatures); - return num_signatures; - } - - // Returns the list of signatures defined in the model. - Expected> GetSignatures() const { - LiteRtParamIndex num_signatures; - internal::AssertOk(LiteRtGetNumModelSignatures, Get(), &num_signatures); - std::vector signatures; - signatures.reserve(num_signatures); - for (int i = 0; i < num_signatures; ++i) { - LiteRtSignature lite_rt_signature; - internal::AssertOk(LiteRtGetModelSignature, Get(), i, &lite_rt_signature); - Signature signature(lite_rt_signature); - signatures.push_back(std::move(signature)); - } - return std::move(signatures); - } - - // Returns the signature at the given index. - Expected GetSignature(size_t signature_index) const { - LiteRtSignature lite_rt_signature; - internal::AssertOk(LiteRtGetModelSignature, Get(), signature_index, - &lite_rt_signature); - return Signature(lite_rt_signature); - } - - // Returns the signature index for the given signature key. - Expected GetSignatureIndex(absl::string_view signature_key) const { - LiteRtParamIndex num_signatures; - internal::AssertOk(LiteRtGetNumModelSignatures, Get(), &num_signatures); - for (int i = 0; i < num_signatures; ++i) { - LiteRtSignature lite_rt_signature; - internal::AssertOk(LiteRtGetModelSignature, Get(), i, &lite_rt_signature); - const char* key_cstr; - internal::AssertOk(LiteRtGetSignatureKey, lite_rt_signature, &key_cstr); - if (absl::string_view(key_cstr) == signature_key) { - return i; - } - } - return Unexpected(kLiteRtStatusErrorNotFound, "Signature not found"); - } - - // Returns the Signature object for the given signature key. - Expected FindSignature( - absl::string_view signature_key) const { - LiteRtParamIndex num_signatures; - internal::AssertOk(LiteRtGetNumModelSignatures, Get(), &num_signatures); - for (int i = 0; i < num_signatures; ++i) { - LiteRtSignature lite_rt_signature; - internal::AssertOk(LiteRtGetModelSignature, Get(), i, &lite_rt_signature); - const char* key_cstr; - internal::AssertOk(LiteRtGetSignatureKey, lite_rt_signature, &key_cstr); - if (absl::string_view(key_cstr) == signature_key) { - return Signature(lite_rt_signature); - } - } - return Unexpected(kLiteRtStatusErrorNotFound, "Signature not found"); - } - - static absl::string_view DefaultSignatureKey() { - const char* key; - internal::AssertOk(LiteRtGetDefaultSignatureKey, &key); - return key; - } - - private: - // Parameter `owned` indicates if the created TensorBuffer object should take - // ownership of the provided `tensor_buffer` handle. - Model(LiteRtModel model, bool owned) - : internal::Handle(model, owned) {} -}; - -struct SerializationOptions { - static LiteRtModelSerializationOptions Defaults() { - return LiteRtModelSerializationOptions{}; - } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_model_predicates.cc b/tensorflow/lite/experimental/litert/cc/litert_model_predicates.cc deleted file mode 100644 index 18efea56f7ffa4..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_model_predicates.cc +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" - -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -namespace litert { -namespace { - -template -bool Any(absl::Span vals, std::function unary_pred) { - for (const auto& val : vals) { - if (unary_pred(val)) { - return true; - } - } - return false; -} - -bool UseSoftEqual(const Tensor::TensorUse& actual_use, - const UseInfo& expected_use) { - if (expected_use.user_param_ind.has_value() && - actual_use.user_arg_ind != expected_use.user_param_ind.value()) { - return false; - } - if (expected_use.op_code.has_value() && - actual_use.user.Code() != expected_use.op_code.value()) { - return false; - } - return true; -} - -} // namespace - -// Does given tensor have given type and shape info. Optional values considered -// to be a vacous match. -bool MatchRankedTensorType(const RankedTensorType& tensor_type, - const TensorTypeInfo& expected) { - if (expected.element_type.has_value() && - (tensor_type.ElementType() != expected.element_type.value())) { - return false; - } - - if (expected.dims.has_value()) { - auto actual_dims = tensor_type.Layout().Dimensions(); - auto expected_dims = absl::MakeConstSpan(expected.dims.value()); - return AllZip(actual_dims, expected_dims, - [](auto l, auto r) -> bool { return l == r; }); - } - return true; -} - -// Does given op have signature matching given types. Optional values considered -// to be a vacous match. -bool MatchOpType( - const Op& op, - const std::vector>& expected_inputs, - const std::vector>& expected_outputs) { - auto match = [](const Tensor& actual, - const std::optional& expected) -> bool { - if (!expected.has_value()) { - return true; - } - auto actual_ranked_tensor_type = actual.RankedTensorType(); - // Don't return a match if the tensor is unranked. - if (!actual_ranked_tensor_type) { - return false; - } - return MatchRankedTensorType(*actual_ranked_tensor_type, expected.value()); - }; - - const bool inputs_match = AllZip(absl::MakeConstSpan(op.Inputs()), - absl::MakeConstSpan(expected_inputs), match); - const bool outputs_match = - AllZip(absl::MakeConstSpan(op.Outputs()), - absl::MakeConstSpan(expected_outputs), match); - return inputs_match && outputs_match; -} - -bool MatchUse(const Tensor& tensor, const UseInfo& expected_use) { - auto soft_equal = [&expected_use = std::as_const(expected_use)]( - const Tensor::TensorUse& actual_use) { - return UseSoftEqual(actual_use, expected_use); - }; - return Any(tensor.Uses(), soft_equal); -} - -bool MatchUses(const Tensor& tensor, const std::vector& expected_uses, - bool strict) { - const auto uses = tensor.Uses(); - if (strict && uses.size() != expected_uses.size()) { - return false; - } - auto not_use = [&tensor = - std::as_const(tensor)](const UseInfo& expected_use) { - return !MatchUse(tensor, expected_use); - }; - return !Any(expected_uses, not_use); -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_model_predicates.h b/tensorflow/lite/experimental/litert/cc/litert_model_predicates.h deleted file mode 100644 index 238e9a455bbb9e..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_model_predicates.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_PREDICATES_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_PREDICATES_H_ - -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -// Predicates used for matching patterns in the graph. NOTE: All optionals in -// matcher arguments are considered to be a vacous match. - -namespace litert { - -struct TensorTypeInfo { - std::optional element_type = std::nullopt; - std::optional> dims = std::nullopt; - - explicit TensorTypeInfo(ElementType element_type) - : element_type(element_type) {} - explicit TensorTypeInfo(absl::InlinedVector dims) : dims(dims) {} - TensorTypeInfo(ElementType element_type, absl::InlinedVector dims) - : element_type(element_type), dims(dims) {} -}; - -struct UseInfo { - std::optional op_code = std::nullopt; - std::optional user_param_ind = std::nullopt; -}; - -// Does this tensor have given type and shape info. -bool MatchRankedTensorType(const RankedTensorType& tensor_type, - const TensorTypeInfo& expected); - -// Does this op have signature matching given types. -bool MatchOpType( - const Op& op, - const std::vector>& expected_inputs, - const std::vector>& expected_outputs); - -// Does this tensor contain weights whose values match expected_data. -template -inline bool MatchWeights(const Tensor& tensor, - absl::Span expected_data) { - auto weights = tensor.WeightsData(); - return weights.HasValue() && *weights == expected_data; -} - -// Does this tensor have a user with the given information. -bool MatchUse(const Tensor& tensor, const UseInfo& expected_use); - -// Does this tensor have matching users. If "strict" is true, then expected_uses -// size must equal the number of actual uses, otherwise just checks each -// expected_use match an actual use. -bool MatchUses(const Tensor& tensor, const std::vector& expected_uses, - bool strict = true); - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_PREDICATES_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_model_predicates_test.cc b/tensorflow/lite/experimental/litert/cc/litert_model_predicates_test.cc deleted file mode 100644 index f16bc764e560c4..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_model_predicates_test.cc +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" - -#include - -#include -#include "absl/container/inlined_vector.h" -#include "absl/log/absl_check.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" - -namespace litert { - -namespace { - -using ::litert::testing::LoadTestFileModel; - -TEST(MatchRankedTensorTypeTest, HasAll) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& input = inputs.front(); - auto input_tensor_type = input.RankedTensorType(); - EXPECT_TRUE(input_tensor_type); - EXPECT_TRUE(MatchRankedTensorType( - *input_tensor_type, TensorTypeInfo(ElementType::Float32, {2, 2}))); -} - -TEST(MatchRankedTensorTypeTest, NoMatch) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& input = inputs.front(); - auto input_tensor_type = input.RankedTensorType(); - EXPECT_TRUE(input_tensor_type); - EXPECT_FALSE(MatchRankedTensorType( - *input_tensor_type, TensorTypeInfo(ElementType::Float32, {3, 2}))); -} - -TEST(MatchRankedTensorTypeTest, AnyDims) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& input = inputs.front(); - auto input_tensor_type = input.RankedTensorType(); - EXPECT_TRUE(input_tensor_type); - EXPECT_TRUE(MatchRankedTensorType(*input_tensor_type, - TensorTypeInfo(ElementType::Float32))); -} - -TEST(MatchRankedTensorTypeTest, AnyElementType) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& input = inputs.front(); - auto input_tensor_type = input.RankedTensorType(); - EXPECT_TRUE(input_tensor_type); - EXPECT_TRUE( - MatchRankedTensorType(*input_tensor_type, TensorTypeInfo({2, 2}))); -} - -TEST(MatchOpTypeTest, HasAll) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - TensorTypeInfo expected_type(ElementType::Float32, {2, 2}); - EXPECT_TRUE(MatchOpType(ops.front(), {expected_type, expected_type}, - {expected_type})); -} - -TEST(MatchOpTypeTest, NoMatch) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - TensorTypeInfo expected_type(ElementType::Float32, {2, 2}); - TensorTypeInfo not_expected_type(ElementType::Int32, {2, 2}); - EXPECT_FALSE(MatchOpType(ops.front(), {not_expected_type, expected_type}, - {expected_type})); -} - -TEST(MatchOpTypeTest, AnyInput) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - TensorTypeInfo expected_type(ElementType::Float32, {2, 2}); - EXPECT_TRUE( - MatchOpType(ops.front(), {std::nullopt, expected_type}, {expected_type})); -} - -TEST(MatchOpTypeTest, AnyOutput) { - auto litert_model = LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - TensorTypeInfo expected_type(ElementType::Float32, {2, 2}); - EXPECT_TRUE( - MatchOpType(ops.front(), {std::nullopt, expected_type}, {std::nullopt})); -} - -TEST(MatchWeightsTest, Matches) { - auto litert_model = LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& cst = inputs.back(); - EXPECT_TRUE(MatchWeights(cst, absl::Span({1.0, 2.0, 3.0, 4.0}))); -} - -TEST(MatchWeightsTest, NoMatchBadType) { - auto litert_model = LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& cst = inputs.back(); - EXPECT_FALSE( - MatchWeights(cst, absl::Span({1.0, 2.0, 3.0, 4.0}))); -} -TEST(MatchWeightsTest, NoMatchBadVals) { - auto litert_model = LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - const auto& cst = inputs.back(); - EXPECT_FALSE( - MatchWeights(cst, absl::Span({3.0, 2.0, 3.0, 5.0}))); -} - -TEST(MatchUseTest, Match) { - auto litert_model = LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - EXPECT_TRUE(MatchUse(inputs.back(), UseInfo{kLiteRtOpCodeTflAdd, 1})); -} - -TEST(MatchUseTest, MatchAnyCode) { - auto litert_model = LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - EXPECT_TRUE(MatchUse(inputs.back(), UseInfo{std::nullopt, 1})); -} - -TEST(MatchUseTest, NoMatch) { - auto litert_model = LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto ops = subgraph->Ops(); - const auto inputs = ops.front().Inputs(); - EXPECT_FALSE(MatchUse(inputs.back(), UseInfo{std::nullopt, 2})); -} - -TEST(MatchUsesTest, StrictMatch) { - auto litert_model = LoadTestFileModel("add_simple.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto subgraph_inputs = subgraph->Inputs(); - const auto& tensor = subgraph_inputs.front(); - EXPECT_TRUE( - MatchUses(tensor, {{kLiteRtOpCodeTflAdd, 0}, {kLiteRtOpCodeTflAdd, 1}})); -} - -TEST(MatchUsesTest, StrictNoMatch) { - auto litert_model = LoadTestFileModel("add_simple.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto subgraph_inputs = subgraph->Inputs(); - const auto& tensor = subgraph_inputs.front(); - EXPECT_FALSE(MatchUses(tensor, {{kLiteRtOpCodeTflAdd, 0}})); -} - -TEST(MatchUsesTest, NonStrict) { - auto litert_model = LoadTestFileModel("add_simple.tflite"); - auto subgraph = litert_model.MainSubgraph(); - ABSL_CHECK(subgraph); - auto subgraph_inputs = subgraph->Inputs(); - const auto& tensor = subgraph_inputs.front(); - EXPECT_TRUE(MatchUses(tensor, {{kLiteRtOpCodeTflAdd, 0}}, /*strict=*/false)); -} - -} // namespace - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_model_test.cc b/tensorflow/lite/experimental/litert/cc/litert_model_test.cc deleted file mode 100644 index a1a80f82e5f397..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_model_test.cc +++ /dev/null @@ -1,359 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" - -// Tests for CC Wrapper classes around public C api. - -namespace litert { - -namespace { - -static constexpr const int32_t kTensorDimensions[] = {1, 2, 3}; - -static constexpr const auto kRank = - sizeof(kTensorDimensions) / sizeof(kTensorDimensions[0]); - -static constexpr const uint32_t kTensorStrides[] = {6, 3, 1}; - -static constexpr const LiteRtLayout kLayout = BuildLayout(kTensorDimensions); - -static constexpr const LiteRtLayout kLayoutWithStrides = - BuildLayout(kTensorDimensions, kTensorStrides); - -static constexpr const LiteRtRankedTensorType kTensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - /*.layout=*/kLayout, -}; - -//===----------------------------------------------------------------------===// -// CC Model // -//===----------------------------------------------------------------------===// - -TEST(CcModelTest, SimpleModel) { - auto model = testing::LoadTestFileModel("one_mul.tflite"); - - LiteRtParamIndex num_subgraphs; - ASSERT_EQ(LiteRtGetNumModelSubgraphs(model.Get(), &num_subgraphs), - kLiteRtStatusOk); - EXPECT_EQ(model.NumSubgraphs(), num_subgraphs); - EXPECT_EQ(model.NumSubgraphs(), 1); - - LiteRtParamIndex main_subgraph_index; - ASSERT_EQ(LiteRtGetMainModelSubgraphIndex(model.Get(), &main_subgraph_index), - kLiteRtStatusOk); - EXPECT_EQ(main_subgraph_index, 0); - - LiteRtSubgraph litert_subgraph_0; - ASSERT_EQ(LiteRtGetModelSubgraph(model.Get(), /*subgraph_index=*/0, - &litert_subgraph_0), - kLiteRtStatusOk); - - auto subgraph_0 = model.Subgraph(0); - ASSERT_TRUE(subgraph_0); - EXPECT_EQ(subgraph_0->Get(), litert_subgraph_0); - - auto main_subgraph = model.MainSubgraph(); - EXPECT_EQ(main_subgraph->Get(), subgraph_0->Get()); -} - -//===----------------------------------------------------------------------===// -// CC Signature // -//===----------------------------------------------------------------------===// - -TEST(CcSignatureTest, Basic) { - auto model = testing::LoadTestFileModel("one_mul.tflite"); - - auto signatures = model.GetSignatures(); - ASSERT_TRUE(signatures); - ASSERT_EQ(signatures->size(), 1); - auto& signature = signatures->at(0); - EXPECT_THAT(signature.Key(), Model::DefaultSignatureKey()); - auto input_names = signature.InputNames(); - EXPECT_THAT(input_names[0], "arg0"); - EXPECT_THAT(input_names[1], "arg1"); - auto output_names = signature.OutputNames(); - EXPECT_THAT(output_names[0], "tfl.mul"); -} - -TEST(CcSignatureTest, Lookup) { - auto model = testing::LoadTestFileModel("one_mul.tflite"); - - { - auto signature = model.FindSignature("nonexistent"); - ASSERT_FALSE(signature); - } - auto signature = model.FindSignature(Model::DefaultSignatureKey()); - ASSERT_TRUE(signature); - EXPECT_THAT(signature->Key(), Model::DefaultSignatureKey()); - auto input_names = signature->InputNames(); - EXPECT_THAT(input_names[0], "arg0"); - EXPECT_THAT(input_names[1], "arg1"); - auto output_names = signature->OutputNames(); - EXPECT_THAT(output_names[0], "tfl.mul"); -} - -//===----------------------------------------------------------------------===// -// CC Layout // -//===----------------------------------------------------------------------===// - -TEST(CcLayoutTest, NoStrides) { - Layout layout(kLayout); - - ASSERT_EQ(layout.Rank(), kLayout.rank); - for (auto i = 0; i < layout.Rank(); ++i) { - ASSERT_EQ(layout.Dimensions()[i], kLayout.dimensions[i]); - } - ASSERT_FALSE(layout.HasStrides()); -} - -TEST(CcLayoutTest, WithStrides) { - Layout layout(kLayoutWithStrides); - - ASSERT_EQ(layout.Rank(), kLayoutWithStrides.rank); - for (auto i = 0; i < layout.Rank(); ++i) { - ASSERT_EQ(layout.Dimensions()[i], kLayoutWithStrides.dimensions[i]); - } - ASSERT_TRUE(layout.HasStrides()); - for (auto i = 0; i < layout.Rank(); ++i) { - ASSERT_EQ(layout.Strides()[i], kLayoutWithStrides.strides[i]); - } -} - -TEST(CcLayoutTest, Equal) { - auto&& dims = {2, 2}; - Layout layout1(BuildLayout(dims)); - Layout layout2(BuildLayout({2, 2})); - ASSERT_TRUE(layout1 == layout2); -} - -TEST(CcLayoutTest, NotEqual) { - Layout layout1(BuildLayout({2, 2}, nullptr)); - Layout layout2(BuildLayout({2, 2}, kTensorStrides)); - ASSERT_FALSE(layout1 == layout2); -} - -TEST(CcLayoutTest, NumElements) { - Layout layout(BuildLayout({2, 2, 3})); - auto num_elements = layout.NumElements(); - ASSERT_TRUE(num_elements.has_value()); - EXPECT_EQ(num_elements.value(), 12); -} - -//===----------------------------------------------------------------------===// -// CC Op // -//===----------------------------------------------------------------------===// - -TEST(CcOpTest, SimpleSupportedOp) { - auto litert_model = testing::LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - const auto ops = subgraph->Ops(); - const auto& op = ops.front(); - - EXPECT_EQ(op.Code(), kLiteRtOpCodeTflMul); - EXPECT_EQ(op.Inputs().size(), 2); - EXPECT_EQ(op.Outputs().size(), 1); -} - -//===----------------------------------------------------------------------===// -// CC RankedTensorType // -//===----------------------------------------------------------------------===// - -TEST(CcRankedTensorTypeTest, Accessors) { - Layout layout(kLayout); - RankedTensorType tensor_type(kTensorType); - ASSERT_EQ(tensor_type.ElementType(), - static_cast(kTensorType.element_type)); - ASSERT_TRUE(tensor_type.Layout() == layout); -} - -//===----------------------------------------------------------------------===// -// CC Tensor // -//===----------------------------------------------------------------------===// - -TEST(CcTensorTest, SimpleModel) { - auto litert_model = testing::LoadTestFileModel("one_mul.tflite"); - auto subgraph = litert_model.MainSubgraph(); - - auto inputs = subgraph->Inputs(); - ASSERT_EQ(inputs.size(), 2); - - { - const Tensor& input_tensor = inputs.front(); - ASSERT_EQ(input_tensor.TypeId(), kLiteRtRankedTensorType); - - auto input_ranked_tensor_type = input_tensor.RankedTensorType(); - EXPECT_TRUE(input_ranked_tensor_type); - ASSERT_EQ(input_ranked_tensor_type->ElementType(), ElementType::Float32); - - EXPECT_FALSE(input_tensor.HasWeights()); - - auto input_weights = input_tensor.Weights(); - ASSERT_EQ(input_weights.Bytes().size(), 0); - - ASSERT_EQ(input_tensor.DefiningOp(), std::nullopt); - - const auto uses = input_tensor.Uses(); - ASSERT_EQ(uses.size(), 1); - } - - auto outputs = subgraph->Outputs(); - ASSERT_EQ(outputs.size(), 1); - - { - const Tensor& output_tensor = outputs.front(); - ASSERT_EQ(output_tensor.TypeId(), kLiteRtRankedTensorType); - - auto output_defining_op = output_tensor.DefiningOp(); - EXPECT_TRUE(output_defining_op.has_value()); - - ASSERT_TRUE(output_tensor.Uses().empty()); - } -} - -TEST(CcTensorTest, WeightsData) { - auto litert_model = testing::LoadTestFileModel("add_cst.tflite"); - auto subgraph = litert_model.MainSubgraph(); - - auto data = subgraph->Ops().front().Inputs().back().WeightsData(); - ASSERT_TRUE(data.HasValue()); - EXPECT_THAT(data.Value(), ::testing::ElementsAreArray({1.0, 2.0, 3.0, 4.0})); -} - -TEST(CcTensorTest, Name) { - static constexpr absl::string_view kName = "foo"; - LiteRtTensorT tensor; - tensor.SetName(std::string(kName)); - - Tensor cc_tensor(&tensor); - EXPECT_EQ(cc_tensor.Name(), kName); -} - -TEST(CcTensorTest, QuantizationNone) { - LiteRtTensorT litert_tensor; - litert_tensor.Qparams().first = kLiteRtQuantizationNone; - - Tensor tensor(&litert_tensor); - EXPECT_EQ(tensor.QTypeId(), kLiteRtQuantizationNone); - EXPECT_FALSE(tensor.HasQuantization()); -} - -TEST(CcTensorTest, QuantizationPerTensor) { - static constexpr auto kScale = 1.0; - static constexpr auto kZeroPoint = 1; - - LiteRtTensorT litert_tensor; - litert_tensor.SetQarams(MakePerTensorQuantization(kScale, kZeroPoint)); - - Tensor tensor(&litert_tensor); - ASSERT_EQ(tensor.QTypeId(), kLiteRtQuantizationPerTensor); - ASSERT_TRUE(tensor.HasQuantization()); - - const auto per_tensor_quantization = tensor.PerTensorQuantization(); - EXPECT_EQ(per_tensor_quantization.scale, kScale); - EXPECT_EQ(per_tensor_quantization.zero_point, kZeroPoint); -} - -TEST(CcTensorTest, QuantizationPerChannel) { - static constexpr auto kNumChannels = 2; - static constexpr auto kQuantizedDimension = 0; - static constexpr float kScales[kNumChannels] = {1.0, 2.0}; - static constexpr int64_t kZeroPoints[kNumChannels] = {0, 0}; - - LiteRtTensorT litert_tensor; - auto per_channel = MakePerChannelQuantization( - kScales, kZeroPoints, kQuantizedDimension, litert_tensor); - litert_tensor.SetQarams(per_channel); - - Tensor tensor(&litert_tensor); - ASSERT_EQ(tensor.QTypeId(), kLiteRtQuantizationPerChannel); - ASSERT_TRUE(tensor.HasQuantization()); - - const auto per_channel_quantization = tensor.PerChannelQuantization(); - EXPECT_THAT( - absl::MakeConstSpan(per_channel_quantization.scales, kNumChannels), - ::testing::ElementsAreArray(kScales)); - EXPECT_THAT( - absl::MakeConstSpan(per_channel_quantization.zero_points, kNumChannels), - ::testing::ElementsAreArray(kZeroPoints)); - EXPECT_EQ(per_channel_quantization.num_channels, kNumChannels); - EXPECT_EQ(per_channel_quantization.quantized_dimension, kQuantizedDimension); -} - -TEST(CcTensorTest, ZeroSizeTensorTest) { - auto litert_model = testing::LoadTestFileModel("scala_reshape.tflite"); - auto subgraph = litert_model.MainSubgraph(); - const auto ops = subgraph->Ops(); - const auto& op = ops.front(); - EXPECT_FALSE(op.Inputs().at(1).IsSubgraphInput()); -} - -//===----------------------------------------------------------------------===// -// CC Subgraph // -//===----------------------------------------------------------------------===// - -TEST(CcSubgraphTest, SimpleModel) { - auto model = testing::LoadTestFileModel("one_mul.tflite"); - auto subgraph = model.MainSubgraph(); - - ASSERT_EQ(subgraph->Inputs().size(), 2); - ASSERT_EQ(subgraph->Outputs().size(), 1); - ASSERT_EQ(subgraph->Ops().size(), 1); - - auto input0_tensor = subgraph->Input("arg0"); - ASSERT_TRUE(input0_tensor.HasValue()); - auto input1_tensor = subgraph->Input("arg1"); - ASSERT_TRUE(input1_tensor.HasValue()); - - auto output_tensor = subgraph->Output("tfl.mul"); - ASSERT_TRUE(output_tensor.HasValue()); - ASSERT_EQ(output_tensor->TypeId(), kLiteRtRankedTensorType); - auto output_ranked_tensor_type = output_tensor->RankedTensorType(); - EXPECT_TRUE(output_ranked_tensor_type); - ASSERT_EQ(output_ranked_tensor_type->ElementType(), ElementType::Float32); -} - -//===----------------------------------------------------------------------===// -// CC ElementType // -//===----------------------------------------------------------------------===// - -TEST(CcElementTypeTest, GetByteWidth) { - const size_t width = GetByteWidth(); - EXPECT_EQ(width, 1); -} - -TEST(CcElementTypeTest, GetElementType) { - ElementType ty = GetElementType(); - EXPECT_EQ(ty, ElementType::Float32); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_op_options.cc b/tensorflow/lite/experimental/litert/cc/litert_op_options.cc deleted file mode 100644 index c2cdfc6e2d0c7a..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_op_options.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_op_options.h" - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -LiteRtStatus CompositeOptions::InitFromOp(LiteRtOp op) { - LiteRtOpCode opcode; - LITERT_RETURN_IF_ERROR(LiteRtGetOpCode(op, &opcode)); - if (opcode != kLiteRtOpCodeShloComposite) { - return kLiteRtStatusErrorInvalidArgument; - } - - const char* op_name; - LITERT_RETURN_IF_ERROR(LiteRtGetSHLOCompositeOpName(op, &op_name)); - name = op_name; - - LITERT_RETURN_IF_ERROR( - LiteRtGetSHLOCompositeOpDecompositionSubgraphIndex(op, &subgraph)); - - this->op = op; - - return kLiteRtStatusOk; -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_op_options.h b/tensorflow/lite/experimental/litert/cc/litert_op_options.h deleted file mode 100644 index 70f6de4a38007e..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_op_options.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_OP_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_OP_OPTIONS_H_ - -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -struct OpOptions { - virtual LiteRtStatus InitFromOp(LiteRtOp op) = 0; - virtual ~OpOptions() = default; -}; - -// Struct to hold LiteRt composite ops. -struct CompositeOptions : public OpOptions { - // Name for special composites representing manual partitions. - static constexpr absl::string_view kNpuCall = "odml.npu_call"; - - // The root op. - LiteRtOp op; - // Decomposition subgraph. - int subgraph; - // The name of the composite op (stored in model). - absl::string_view name; - - LiteRtStatus InitFromOp(LiteRtOp op) override; -}; - -// Returns the composite info for the given op if it is a composite op. -template -Expected GetOptionsAs(LiteRtOp op) { - if constexpr (std::is_same_v) { - CompositeOptions options; - LITERT_RETURN_IF_ERROR(options.InitFromOp(op)); - return options; - } else { - // TODO: Add more as needed. - return Unexpected(kLiteRtStatusErrorInvalidArgument); - } -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_OP_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_op_options_test.cc b/tensorflow/lite/experimental/litert/cc/litert_op_options_test.cc deleted file mode 100644 index 4be92d3e22f9d4..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_op_options_test.cc +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_op_options.h" - -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert { -namespace { - -TEST(OpOptionsTest, GetCompositeOptions) { - static constexpr auto kOptsType = - ::tflite::BuiltinOptions2_StableHLOCompositeOptions; - static constexpr absl::string_view kName = "test.composite"; - static constexpr int kSubgraph = 1; - - LiteRtOpT op; - op.SetOpCode(kLiteRtOpCodeShloComposite); - - tflite::StableHLOCompositeOptionsT options; - options.name = kName; - options.decomposition_subgraph_index = kSubgraph; - - internal::TflOptions2 tfl_options; - tfl_options.type = kOptsType; - tfl_options.Set(std::move(options)); - litert::internal::SetTflOptions2(op, std::move(tfl_options)); - - auto res = GetOptionsAs(&op); - ASSERT_TRUE(res); - EXPECT_EQ(res->name, kName); - EXPECT_EQ(res->subgraph, kSubgraph); -} - -TEST(OpOptionsTest, GetUnsupportedOptions) { - LiteRtOpT op; - op.SetOpCode(kLiteRtOpCodeShloAdd); - ASSERT_FALSE(GetOptionsAs(&op)); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_shared_library.cc b/tensorflow/lite/experimental/litert/cc/litert_shared_library.cc deleted file mode 100644 index 7e769ed2b3a1bf..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_shared_library.cc +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" - -#if !LITERT_WINDOWS_OS -#include -#endif - -#if defined(_GNU_SOURCE) && !defined(__ANDROID__) && !defined(__APPLE__) -#define LITERT_IMPLEMENT_SHARED_LIBRARY_INFO 1 -#include -#endif - -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -// When using an address sanitizer, `RTLD_DEEPBIND` is not supported. When using -// one, we discard the flag and log an error. -#if defined(__SANITIZE_ADDRESS__) || \ - defined(__has_feature) && \ - (__has_feature(address_sanitizer) || __has_feature(memory_sanitizer)) -#define LITERT_SANITIZER_BUILD 1 -#endif - -#if LITERT_SANITIZER_BUILD && defined(RTLD_DEEPBIND) -namespace litert { -namespace { -RtldFlags SanitizeFlagsInCaseOfAsan(RtldFlags flags) { - LITERT_LOG( - LITERT_WARNING, - "Trying to load a library using `RTLD_DEEPBIND` is not supported by " - "address sanitizers. In an effort to enable testing we strip the flag. " - "If this leads to unintended behaviour, either remove the " - "`RTLD_DEEPBIND` flag or run without an address sanitizer. " - "See https://github.com/google/sanitizers/issues/611 for more " - "information."); - flags.flags &= ~RTLD_DEEPBIND; - return flags; -} -} // namespace -} // namespace litert -#else -#define SanitizeFlagsInCaseOfAsan(flags) (flags) -#endif - -#if LITERT_WINDOWS_OS -// Implement dummy functions from dlfnc.h on Windows. -namespace { - -const char* dlerror() { - return "Windows is not supported for loading shared libraries."; -} - -void* dlopen(const char*, int) { return NULL; } - -void dlclose(void*) {} - -void* dlsym(void*, const char*) { return NULL; } - -int dlinfo(void*, int, void*) { return -1; } - -#define RTLD_NEXT (void*)-1; -#define RTLD_DEFAULT (void*)0; - -} // namespace -#endif - -namespace litert { - -SharedLibrary::~SharedLibrary() noexcept { Close(); } - -SharedLibrary::SharedLibrary(SharedLibrary&& other) noexcept - : handle_kind_(other.handle_kind_), - path_(std::move(other.path_)), - handle_(other.handle_) { - other.handle_kind_ = HandleKind::kInvalid; - other.handle_ = nullptr; -} - -SharedLibrary& SharedLibrary::operator=(SharedLibrary&& other) noexcept { - Close(); - handle_kind_ = other.handle_kind_; - path_ = std::move(other.path_); - handle_ = other.handle_; - other.handle_kind_ = HandleKind::kInvalid; - other.handle_ = nullptr; - return *this; -} - -void SharedLibrary::Close() noexcept { - if (handle_kind_ == HandleKind::kPath) { - dlclose(handle_); - } - handle_kind_ = HandleKind::kInvalid; - path_.clear(); -} - -absl::string_view SharedLibrary::DlError() noexcept { - const char* error = dlerror(); - if (!error) { - return {}; - } - return error; -} - -Expected SharedLibrary::LoadImpl( - SharedLibrary::HandleKind handle_kind, absl::string_view path, - RtldFlags flags) { - SharedLibrary lib; - switch (handle_kind) { - case HandleKind::kInvalid: - return Error(kLiteRtStatusErrorDynamicLoading, - "This is a logic error. LoadImpl should not be called with " - "HandleKind::kInvalid"); - case HandleKind::kPath: - if (path.empty()) { - return Error(kLiteRtStatusErrorDynamicLoading, - "Cannot not load shared library: empty path."); - } - lib.path_ = path; - lib.handle_ = - dlopen(lib.Path().c_str(), SanitizeFlagsInCaseOfAsan(flags)); - if (!lib.handle_) { - return Error(kLiteRtStatusErrorDynamicLoading, - absl::StrFormat("Could not load shared library %s: %s.", - lib.path_, DlError())); - } - break; - case HandleKind::kRtldNext: - lib.handle_ = RTLD_NEXT; - break; - case HandleKind::kRtldDefault: - lib.handle_ = RTLD_DEFAULT; - break; - } - lib.handle_kind_ = handle_kind; - return lib; -} - -Expected SharedLibrary::LookupSymbolImpl(const char* symbol_name) const { - void* symbol = dlsym(handle_, symbol_name); - - if (!symbol) { - return Error(kLiteRtStatusErrorDynamicLoading, - absl::StrFormat("Could not load symbol %s: %s.", symbol_name, - DlError())); - } - return symbol; -} - -std::ostream& operator<<(std::ostream& os, const SharedLibrary& lib) { - static constexpr absl::string_view kHeader = "/// DLL Info ///\n"; - static constexpr absl::string_view kFooter = "////////////////\n"; - - if (lib.handle_ == nullptr) { - os << kHeader << "Handle is nullptr.\n" << kFooter; - return os; - } - - os << kHeader; -#ifdef RTLD_DI_LMID - if (Lmid_t dl_ns_idx; dlinfo(lib.handle_, RTLD_DI_LMID, &dl_ns_idx) != 0) { - os << "Error getting lib namespace index: " << dlerror() << ".\n"; - } else { - os << "LIB NAMESPACE INDEX: " << dl_ns_idx << "\n"; - } -#else - os << "Cannot retrieve namespace index on this platform.\n"; -#endif - -#ifdef RTLD_DI_LINKMAP - if (link_map* lm; dlinfo(lib.handle_, RTLD_DI_LINKMAP, &lm) != 0) { - os << "Error getting linked objects: " << dlerror() << ".\n"; - } else { - os << "LINKED OBJECTS:\n"; - // Rewind to the start of the linked list. - const link_map* link = lm; - while (link->l_prev) { - link = link->l_prev; - } - // Print all list elements - for (; link != nullptr; link = link->l_next) { - os << (link != lm ? " " : "***") << link->l_name << "\n"; - } - } -#else - os << "Cannot retrieve lib map on this platform.\n"; -#endif - return os << kFooter; -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_shared_library.h b/tensorflow/lite/experimental/litert/cc/litert_shared_library.h deleted file mode 100644 index b28f1f7e9a604c..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_shared_library.h +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_SHARED_LIBRARY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_SHARED_LIBRARY_H_ - -#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || \ - defined(__NT__) || defined(_WIN64) -#define LITERT_WINDOWS_OS 1 -#endif - -#if !LITERT_WINDOWS_OS -#include -#endif - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -struct RtldFlags { - int flags; - - static constexpr struct NextTag { - } kNext; - static constexpr struct DefaultTag { - } kDefault; - - // NOLINTNEXTLINE(*-explicit-constructor): we want this to be passed as flags. - operator int() { return flags; } - - static constexpr RtldFlags Lazy() { - return { -#if defined(RTLD_LAZY) - RTLD_LAZY -#endif - }; - } - static constexpr RtldFlags Now() { - return { -#if defined(RTLD_NOW) - RTLD_NOW -#endif - }; - } - static constexpr RtldFlags Default() { return Lazy().Local().DeepBind(); } - constexpr RtldFlags& Global() { -#if defined(RTLD_GLOBAL) - flags |= RTLD_GLOBAL; -#endif - return *this; - } - constexpr RtldFlags& Local() { -#if defined(RTLD_LOCAL) - flags |= RTLD_LOCAL; -#endif - return *this; - } - constexpr RtldFlags& NoDelete() { -#if defined(RTLD_NODELETE) - flags |= RTLD_NODELETE; -#endif - return *this; - } - constexpr RtldFlags& NoLoad() { -#if defined(RTLD_NOLOAD) - flags |= RTLD_NOLOAD; -#endif - return *this; - } - constexpr RtldFlags& DeepBind() { -#if defined(RTLD_DEEPBIND) - flags |= RTLD_DEEPBIND; -#endif - return *this; - } -}; - -// Wraps a dynamically loaded shared library to offer RAII semantics. -class SharedLibrary { - public: - SharedLibrary() = default; - SharedLibrary(const SharedLibrary&) = delete; - SharedLibrary& operator=(const SharedLibrary&) = delete; - SharedLibrary(SharedLibrary&&) noexcept; - SharedLibrary& operator=(SharedLibrary&&) noexcept; - ~SharedLibrary() noexcept; - - // Loads the library at the given path. - static Expected Load(absl::string_view path, - RtldFlags flags) noexcept { - return LoadImpl(HandleKind::kPath, path, flags); - } - - // Loads the library as the RTLD_NEXT special handle. - static Expected Load(RtldFlags::NextTag) noexcept { - return LoadImpl(HandleKind::kRtldNext, "", RtldFlags{}); - } - - // Loads the library as the RTLD_DEFAULT special handle. - static Expected Load(RtldFlags::DefaultTag) noexcept { - return LoadImpl(HandleKind::kRtldDefault, "", RtldFlags{}); - } - - // Gets the last shared library operation error if there was one. - // - // If there was no error, returns an empty view. - static absl::string_view DlError() noexcept; - - friend std::ostream& operator<<(std::ostream& os, const SharedLibrary& lib); - - bool Loaded() const noexcept { return handle_kind_ != HandleKind::kInvalid; } - - // Unloads the shared library. - // - // Note: this is automatically done when the object is destroyed. - void Close() noexcept; - - // Looks up a symbol in the shared library. - // - // Note: This takes a `char*` because the underlying system call requires a - // null terminated string which a string view doesn't guarantee. - template - Expected LookupSymbol(const char* symbol) const noexcept { - static_assert(std::is_pointer_v, - "The template parameter should always be a pointer."); - LITERT_ASSIGN_OR_RETURN(void* const raw_symbol, LookupSymbolImpl(symbol)); - return reinterpret_cast(raw_symbol); - } - - // Returns the loaded library path. - const std::string& Path() const noexcept { return path_; } - - // Returns the underlying shared library handle. - // - // Warning: some special handle value may be NULL. Do not rely on this value - // to check whether a library is loaded or not. - const void* Handle() const noexcept { return handle_; } - void* Handle() noexcept { return handle_; } - - private: - enum class HandleKind { kInvalid, kPath, kRtldNext, kRtldDefault }; - - static Expected LoadImpl(HandleKind handle_kind, - absl::string_view path, - RtldFlags flags); - - Expected LookupSymbolImpl(const char* symbol) const; - - HandleKind handle_kind_ = HandleKind::kInvalid; - std::string path_; - void* handle_ = nullptr; -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_SHARED_LIBRARY_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_shared_library_test.cc b/tensorflow/lite/experimental/litert/cc/litert_shared_library_test.cc deleted file mode 100644 index 5a6fb051d0d0e0..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_shared_library_test.cc +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" - -#include - -#include -#include - -#include -#include -#include "absl/strings/str_cat.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -using testing::Eq; -using testing::NotNull; -using testing::StrEq; -using testing::litert::IsError; - -namespace litert { -namespace { - -extern "C" { - -const char* TestFunction() { return "local_test_function"; } - -} // extern "C" - -TEST(RtldFlagsTest, FlagFactoryWorks) { - EXPECT_THAT(static_cast(RtldFlags::Now()), Eq(RTLD_NOW)); - EXPECT_THAT(static_cast(RtldFlags::Lazy()), Eq(RTLD_LAZY)); - EXPECT_THAT(static_cast(RtldFlags::Lazy().Global()), - Eq(RTLD_LAZY | RTLD_GLOBAL)); - EXPECT_THAT(static_cast(RtldFlags::Lazy().Local()), - Eq(RTLD_LAZY | RTLD_LOCAL)); - EXPECT_THAT(static_cast(RtldFlags::Lazy().NoDelete()), - Eq(RTLD_LAZY | RTLD_NODELETE)); - EXPECT_THAT(static_cast(RtldFlags::Lazy().NoLoad()), - Eq(RTLD_LAZY | RTLD_NOLOAD)); - EXPECT_THAT(static_cast(RtldFlags::Lazy().DeepBind()), - Eq(RTLD_LAZY | RTLD_DEEPBIND)); -} - -TEST(SharedLibraryTest, LoadRtldDefaultWorks) { - LITERT_ASSERT_OK_AND_ASSIGN(SharedLibrary lib, - SharedLibrary::Load(RtldFlags::kDefault)); - - EXPECT_THAT(lib.Path(), StrEq("")); - EXPECT_EQ(lib.Handle(), RTLD_DEFAULT); - - auto maybe_test_function = - lib.LookupSymbol("TestFunction"); - if (!maybe_test_function.HasValue()) { - GTEST_SKIP() << "TestFunction symbol was stripped from binary."; - } - - decltype(&TestFunction) test_function = maybe_test_function.Value(); - ASSERT_NE(test_function, nullptr); - EXPECT_THAT(test_function(), StrEq(TestFunction())); -} - -TEST(SharedLibraryTest, LoadRtldNextWorks) { - LITERT_ASSERT_OK_AND_ASSIGN(SharedLibrary lib, - SharedLibrary::Load(RtldFlags::kNext)); - - EXPECT_THAT(lib.Path(), StrEq("")); - EXPECT_EQ(lib.Handle(), RTLD_NEXT); -} - -TEST(SharedLibraryTest, LoadEmptyPathFails) { - EXPECT_THAT(SharedLibrary::Load("", RtldFlags::Now().Local()), IsError()); -} - -TEST(SharedLibraryTest, LoadPathWorks) { - const std::string lib_path = absl::StrCat( - "third_party/tensorflow/lite/experimental/litert/cc/" - "test_shared_library.so"); - LITERT_ASSERT_OK_AND_ASSIGN( - SharedLibrary lib, - SharedLibrary::Load(lib_path, RtldFlags::Now().Local())); - - EXPECT_TRUE(lib.Loaded()); - EXPECT_THAT(lib.Path(), StrEq(lib_path)); - EXPECT_THAT(lib.Handle(), NotNull()); - - using TestFunctionSignature = char* (*)(); - - LITERT_ASSERT_OK_AND_ASSIGN(TestFunctionSignature test_function, - lib.LookupSymbol("TestFunction")); - ASSERT_NE(test_function, nullptr); - EXPECT_THAT(test_function(), StrEq("test_shared_library")); - - lib.Close(); - EXPECT_THAT(lib.Path(), StrEq("")); - EXPECT_FALSE(lib.Loaded()); -} - -TEST(SharedLibraryTest, ConstructionAndAssignmentWork) { - const std::string lib_path = absl::StrCat( - "third_party/tensorflow/lite/experimental/litert/cc/" - "test_shared_library.so"); - LITERT_ASSERT_OK_AND_ASSIGN( - SharedLibrary lib, - SharedLibrary::Load(lib_path, RtldFlags::Now().Local())); - - const void* const lib_handle = lib.Handle(); - - SharedLibrary lib2(std::move(lib)); - - // NOLINTBEGIN(bugprone-use-after-move): Tests that moving clears up the - // object. - EXPECT_THAT(lib.Path(), StrEq("")); - EXPECT_FALSE(lib.Loaded()); - - EXPECT_TRUE(lib2.Loaded()); - EXPECT_THAT(lib2.Path(), StrEq(lib_path)); - EXPECT_THAT(lib2.Handle(), Eq(lib_handle)); - - lib = std::move(lib2); - EXPECT_THAT(lib2.Path(), StrEq("")); - EXPECT_FALSE(lib2.Loaded()); - - EXPECT_TRUE(lib.Loaded()); - EXPECT_THAT(lib.Path(), StrEq(lib_path)); - EXPECT_THAT(lib.Handle(), Eq(lib_handle)); - // NOLINTEND(bugprone-use-after-move) -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h deleted file mode 100644 index d7fef8cec7e0d7..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h +++ /dev/null @@ -1,350 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_H_ - -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -#if LITERT_HAS_OPENCL_SUPPORT -#include -#else -typedef struct _cl_mem* cl_mem; -#endif - -namespace litert { - -// Tensor and associated backing buffer. C++ equivalent of LiteRtTensorBuffer. -class TensorBuffer - : public internal::Handle { - public: - TensorBuffer() = default; - - // Parameter `owned` indicates if the created TensorBuffer object should take - // ownership of the provided `tensor_buffer` handle. - explicit TensorBuffer(LiteRtTensorBuffer tensor_buffer, bool owned = true) - : internal::Handle( - tensor_buffer, owned) {} - - // Creates a duplicate of the current TensorBuffer object. The returned - // object is reference counted so the underlying LiteRtTensorBuffer handle is - // not released with the destructor until the last reference is removed. - Expected Duplicate() const { - if (!IsOwned()) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Cannot duplicate a non-owned tensor buffer"); - } - LITERT_RETURN_IF_ERROR(LiteRtDuplicateTensorBuffer(Get())); - return TensorBuffer(Get()); - } - - static Expected CreateManaged( - LiteRtTensorBufferType buffer_type, const RankedTensorType& tensor_type, - size_t buffer_size) { - LiteRtTensorBuffer tensor_buffer; - auto litert_tensor_type = static_cast(tensor_type); - LITERT_RETURN_IF_ERROR(LiteRtCreateManagedTensorBuffer( - buffer_type, &litert_tensor_type, buffer_size, &tensor_buffer)); - return TensorBuffer(tensor_buffer); - } - - // Creates a TensorBuffer object that wraps the provided host memory. - // The provided host memory is not owned by the TensorBuffer object and must - // outlive the TensorBuffer object. - static Expected CreateFromHostMemory( - const RankedTensorType& tensor_type, void* host_mem_addr, - size_t buffer_size) { - LiteRtTensorBuffer tensor_buffer; - auto litert_tensor_type = static_cast(tensor_type); - - LITERT_RETURN_IF_ERROR(LiteRtCreateTensorBufferFromHostMemory( - &litert_tensor_type, host_mem_addr, buffer_size, - /*deallocator=*/nullptr, &tensor_buffer)); - return TensorBuffer(tensor_buffer); - } - - // Creates a TensorBuffer object that wraps an Android Hardware Buffer. Note - // that the provided AHardwareBuffer is not owned by the TensorBuffer object - // and must outlive the TensorBuffer object. The `ahwb_offset` parameter - // specifies the offset in bytes from the start of the AHardwareBuffer where - // the tensor data starts. - static Expected CreateFromAhwb( - const RankedTensorType& tensor_type, AHardwareBuffer* ahwb, - size_t ahwb_offset) { -#if LITERT_HAS_AHWB_SUPPORT - LiteRtTensorBuffer tensor_buffer; - auto litert_tensor_type = static_cast(tensor_type); - - LITERT_RETURN_IF_ERROR(LiteRtCreateTensorBufferFromAhwb( - &litert_tensor_type, ahwb, ahwb_offset, - /*deallocator=*/nullptr, &tensor_buffer)); - return TensorBuffer(tensor_buffer); -#else - return litert::Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer is not supported on this platform"); -#endif - } - - litert::Expected GetAhwb() const { -#if LITERT_HAS_AHWB_SUPPORT - AHardwareBuffer* ahwb; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferAhwb(Get(), &ahwb)); - return ahwb; -#else - return litert::Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer is not supported on this platform"); -#endif - } - - struct DmaBuf { - void* addr; - int fd; - }; - - litert::Expected GetDmaBuf() const { -#if LITERT_HAS_DMABUF_SUPPORT - DmaBuf dma_buf; - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferDmaBufBuffer(Get(), &dma_buf.addr, &dma_buf.fd)); - return dma_buf; -#else - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "DMA-BUF is not supported on this platform"); -#endif - } - - Expected GetOpenClBuffer() const { -#if LITERT_HAS_OPENCL_SUPPORT - cl_mem cl_mem; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferOpenClBuffer(Get(), &cl_mem)); - return cl_mem; -#else - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "OpenCL is not supported on this platform"); -#endif - } - - struct GlBuffer { - LiteRtGLenum target; - LiteRtGLuint id; - size_t size_bytes; - size_t offset; - }; - - static Expected CreateFromGlBuffer( - const RankedTensorType& tensor_type, LiteRtGLenum target, LiteRtGLuint id, - size_t size_bytes, size_t offset) { - LiteRtTensorBuffer tensor_buffer; - auto litert_tensor_type = static_cast(tensor_type); - LITERT_RETURN_IF_ERROR(LiteRtCreateTensorBufferFromGlBuffer( - &litert_tensor_type, target, id, size_bytes, offset, - /*deallocator=*/nullptr, &tensor_buffer)); - return TensorBuffer(tensor_buffer); - } - - Expected GetGlBuffer() const { - GlBuffer gl_buffer; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferGlBuffer( - Get(), &gl_buffer.target, &gl_buffer.id, &gl_buffer.size_bytes, - &gl_buffer.offset)); - return gl_buffer; - } - struct GlTexture { - LiteRtGLenum target; - LiteRtGLuint id; - LiteRtGLenum format; - size_t size_bytes; - LiteRtGLint layer; - }; - static Expected CreateFromGlTexture( - const RankedTensorType& tensor_type, LiteRtGLenum target, LiteRtGLuint id, - LiteRtGLenum format, size_t size_bytes, LiteRtGLint layer) { - LiteRtTensorBuffer tensor_buffer; - auto litert_tensor_type = static_cast(tensor_type); - LITERT_RETURN_IF_ERROR(LiteRtCreateTensorBufferFromGlTexture( - &litert_tensor_type, target, id, format, size_bytes, layer, - /*deallocator=*/nullptr, &tensor_buffer)); - return TensorBuffer(tensor_buffer); - } - - Expected GetGlTexture() const { - GlTexture gl_texture; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferGlTexture( - Get(), &gl_texture.target, &gl_texture.id, &gl_texture.format, - &gl_texture.size_bytes, &gl_texture.layer)); - return gl_texture; - } - - Expected BufferType() const { - LiteRtTensorBufferType tensor_buffer_type; - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferType(Get(), &tensor_buffer_type)); - return tensor_buffer_type; - } - - Expected TensorType() const { - LiteRtRankedTensorType tensor_type; - if (auto status = LiteRtGetTensorBufferTensorType(Get(), &tensor_type); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get tensor type"); - } - return RankedTensorType(tensor_type); - } - - Expected Size() const { - size_t size; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferSize(Get(), &size)); - return size; - } - - Expected Offset() const { - size_t offset; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferOffset(Get(), &offset)); - return offset; - } - - bool HasEvent() const { - bool has_event; - internal::AssertOk(LiteRtHasTensorBufferEvent, Get(), &has_event); - return has_event; - } - - Expected GetEvent() const { - LiteRtEvent event; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferEvent(Get(), &event)); - return Event(event, /*owned=*/false); - } - - // Set the C++ Event object for the tensor buffer. - // The function takes ownership of the passed Event object. - Expected SetEvent(Event&& event) { - if (!event.IsOwned()) { - return Error(kLiteRtStatusErrorInvalidArgument, - "Expected an owned event"); - } - LITERT_RETURN_IF_ERROR(LiteRtSetTensorBufferEvent(Get(), event.Release())); - return {}; - } - - // Set the C LiteRtEvent object for the tensor buffer. - // The function takes ownership of the passed LiteRtEvent object. - Expected SetLiteRtEvent(LiteRtEvent& litert_event) { - LITERT_RETURN_IF_ERROR(LiteRtSetTensorBufferEvent(Get(), litert_event)); - return {}; - } - - Expected ClearEvent() { - LITERT_RETURN_IF_ERROR(LiteRtClearTensorBufferEvent(Get())); - return {}; - } - - Expected Lock() { - void* host_mem_addr; - LITERT_RETURN_IF_ERROR(LiteRtLockTensorBuffer(Get(), &host_mem_addr)); - return host_mem_addr; - } - - Expected Unlock() { - LITERT_RETURN_IF_ERROR(LiteRtUnlockTensorBuffer(Get())); - return {}; - } - - // Writes data from the user provided Span to the tensor buffer. - // It returns an error if the provided buffer is bigger than the size of the - // tensor buffer. - template - Expected Write(absl::Span data) { - LITERT_ASSIGN_OR_RETURN(void* host_mem_addr, Lock()); - LITERT_ASSIGN_OR_RETURN(size_t size, Size()); - if (size < data.size() * sizeof(T)) { - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "TensorBuffer size is smaller than the given data size"); - } - std::memcpy(host_mem_addr, data.data(), data.size() * sizeof(T)); - Unlock(); - return {}; - } - - // Reads data into the user provided Span from the tensor buffer. - // If the provided buffer is smaller than the size of the tensor buffer, the - // data will be read up to the size of the provided buffer. - // It returns an error if the provided buffer is bigger than the size of the - // tensor buffer. - template - Expected Read(absl::Span data) { - LITERT_ASSIGN_OR_RETURN(void* host_mem_addr, Lock()); - LITERT_ASSIGN_OR_RETURN(size_t size, Size()); - size_t total_read_size = data.size() * sizeof(T); - if (size < total_read_size) { - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "TensorBuffer size is smaller than the given data size"); - } - std::memcpy(data.data(), host_mem_addr, total_read_size); - Unlock(); - return {}; - } -}; - -class TensorBufferScopedLock { - public: - TensorBufferScopedLock(const TensorBufferScopedLock& arg) = delete; - TensorBufferScopedLock(TensorBufferScopedLock&& arg) = default; - ~TensorBufferScopedLock() { (void)LiteRtUnlockTensorBuffer(tensor_buffer_); } - - template - static Expected> Create( - TensorBuffer& tensor_buffer) { - return Create(tensor_buffer.Get()); - } - - template - static Expected> Create( - LiteRtTensorBuffer tensor_buffer) { - void* host_mem_addr; - LITERT_RETURN_IF_ERROR( - LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr)); - return std::make_pair(TensorBufferScopedLock(tensor_buffer), - static_cast(host_mem_addr)); - } - - private: - explicit TensorBufferScopedLock(LiteRtTensorBuffer& tensor_buffer) - : tensor_buffer_(tensor_buffer) {} - - LiteRtTensorBuffer tensor_buffer_; -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h deleted file mode 100644 index 881e3662a2fff6..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ - -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert { - -// Requirements for allocating a TensorBuffer, typically specified by a HW -// accelerator for a given I/O tensor. C++ equivalent to -// LiteRtTensorBufferRequirements. -class TensorBufferRequirements - : public internal::Handle { - public: - TensorBufferRequirements() = default; - - // Parameter `owned` indicates if the created TensorBufferRequirements object - // should take ownership of the provided `requirements` handle. - explicit TensorBufferRequirements(LiteRtTensorBufferRequirements requirements, - bool owned = true) - : internal::Handle(requirements, - owned) {} - - static Expected Create( - absl::Span buffer_types, size_t buffer_size, - absl::Span strides = - absl::MakeSpan(static_cast(nullptr), 0)) { - LiteRtTensorBufferRequirements tensor_buffer_requirements; - LITERT_RETURN_IF_ERROR(LiteRtCreateTensorBufferRequirements( - buffer_types.size(), buffer_types.data(), buffer_size, strides.size(), - strides.data(), &tensor_buffer_requirements)); - return TensorBufferRequirements(tensor_buffer_requirements); - } - - Expected> SupportedTypes() const { - int num_types; - LITERT_RETURN_IF_ERROR( - LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes(Get(), - &num_types)); - std::vector types(num_types); - for (auto i = 0; i < num_types; ++i) { - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - Get(), i, &types[i])); - } - return types; - } - - Expected BufferSize() const { - size_t buffer_size; - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferRequirementsBufferSize(Get(), &buffer_size)); - return buffer_size; - } - - Expected> Strides() const { - int num_strides; - const uint32_t* strides; - LITERT_RETURN_IF_ERROR(LiteRtGetTensorBufferRequirementsStrides( - Get(), &num_strides, &strides)); - return absl::MakeSpan(strides, num_strides); - } -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements_test.cc b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements_test.cc deleted file mode 100644 index 0dba6aaac27641..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements_test.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" - -namespace { - -constexpr const LiteRtTensorBufferType kSupportedTensorBufferTypes[] = { - kLiteRtTensorBufferTypeHostMemory, - kLiteRtTensorBufferTypeAhwb, - kLiteRtTensorBufferTypeIon, - kLiteRtTensorBufferTypeFastRpc, -}; - -constexpr const size_t kNumSupportedTensorBufferTypes = - sizeof(kSupportedTensorBufferTypes) / - sizeof(kSupportedTensorBufferTypes[0]); - -constexpr const size_t kBufferSize = 1234; - -} // namespace - -TEST(TensorBufferRequirements, Owned) { - auto requirements = litert::TensorBufferRequirements::Create( - absl::MakeSpan(kSupportedTensorBufferTypes, - kNumSupportedTensorBufferTypes), - kBufferSize); - ASSERT_TRUE(requirements); - - auto supported_types = requirements->SupportedTypes(); - ASSERT_TRUE(supported_types); - ASSERT_EQ(supported_types->size(), kNumSupportedTensorBufferTypes); - for (auto i = 0; i < supported_types->size(); ++i) { - ASSERT_EQ((*supported_types)[i], kSupportedTensorBufferTypes[i]); - } - - auto size = requirements->BufferSize(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, kBufferSize); -} - -TEST(TensorBufferRequirements, NotOwned) { - LiteRtTensorBufferRequirements litert_requirements; - ASSERT_EQ(LiteRtCreateTensorBufferRequirements( - kNumSupportedTensorBufferTypes, kSupportedTensorBufferTypes, - kBufferSize, /*num_strides=*/0, /*strides=*/nullptr, - &litert_requirements), - kLiteRtStatusOk); - - litert::TensorBufferRequirements requirements(litert_requirements, - /*owned=*/false); - - auto supported_types = requirements.SupportedTypes(); - ASSERT_TRUE(supported_types); - ASSERT_EQ(supported_types->size(), kNumSupportedTensorBufferTypes); - for (auto i = 0; i < supported_types->size(); ++i) { - ASSERT_EQ((*supported_types)[i], kSupportedTensorBufferTypes[i]); - } - - auto size = requirements.BufferSize(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, kBufferSize); - - ASSERT_EQ(requirements.Get(), litert_requirements); - - LiteRtDestroyTensorBufferRequirements(litert_requirements); -} - -TEST(TensorBufferRequirements, WithStrides) { - constexpr std::array kStrides = {1, 2, 3}; - - auto requirements = litert::TensorBufferRequirements::Create( - absl::MakeSpan(kSupportedTensorBufferTypes, - kNumSupportedTensorBufferTypes), - kBufferSize, absl::MakeSpan(kStrides.data(), kStrides.size())); - ASSERT_TRUE(requirements); - - auto strides = requirements->Strides(); - ASSERT_TRUE(strides); - ASSERT_EQ(strides->size(), kStrides.size()); - for (auto i = 0; i < kStrides.size(); ++i) { - ASSERT_EQ((*strides)[i], kStrides[i]); - } -} diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_test.cc b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_test.cc deleted file mode 100644 index c366f2081c1b9b..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_test.cc +++ /dev/null @@ -1,613 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -#include -#include -#include -#include -#include - -#include -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/ion_buffer.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -#if LITERT_HAS_AHWB_SUPPORT -#include -#endif // LITERT_HAS_AHWB_SUPPORT - -#if LITERT_HAS_OPENGL_SUPPORT -#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -namespace litert { -namespace { - -using ::testing::Eq; -using ::testing::Ne; - -constexpr const float kTensorData[] = {10, 20, 30, 40}; - -constexpr const int32_t kTensorDimensions[] = {sizeof(kTensorData) / - sizeof(kTensorData[0])}; - -constexpr int kFakeSyncFenceFd = 1; - -constexpr const LiteRtRankedTensorType kTestTensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - BuildLayout(kTensorDimensions)}; - -int GetReferenceCount(const TensorBuffer& tensor_buffer) { - LiteRtTensorBufferT* internal_tensor_buffer = - static_cast(tensor_buffer.Get()); - return internal_tensor_buffer->RefCount(); -} - -TEST(TensorBuffer, HostMemory) { - const RankedTensorType kTensorType(kTestTensorType); - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeHostMemory; - - auto tensor_buffer = TensorBuffer::CreateManaged( - kTensorBufferType, kTensorType, sizeof(kTensorData)); - ASSERT_TRUE(tensor_buffer); - - auto tensor_buffer_type = tensor_buffer->BufferType(); - ASSERT_TRUE(tensor_buffer_type); - ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); - - auto tensor_type = tensor_buffer->TensorType(); - ASSERT_TRUE(tensor_type); - - ASSERT_EQ(tensor_type->ElementType(), ElementType::Float32); - ASSERT_EQ(tensor_type->Layout().Rank(), 1); - ASSERT_EQ(tensor_type->Layout().Dimensions()[0], - kTensorType.Layout().Dimensions()[0]); - ASSERT_FALSE(tensor_type->Layout().HasStrides()); - - auto size = tensor_buffer->Size(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, sizeof(kTensorData)); - - auto offset = tensor_buffer->Offset(); - ASSERT_TRUE(offset); - ASSERT_EQ(*offset, 0); - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); - } - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - ASSERT_EQ( - std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), - 0); - } -} - -TEST(TensorBuffer, Ahwb) { - if (!internal::AhwbBuffer::IsSupported()) { - GTEST_SKIP() << "AHardwareBuffers are not supported on this platform; " - "skipping the test"; - } - - const RankedTensorType kTensorType(kTestTensorType); - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeAhwb; - - auto tensor_buffer = TensorBuffer::CreateManaged( - kTensorBufferType, kTensorType, sizeof(kTensorData)); - ASSERT_TRUE(tensor_buffer); - - auto tensor_buffer_type = tensor_buffer->BufferType(); - ASSERT_TRUE(tensor_buffer_type); - ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); - - auto tensor_type = tensor_buffer->TensorType(); - ASSERT_TRUE(tensor_type); - - ASSERT_EQ(tensor_type->ElementType(), ElementType::Float32); - ASSERT_EQ(tensor_type->Layout().Rank(), 1); - ASSERT_EQ(tensor_type->Layout().Dimensions()[0], - kTensorType.Layout().Dimensions()[0]); - ASSERT_FALSE(tensor_type->Layout().HasStrides()); - - auto size = tensor_buffer->Size(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, sizeof(kTensorData)); - - auto offset = tensor_buffer->Offset(); - ASSERT_TRUE(offset); - ASSERT_EQ(*offset, 0); - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); - } - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - ASSERT_EQ( - std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), - 0); - } -} - -TEST(TensorBuffer, Ion) { - if (!internal::IonBuffer::IsSupported()) { - GTEST_SKIP() - << "ION buffers are not supported on this platform; skipping the test"; - } - - const RankedTensorType kTensorType(kTestTensorType); - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeIon; - - auto tensor_buffer = TensorBuffer::CreateManaged( - kTensorBufferType, kTensorType, sizeof(kTensorData)); - ASSERT_TRUE(tensor_buffer); - - auto tensor_buffer_type = tensor_buffer->BufferType(); - ASSERT_TRUE(tensor_buffer_type); - ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); - - auto tensor_type = tensor_buffer->TensorType(); - ASSERT_TRUE(tensor_type); - - ASSERT_EQ(tensor_type->ElementType(), ElementType::Float32); - ASSERT_EQ(tensor_type->Layout().Rank(), 1); - ASSERT_EQ(tensor_type->Layout().Dimensions()[0], - kTensorType.Layout().Dimensions()[0]); - ASSERT_FALSE(tensor_type->Layout().HasStrides()); - - auto size = tensor_buffer->Size(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, sizeof(kTensorData)); - - auto offset = tensor_buffer->Offset(); - ASSERT_TRUE(offset); - ASSERT_EQ(*offset, 0); - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); - } - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - ASSERT_EQ( - std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), - 0); - } -} - -TEST(TensorBuffer, DmaBuf) { - if (!internal::DmaBufBuffer::IsSupported()) { - GTEST_SKIP() - << "DMA-BUF buffers are not supported on this platform; skipping " - "the test"; - } - - const RankedTensorType kTensorType(kTestTensorType); - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeDmaBuf; - - auto tensor_buffer = TensorBuffer::CreateManaged( - kTensorBufferType, kTensorType, sizeof(kTensorData)); - ASSERT_TRUE(tensor_buffer); - - auto tensor_buffer_type = tensor_buffer->BufferType(); - ASSERT_TRUE(tensor_buffer_type); - ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); - - auto tensor_type = tensor_buffer->TensorType(); - ASSERT_TRUE(tensor_type); - - ASSERT_EQ(tensor_type->ElementType(), ElementType::Float32); - ASSERT_EQ(tensor_type->Layout().Rank(), 1); - ASSERT_EQ(tensor_type->Layout().Dimensions()[0], - kTensorType.Layout().Dimensions()[0]); - ASSERT_FALSE(tensor_type->Layout().HasStrides()); - - auto size = tensor_buffer->Size(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, sizeof(kTensorData)); - - auto offset = tensor_buffer->Offset(); - ASSERT_TRUE(offset); - ASSERT_EQ(*offset, 0); - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); - } - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - ASSERT_EQ( - std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), - 0); - } -} - -TEST(TensorBuffer, FastRpc) { - if (!internal::FastRpcBuffer::IsSupported()) { - GTEST_SKIP() - << "FastRPC buffers are not supported on this platform; skipping " - "the test"; - } - - const RankedTensorType kTensorType(kTestTensorType); - constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeFastRpc; - - auto tensor_buffer = TensorBuffer::CreateManaged( - kTensorBufferType, kTensorType, sizeof(kTensorData)); - ASSERT_TRUE(tensor_buffer); - - auto tensor_buffer_type = tensor_buffer->BufferType(); - ASSERT_TRUE(tensor_buffer_type); - ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); - - auto tensor_type = tensor_buffer->TensorType(); - ASSERT_TRUE(tensor_type); - - ASSERT_EQ(tensor_type->ElementType(), ElementType::Float32); - ASSERT_EQ(tensor_type->Layout().Rank(), 1); - ASSERT_EQ(tensor_type->Layout().Dimensions()[0], - kTensorType.Layout().Dimensions()[0]); - ASSERT_FALSE(tensor_type->Layout().HasStrides()); - - auto size = tensor_buffer->Size(); - ASSERT_TRUE(size); - ASSERT_EQ(*size, sizeof(kTensorData)); - - auto offset = tensor_buffer->Offset(); - ASSERT_TRUE(offset); - ASSERT_EQ(*offset, 0); - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); - } - - { - auto lock_and_addr = TensorBufferScopedLock::Create(*tensor_buffer); - ASSERT_TRUE(lock_and_addr); - ASSERT_EQ( - std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), - 0); - } -} - -TEST(TensorBuffer, NotOwned) { - LiteRtTensorBuffer litert_tensor_buffer; - ASSERT_EQ(LiteRtCreateManagedTensorBuffer( - kLiteRtTensorBufferTypeHostMemory, &kTestTensorType, - sizeof(kTensorData), &litert_tensor_buffer), - kLiteRtStatusOk); - - TensorBuffer tensor_buffer(litert_tensor_buffer, /*owned=*/false); - ASSERT_EQ(tensor_buffer.Get(), litert_tensor_buffer); - - LiteRtDestroyTensorBuffer(litert_tensor_buffer); -} - -TEST(TensorBuffer, CreateFromExternalHostMemory) { - // Allocate a tensor buffer with host memory. - const int kTensorBufferSize = - std::max(sizeof(kTensorData), LITERT_HOST_MEMORY_BUFFER_ALIGNMENT); - const RankedTensorType kTensorType(kTestTensorType); - void* host_memory_ptr; - ASSERT_EQ( - ::posix_memalign(&host_memory_ptr, LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, - kTensorBufferSize), - 0); - - std::memcpy(host_memory_ptr, kTensorData, sizeof(kTensorData)); - - // Create a tensor buffer that wraps the host memory. - auto tensor_buffer_from_external_memory = TensorBuffer::CreateFromHostMemory( - kTensorType, host_memory_ptr, kTensorBufferSize); - - auto lock_and_addr_external_memory = - TensorBufferScopedLock::Create(*tensor_buffer_from_external_memory); - ASSERT_TRUE(lock_and_addr_external_memory); - ASSERT_EQ(std::memcmp(lock_and_addr_external_memory->second, kTensorData, - sizeof(kTensorData)), - 0); - - free(host_memory_ptr); -} - -#if LITERT_HAS_AHWB_SUPPORT -TEST(TensorBuffer, CreateFromAhwb) { - AHardwareBuffer* ahw_buffer = nullptr; - if (__builtin_available(android 26, *)) { - int error = 0; - AHardwareBuffer_Desc desc = { - .width = LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, - .height = 1, - .layers = 1, - .format = AHARDWAREBUFFER_FORMAT_BLOB, - .usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY | - AHARDWAREBUFFER_USAGE_CPU_READ_RARELY}; - error = AHardwareBuffer_allocate(&desc, &ahw_buffer); - ASSERT_EQ(error, 0); - - void* host_memory_ptr = nullptr; - error = - AHardwareBuffer_lock(ahw_buffer, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, - -1, nullptr, &host_memory_ptr); - ASSERT_EQ(error, 0); - - std::memcpy(host_memory_ptr, kTensorData, sizeof(kTensorData)); - - int fence_file_descriptor = -1; - error = AHardwareBuffer_unlock(ahw_buffer, &fence_file_descriptor); - ASSERT_EQ(error, 0); - } else { - GTEST_SKIP() << "AHardwareBuffers are not supported on this platform; " - "skipping the test"; - } - - { - // Create a tensor buffer that wraps the AHardwareBuffer. - const RankedTensorType kTensorType(kTestTensorType); - auto tensor_buffer_from_ahwb = - TensorBuffer::CreateFromAhwb(kTensorType, ahw_buffer, - /*ahwb_offset=*/0); - - auto lock_and_addr_external_memory = - TensorBufferScopedLock::Create(*tensor_buffer_from_ahwb); - ASSERT_TRUE(lock_and_addr_external_memory); - ASSERT_EQ(std::memcmp(lock_and_addr_external_memory->second, kTensorData, - sizeof(kTensorData)), - 0); - } - - if (__builtin_available(android 26, *)) { - AHardwareBuffer_release(ahw_buffer); - } -} -#endif // LITERT_HAS_AHWB_SUPPORT - -TEST(TensorBuffer, Duplicate) { - LiteRtTensorBuffer litert_tensor_buffer; - ASSERT_EQ(LiteRtCreateManagedTensorBuffer( - kLiteRtTensorBufferTypeHostMemory, &kTestTensorType, - sizeof(kTensorData), &litert_tensor_buffer), - kLiteRtStatusOk); - - TensorBuffer tensor_buffer(litert_tensor_buffer, /*owned=*/true); - ASSERT_EQ(GetReferenceCount(tensor_buffer), 1); - { - auto duplicated_tensor_buffer = tensor_buffer.Duplicate(); - ASSERT_TRUE(duplicated_tensor_buffer); - ASSERT_EQ(GetReferenceCount(*duplicated_tensor_buffer), 2); - // The duplicated tensor buffer should point to the same underlying - // LiteRtTensorBuffer object. - ASSERT_EQ(duplicated_tensor_buffer->Get(), tensor_buffer.Get()); - - // Update tensor buffer using the duplicated tensor buffer. - auto lock_and_addr = - TensorBufferScopedLock::Create(*duplicated_tensor_buffer); - ASSERT_TRUE(lock_and_addr); - std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); - - // When the scope ends, the duplicated tensor buffer should be destroyed. - // This should not affect the original tensor buffer. - } - - ASSERT_EQ(GetReferenceCount(tensor_buffer), 1); - // Check that the original tensor buffer is not affected. - { - auto lock_and_addr = TensorBufferScopedLock::Create(tensor_buffer); - ASSERT_TRUE(lock_and_addr); - ASSERT_EQ( - std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), - 0); - } -} - -TEST(TensorBuffer, ReadWriteBasic) { - LiteRtTensorBuffer litert_tensor_buffer; - ASSERT_EQ(LiteRtCreateManagedTensorBuffer( - kLiteRtTensorBufferTypeHostMemory, &kTestTensorType, - sizeof(kTensorData), &litert_tensor_buffer), - kLiteRtStatusOk); - - TensorBuffer tensor_buffer(litert_tensor_buffer, /*owned=*/true); - auto write_success = tensor_buffer.Write(absl::MakeSpan( - kTensorData, sizeof(kTensorData) / sizeof(kTensorData[0]))); - ASSERT_TRUE(write_success); - float read_data[sizeof(kTensorData) / sizeof(kTensorData[0])]; - auto read_success = tensor_buffer.Read(absl::MakeSpan(read_data)); - ASSERT_TRUE(read_success); - ASSERT_EQ(std::memcmp(read_data, kTensorData, sizeof(kTensorData)), 0); -} - -TEST(TensorBuffer, ReadWriteBufferSizeMismatch) { - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer tensor_buffer, - TensorBuffer::CreateManaged(kLiteRtTensorBufferTypeHostMemory, - RankedTensorType(kTestTensorType), - sizeof(kTensorData))); - { - // Write with smaller size of data. - auto write_success = - tensor_buffer.Write(absl::MakeSpan(kTensorData, 1)); - ASSERT_TRUE(write_success); - } - { - constexpr const float big_data[] = {10, 20, 30, 40, 50}; - // Write with larger size of data. - auto write_success = - tensor_buffer.Write(absl::MakeSpan(big_data, 5)); - ASSERT_FALSE(write_success); - } - auto write_success = tensor_buffer.Write(absl::MakeSpan( - kTensorData, sizeof(kTensorData) / sizeof(kTensorData[0]))); - ASSERT_TRUE(write_success); - { - // Read with smaller size of buffer. - float read_data[1]; - auto read_success = tensor_buffer.Read(absl::MakeSpan(read_data, 1)); - ASSERT_TRUE(read_success); - ASSERT_EQ(read_data[0], kTensorData[0]); - } - { - // Read with larger size of buffer. - float read_data[5]; - auto read_success = tensor_buffer.Read(absl::MakeSpan(read_data, 5)); - ASSERT_FALSE(read_success); - } -} - -#if LITERT_HAS_OPENGL_SUPPORT -TEST(TensorBuffer, CreateFromGlTexture) { - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - // Create GL texture. - tflite::gpu::gl::GlTexture gl_texture(GL_TEXTURE_2D, 1, GL_RGBA8, 1, 1, - /*has_ownership=*/true); - ASSERT_TRUE(gl_texture.is_valid()); - - // Create tensor buffer from existing GL texture (e.g. this could be from - // Android Camera API). - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer tensor_buffer, - TensorBuffer::CreateFromGlTexture( - RankedTensorType(kTensorType), gl_texture.target(), gl_texture.id(), - gl_texture.format(), gl_texture.bytes_size(), gl_texture.layer())); -} - -tflite::gpu::gl::GlBuffer CreateTestGlBuffer(size_t size_bytes) { - tflite::gpu::gl::GlBuffer gl_buffer; - CHECK_OK(tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( - size_bytes, &gl_buffer)); - return gl_buffer; -} - -TEST(TensorBuffer, CreateFromGlBuffer) { - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - // Create GL buffer. - tflite::gpu::gl::GlBuffer gl_buffer = CreateTestGlBuffer(sizeof(kTensorData)); - - // Create tensor buffer from existing GL buffer. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer tensor_buffer, - TensorBuffer::CreateFromGlBuffer( - RankedTensorType(kTensorType), gl_buffer.target(), gl_buffer.id(), - gl_buffer.bytes_size(), gl_buffer.offset())); -} - -#if LITERT_HAS_AHWB_SUPPORT -TEST(TensorBuffer, GetGlBufferFromAhwb) { - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - // Create AHWB Tensor buffer. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer ahwb_tensor_buffer, - TensorBuffer::CreateManaged(kLiteRtTensorBufferTypeAhwb, - RankedTensorType(kTensorType), - sizeof(kTensorData))); - - // Write to AHWB Tensor buffer. - LITERT_ASSERT_OK(ahwb_tensor_buffer.Write(absl::MakeConstSpan( - kTensorData, sizeof(kTensorData) / sizeof(kTensorData[0])))); - - LITERT_ASSERT_OK_AND_ASSIGN(TensorBuffer::GlBuffer gl_buffer, - ahwb_tensor_buffer.GetGlBuffer()); - EXPECT_THAT(gl_buffer.target, Eq(GL_SHADER_STORAGE_BUFFER)); - EXPECT_THAT(gl_buffer.id, Ne(0)); - EXPECT_THAT(gl_buffer.size_bytes, Eq(sizeof(kTensorData))); - EXPECT_THAT(gl_buffer.offset, Eq(0)); - - // Read from GL buffer. - // TODO(gcarranza): Add GlBuffer ReadLock functionality to LiteRT - // TensorBuffer. GlBuffer::Unlock currently writes to GL buffer. - tflite::gpu::gl::GlBuffer gl_buffer_from_ahwb( - gl_buffer.target, gl_buffer.id, gl_buffer.size_bytes, gl_buffer.offset, - /*has_ownership=*/false); - float read_data[sizeof(kTensorData) / sizeof(kTensorData[0])]; - ASSERT_OK(gl_buffer_from_ahwb.Read(absl::MakeSpan(read_data))); - ASSERT_EQ(std::memcmp(read_data, kTensorData, sizeof(kTensorData)), 0); -} -#endif // LITERT_HAS_AHWB_SUPPORT - -#endif // LITERT_HAS_OPENGL_SUPPORT - -TEST(TensorBuffer, GetAhwb) { - if (!internal::AhwbBuffer::IsSupported()) { - GTEST_SKIP() << "AHardwareBuffers are not supported on this platform; " - "skipping the test"; - } - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer tensor_buffer, - TensorBuffer::CreateManaged(kLiteRtTensorBufferTypeAhwb, - RankedTensorType(kTestTensorType), - sizeof(kTensorData))); - LITERT_ASSERT_OK_AND_ASSIGN(AHardwareBuffer * ahwb, tensor_buffer.GetAhwb()); - EXPECT_THAT(ahwb, Ne(nullptr)); -} - -TEST(TensorBuffer, Event) { - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer tensor_buffer, - TensorBuffer::CreateManaged(kLiteRtTensorBufferTypeHostMemory, - RankedTensorType(kTestTensorType), - sizeof(kTensorData))); - // Create event. - LITERT_ASSERT_OK_AND_ASSIGN( - Event event, Event::CreateFromSyncFenceFd(kFakeSyncFenceFd, true)); - // Move event into tensor buffer. - LITERT_EXPECT_OK(tensor_buffer.SetEvent(std::move(event))); - EXPECT_TRUE(tensor_buffer.HasEvent()); - LITERT_ASSERT_OK_AND_ASSIGN(Event tensor_buffer_event, - tensor_buffer.GetEvent()); - LITERT_ASSERT_OK_AND_ASSIGN(int fence_fd, - tensor_buffer_event.GetSyncFenceFd()); - EXPECT_THAT(fence_fd, Eq(kFakeSyncFenceFd)); - // Clear event. - LITERT_ASSERT_OK(tensor_buffer.ClearEvent()); - EXPECT_FALSE(tensor_buffer.HasEvent()); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.cc b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.cc deleted file mode 100644 index 67ef66c34291d4..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.h" - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" - -std::string BufferTypeToString(LiteRtTensorBufferType buffer_type) { - switch (buffer_type) { - case kLiteRtTensorBufferTypeUnknown: - return "Unknown"; - case kLiteRtTensorBufferTypeHostMemory: - return "HostMemory"; - case kLiteRtTensorBufferTypeAhwb: - return "Ahwb"; - case kLiteRtTensorBufferTypeIon: - return "Ion"; - case kLiteRtTensorBufferTypeDmaBuf: - return "DmaBuf"; - case kLiteRtTensorBufferTypeFastRpc: - return "FastRpc"; - case kLiteRtTensorBufferTypeOpenCl: - return "OpenCl"; - case kLiteRtTensorBufferTypeGlBuffer: - return "GlBuffer"; - case kLiteRtTensorBufferTypeGlTexture: - return "GlTexture"; - } - LITERT_LOG(LITERT_ERROR, "Unexpected value for LiteRtTensorBufferType: %d", - static_cast(buffer_type)); - return "UnexpectedBufferType"; -} diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.h b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.h deleted file mode 100644 index a2ccf427211007..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_UTILS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_UTILS_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" - -std::string BufferTypeToString(LiteRtTensorBufferType buffer_type); - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_UTILS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/test_shared_library.cc b/tensorflow/lite/experimental/litert/cc/test_shared_library.cc deleted file mode 100644 index 37254390ea2018..00000000000000 --- a/tensorflow/lite/experimental/litert/cc/test_shared_library.cc +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -extern "C" { - -const char* TestFunction() { return "test_shared_library"; } - -} // extern "C" diff --git a/tensorflow/lite/experimental/litert/compiler/BUILD b/tensorflow/lite/experimental/litert/compiler/BUILD deleted file mode 100644 index 23b07d5602d7c8..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/BUILD b/tensorflow/lite/experimental/litert/compiler/plugin/BUILD deleted file mode 100644 index 77f23b34399cc7..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/BUILD +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "compiler_plugin", - srcs = ["compiler_plugin.cc"], - hdrs = ["compiler_plugin.h"], - deps = [ - ":algo", - ":compiler_flags", - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_detail", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_op_options", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core:dynamic_loading", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/experimental/litert/core:version", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:buffer_manager", - "//tensorflow/lite/experimental/litert/core/model:ir_allocator", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin_api", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - ], -) - -# copybara:uncomment_begin(no OSS for unique-test-directory) -# cc_test( -# name = "compiler_plugin_test", -# srcs = ["compiler_plugin_test.cc"], -# data = [ -# "//tensorflow/lite/experimental/litert/test:mlir_test_data", -# "//tensorflow/lite/experimental/litert/vendors/examples:example_plugin_so", -# ], -# tags = [ -# # Sanitizer runtimes are incompatible with RTLD_DEEPBIND. -# "noasan", -# "nomsan", -# "nosan", -# "notsan", -# ], -# deps = [ -# ":compiler_plugin", -# "@com_google_googletest//:gtest_main", -# "@com_google_absl//absl/strings:string_view", -# "//tensorflow/lite/experimental/litert/c:litert_common", -# "//tensorflow/lite/experimental/litert/c:litert_model", -# "//tensorflow/lite/experimental/litert/c:litert_op_code", -# "//tensorflow/lite/experimental/litert/cc:litert_environment", -# "//tensorflow/lite/experimental/litert/cc:litert_op_options", -# "//tensorflow/lite/experimental/litert/core:build_stamp", -# "//tensorflow/lite/experimental/litert/core:filesystem", -# "//tensorflow/lite/experimental/litert/core/model", -# "//tensorflow/lite/experimental/litert/test:common", -# "//tensorflow/lite/experimental/litert/test:matchers", -# "//tensorflow/lite/experimental/litert/tools:dump", -# ], -# ) -# copybara:uncomment_end - -cc_library( - name = "algo", - srcs = ["algo.cc"], - hdrs = ["algo.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/core:insert_order_map", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:model_graph", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:absl_check", - ], -) - -cc_test( - name = "algo_test", - srcs = ["algo_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - ], - deps = [ - ":algo", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:graph_validation", - "//tensorflow/lite/experimental/litert/test:common", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "compiler_flags", - srcs = ["compiler_flags.cc"], - hdrs = ["compiler_flags.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "compiler_flags_test", - srcs = ["compiler_flags_test.cc"], - deps = [ - ":compiler_flags", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/algo.cc b/tensorflow/lite/experimental/litert/compiler/plugin/algo.cc deleted file mode 100644 index eb36733486b3b2..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/algo.cc +++ /dev/null @@ -1,331 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/compiler/plugin/algo.h" - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/log/absl_check.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/core/insert_order_map.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" - -namespace litert::internal { -namespace { - -// -// flatlist to partition(s) -//===----------------------------------------------------------------------===// - -class DisjointSets { - public: - static std::vector> GetPartitionsFromFlatList( - const std::vector& flat_op_list); - - private: - void Insert(LiteRtOp op, LiteRtOp parent); - std::vector> GetBuckets(); - LiteRtOp GetBucket(LiteRtOp op); - InsertOrderMap map_; -}; - -//===----------------------------------------------------------------------===// -// LiteRt Core union-find algorithm. -// -// This algorithm is used to group partitions into sub DAGs. -// The input to the algorithm is a list of ops with the their partition index. -// -// [ (op_0, 0), -// (op_1, 0), -// (op_2, 0), -// ... -// (op_7, 1), -// (op_8, 1), ...] -// -// Union-find algorithm is run on each partition (list of ops with same -// partition index). -// -// For each partition, the input to the union find algorithm is a list of -// ops with the same partition index. For example, -// -// [ op_0, op_1, op_2, op3, op_4, op_5 ...] -// -// The output of the union find algorithm is a list of list of ops, where each -// list is a disjoint set(a sub DAG within the original Subgraph). For -// example, -// -// [ [op_0, op_1, op_6], -// [op_2, op_3], -// [op_4, op_5] ... ] -// -// Similarly, algorithm on the next parition would return something like -// -// [ [op_7, op_8, op_9], -// [op_10, op_11], -// [op_12, op_13] ... ] -// -// We aggregate all disjoint sets into the result buckets. For example, -// -// [ [op_0, op_1, op_6] -// [op_2, op_3] , -// [op_4, op_5], -// [op_7, op_8, op_9], -// [op_10, op_11], -// [op_12, op_13] ... ] -//===----------------------------------------------------------------------===// -std::vector> DisjointSets::GetPartitionsFromFlatList( - const std::vector& flat_op_list) { - // Find all unique partition indices. Use unique partition index as key and - // store the ops for each partition index as value of the map. - absl::flat_hash_map> partition_map; - for (int i = 0; i < flat_op_list.size(); ++i) { - partition_map[flat_op_list[i].second].push_back(flat_op_list[i].first); - } - - // A vector of disjoint sets, where each partition contains op with the same - // partition index. - std::vector partitions; - - // A vector of all unique partition indices for iterative access. We kept this - // vector so vendor plugin returned partition indices does not have to be - // zero-based. - std::vector flat_partition_indices; - for (auto& partition_index : partition_map) { - flat_partition_indices.push_back(partition_index.first); - } - - // Resize the partitions vector to the number of unique partition indices. - partitions.resize(flat_partition_indices.size()); - - // Resulting buckets of the union find algorithm. - std::vector> all_buckets; - - // Run union-find algorithm on each partition. - for (int i = 0; i < flat_partition_indices.size(); ++i) { - // For each partition, initialize the disjoint sets. - for (auto* op : partition_map[flat_partition_indices[i]]) { - partitions[i].map_.InsertOrAssign(op, op); - } - // For each partition, find all disjoint sets. - for (auto* op : partition_map[flat_partition_indices[i]]) { - for (auto* output : op->Outputs()) { - for (auto* user : output->Users()) { - if (!partitions[i].map_.Contains(user)) { - continue; - } - partitions[i].Insert(op, user); - } - } - } - // Aggregate all disjoint sets into the result buckets. - for (auto& bucket : partitions[i].GetBuckets()) { - all_buckets.push_back(std::move(bucket)); - } - } - return all_buckets; -} - -void DisjointSets::Insert(LiteRtOp op, LiteRtOp parent) { - auto* parent_bucket = GetBucket(parent); - auto* op_bucket = GetBucket(op); - if (op_bucket == parent_bucket) { - return; - } - map_.InsertOrAssign(op_bucket, parent_bucket); -} - -// Get all disjoint sets. -std::vector> DisjointSets::GetBuckets() { - // NOLINTBEGIN - std::unordered_map> invert_map; - // NOLINTEND - for (auto it = map_.Begin(); it != map_.End(); ++it) { - auto* bucket = GetBucket(it->first); - - if (invert_map.find(bucket) == invert_map.end()) { - invert_map.insert_or_assign(bucket, std::vector{}); - } - - invert_map[bucket].push_back(it->first); - } - - std::vector> res; - res.reserve(invert_map.size()); - - for (auto& entry : invert_map) { - res.push_back(std::move(entry.second)); - } - - return res; -} - -// Gets the pointer which serves as the key for given ops bucket. Collapses -// paths to amortize. -LiteRtOp DisjointSets::GetBucket(LiteRtOp op) { - auto it = map_.Find(op); - auto* parent = it->get().second; - if (op != parent) { - parent = GetBucket(parent); - map_.InsertOrAssign(op, parent); - } - return parent; -} - -// -// slice partitions out of a subgraph (into new subgraphs) -//===----------------------------------------------------------------------===// - -class GraphSlicer { - public: - // Slices "partitions" from "root" into the empty subgraph "slice". Assumes - // the partition is a valid sub-DAG, and replaces it witha single - // tfl.custom_op in "root". A reference to that op is returned. - static LiteRtOp SlicePartitionFromGraph(LiteRtSubgraphT& root, - LiteRtSubgraph slice, - std::vector& partition); - - private: - explicit GraphSlicer(LiteRtSubgraph slice) : slice_(slice) {} - - void CloneInto(const LiteRtOpT& op); - - void RerouteTensorsThroughCustomOp(const LiteRtSubgraphT& root); - - LiteRtSubgraph slice_; - // Maps tensor in old subgraph to tensor in new subgraph. - InsertOrderMap tensor_map_; - LiteRtOp dispatch_op_ = nullptr; -}; - -LiteRtOp GraphSlicer::SlicePartitionFromGraph( - LiteRtSubgraphT& root, LiteRtSubgraph slice, - std::vector& partition) { - GraphSlicer slicer(slice); - - // Register input tensors of the sliced partition WRT to their original order - // in the root subgraph. This ensures the order of input tensors of the - // later outlined custom op is the same as the order of input tensors of the - // GraphInputs. - absl::flat_hash_set used_tensors; - - // Get all tensors used in the partition. - for (auto* op : partition) { - used_tensors.insert(op->Inputs().cbegin(), op->Inputs().cend()); - } - for (auto* old_input : root.Inputs()) { - if (used_tensors.contains(old_input)) { - auto* new_input = &MakeClone(*slicer.slice_, *old_input); - slicer.slice_->Inputs().push_back(new_input); - slicer.tensor_map_.InsertOrAssign(old_input, new_input); - } - } - - for (auto* op : partition) { - slicer.CloneInto(*op); - } - - for (auto* op : partition) { - Drop(*op); - } - - // Reuse the storage from the last op in partition to maintain - // topological order. - slicer.dispatch_op_ = partition.back(); - - ABSL_DCHECK(slicer.dispatch_op_->Inputs().empty()); - ABSL_DCHECK(slicer.dispatch_op_->Outputs().empty()); - MakeDispatchOp(*slicer.dispatch_op_); - slicer.RerouteTensorsThroughCustomOp(root); - - DCE(root); - - return slicer.dispatch_op_; -} - -void GraphSlicer::RerouteTensorsThroughCustomOp(const LiteRtSubgraphT& root) { - for (auto it = tensor_map_.Begin(); it != tensor_map_.End(); ++it) { - auto* old_tensor = it->first; - auto* new_tensor = it->second; - - // Reroute tensors which need to be passed into the scope of the new - // subgraph to inputs of the custom op. - if (new_tensor->DefiningOp() == nullptr && !IsConstant(*new_tensor)) { - AttachInput(old_tensor, *dispatch_op_); - continue; - } - - // Reroute custom op as the definer of tensors within the removed partition - // and referenced later in the root graph. - if ((!old_tensor->Users().empty() && !IsConstant(*old_tensor)) || - FindOutput(root, *old_tensor)) { - AttachOutput(old_tensor, *dispatch_op_); - slice_->Outputs().push_back(new_tensor); - } - } -} - -void GraphSlicer::CloneInto(const LiteRtOpT& old_op) { - auto& new_op = MakeClone(*slice_, old_op); - - for (auto i = 0; i < old_op.NumInputs(); ++i) { - auto* old_input = old_op.Inputs().at(i); - LiteRtTensor new_input; - if (tensor_map_.Contains(old_input)) { - // If old_input is already in the map then map[input] is its cloned - // counterpart in the new graph. - auto it = tensor_map_.Find(old_input); - new_input = it->get().second; - } else { - // Otherwise, it must be a new subgraph input (or constant). - new_input = &MakeClone(*slice_, *old_input); - if (!IsConstant(*new_input)) { - slice_->Inputs().push_back(new_input); - } - - tensor_map_.InsertOrAssign(old_input, new_input); - } - - AttachInput(new_input, new_op); - } - - for (int i = 0; i < old_op.NumOutputs(); ++i) { - auto* old_output = old_op.Outputs().at(i); - auto* new_output = &MakeClone(*slice_, *old_output); - AttachOutput(new_output, new_op); - - // Update the values defined in scope of the new subgraph. - tensor_map_.InsertOrAssign(old_output, new_output); - } -} - -} // namespace - -std::vector> GroupPartitions( - const std::vector& ops) { - return DisjointSets::GetPartitionsFromFlatList(ops); -} - -LiteRtOp OutlinePartition(LiteRtSubgraphT& root, LiteRtSubgraph slice, - std::vector& partition) { - return GraphSlicer::SlicePartitionFromGraph(root, slice, partition); -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/algo.h b/tensorflow/lite/experimental/litert/compiler/plugin/algo.h deleted file mode 100644 index 8f82ca33ba0ffe..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/algo.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_ALGO_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_ALGO_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -// Identifies sub-DAGs of ops connected w.r.t. the use-def chain. Expects -// all "ops" belong to the same Subgraph. The ops in the input -// and output will always be the same. -std::vector> GroupPartitions( - const std::vector& ops); - -// Outlines "partition" from "root" into the empty subgraph "slice". Assumes -// the partition is a valid sub-DAG, and replaces it with a single -// tfl.custom_op in "root". A reference to that op is returned. -LiteRtOp OutlinePartition(LiteRtSubgraphT& root, LiteRtSubgraph slice, - std::vector& partition); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_ALGO_H_ diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/algo_test.cc b/tensorflow/lite/experimental/litert/compiler/plugin/algo_test.cc deleted file mode 100644 index f756f649520c77..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/algo_test.cc +++ /dev/null @@ -1,304 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/compiler/plugin/algo.h" - -#include - -#include -#include "absl/container/flat_hash_set.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" -#include "tensorflow/lite/experimental/litert/core/model/graph_validation.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" - -namespace litert::internal { -namespace { - -TEST(TestPartitionsFromFlatList, SimpleMultiOp) { - auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - - // func.func @main(arg0) - // 0 = tfl.add arg0, arg0 - // 1 = tfl.mul 0, 0 - // 2 = tfl.mul 1, 1 - // 3 = tfl.add 2, 2 - // return 3 - - { - std::vector selected_ops; - selected_ops.push_back({ops.at(1).Get(), 0}); - selected_ops.push_back({ops.at(2).Get(), 0}); - - auto partitions = GroupPartitions(selected_ops); - ASSERT_EQ(partitions.size(), 1); - ASSERT_EQ(partitions.front().size(), 2); - - EXPECT_EQ(partitions.front().at(0), selected_ops.at(0).first); - EXPECT_EQ(partitions.front().at(1), selected_ops.at(1).first); - } - - { - std::vector selected_ops; - selected_ops.push_back({ops.at(1).Get(), 0}); - selected_ops.push_back({ops.at(3).Get(), 0}); - - auto partitions = GroupPartitions(selected_ops); - ASSERT_EQ(partitions.size(), 2); - ASSERT_EQ(partitions.front().size(), 1); - ASSERT_EQ(partitions.back().size(), 1); - - auto p1_op_code = partitions.front().front()->OpCode(); - auto p2_op_code = partitions.back().front()->OpCode(); - - ASSERT_TRUE((p1_op_code == kLiteRtOpCodeTflMul && - p2_op_code == kLiteRtOpCodeTflAdd) || - (p1_op_code == kLiteRtOpCodeTflAdd && - p2_op_code == kLiteRtOpCodeTflMul)); - } - - { - std::vector selected_ops; - - auto partitions = GroupPartitions(selected_ops); - ASSERT_EQ(partitions.size(), 0); - } - - { - std::vector selected_ops; - selected_ops.push_back({ops.at(0).Get(), 0}); - selected_ops.push_back({ops.at(1).Get(), 0}); - selected_ops.push_back({ops.at(2).Get(), 0}); - selected_ops.push_back({ops.at(3).Get(), 0}); - - auto partitions = GroupPartitions(selected_ops); - ASSERT_EQ(partitions.size(), 1); - ASSERT_EQ(partitions.front().size(), 4); - - EXPECT_EQ(partitions.front().at(0), selected_ops.at(0).first); - EXPECT_EQ(partitions.front().at(1), selected_ops.at(1).first); - EXPECT_EQ(partitions.front().at(2), selected_ops.at(2).first); - EXPECT_EQ(partitions.front().at(3), selected_ops.at(3).first); - } -} - -TEST(TestSliceSubgraphSimpleMultiOp, OnePartition) { - auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - - // func.func @main(arg0) - // 0 = tfl.add arg0, arg0 - // 1 = tfl.mul 0, 0 - // 2 = tfl.mul 1, 1 - // 3 = tfl.add 2, 2 - // return 3 - - std::vector partition; - partition.push_back(ops.at(1).Get()); - partition.push_back(ops.at(2).Get()); - - auto sliced_graph = litert::Subgraph(&model.Get()->EmplaceSubgraph()); - auto* dispatch_op = - OutlinePartition(*subgraph->Get(), sliced_graph.Get(), partition); - - const auto& internal_sliced = *sliced_graph.Get(); - ASSERT_TRUE(ValidateSubgraphIO(internal_sliced)); - ASSERT_TRUE(ValidateLocalTopology(internal_sliced.Ops().cbegin(), - internal_sliced.Ops().cend())); - - auto edited_subgraph_ops = subgraph->Ops(); - - ASSERT_EQ(edited_subgraph_ops.size(), 3); - ASSERT_EQ(edited_subgraph_ops.at(0).Code(), kLiteRtOpCodeTflAdd); - ASSERT_EQ(edited_subgraph_ops.at(1).Code(), kLiteRtOpCodeTflCustom); - ASSERT_EQ(edited_subgraph_ops.at(2).Code(), kLiteRtOpCodeTflAdd); - - auto sliced_subgraph_ops = sliced_graph.Ops(); - - ASSERT_EQ(sliced_subgraph_ops.size(), 2); - ASSERT_EQ(sliced_subgraph_ops[0].Code(), kLiteRtOpCodeTflMul); - ASSERT_EQ(sliced_subgraph_ops[1].Code(), kLiteRtOpCodeTflMul); - - ASSERT_EQ(dispatch_op, edited_subgraph_ops.at(1).Get()); - const Op hal_call(dispatch_op); - - { - const auto dispatch_op_ins = hal_call.Inputs(); - - ASSERT_EQ(dispatch_op_ins.size(), 1); - - auto hal_input_defining_op = dispatch_op_ins.front().DefiningOp(); - ASSERT_EQ(hal_input_defining_op->op, edited_subgraph_ops.at(0).Get()); - ASSERT_EQ(hal_input_defining_op->op_output_index, 0); - - const auto sliced_subgraph_inputs = sliced_graph.Inputs(); - - ASSERT_EQ(sliced_subgraph_inputs.size(), 1); - - ASSERT_TRUE(MatchUses(sliced_subgraph_inputs.front(), - {UseInfo{sliced_subgraph_ops.front().Code(), 0}, - UseInfo{sliced_subgraph_ops.front().Code(), 0}})); - ASSERT_TRUE(sliced_subgraph_inputs.front().IsSubgraphInput()); - } - - { - const auto hal_call_outs = hal_call.Outputs(); - ASSERT_EQ(hal_call_outs.size(), 1); - const auto& hal_call_out = hal_call_outs.front(); - - ASSERT_TRUE(MatchUses(hal_call_out, - {UseInfo{edited_subgraph_ops.back().Code(), 0}, - UseInfo{edited_subgraph_ops.back().Code(), 1}})); - - auto sliced_subgraph_outputs = sliced_graph.Outputs(); - - ASSERT_EQ(sliced_subgraph_outputs.size(), 1); - - const auto defining_op = sliced_subgraph_outputs.front().DefiningOp(); - ASSERT_EQ(defining_op->op, sliced_subgraph_ops.back().Get()); - ASSERT_EQ(defining_op->op_output_index, 0); - - ASSERT_TRUE(sliced_subgraph_outputs.front().Uses().empty()); - } -} - -TEST(TestSliceSubgraphSimpleMultiOp, TwoPartitions) { - auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - - // func.func @main(arg0) - // 0 = tfl.add arg0, arg0 - // 1 = tfl.mul 0, 0 - // 2 = tfl.mul 1, 1 - // 3 = tfl.add 2, 2 - // return 3 - - std::vector partition_1; - partition_1.push_back(ops.at(0).Get()); - - auto sliced_graph_1 = litert::Subgraph(&model.Get()->EmplaceSubgraph()); - OutlinePartition(*(subgraph->Get()), sliced_graph_1.Get(), partition_1); - - const auto& internal_slice_1 = *sliced_graph_1.Get(); - ASSERT_TRUE(ValidateSubgraphIO(internal_slice_1)); - ASSERT_TRUE(ValidateLocalTopology(internal_slice_1.Ops().cbegin(), - internal_slice_1.Ops().cend())); - - std::vector partition_2; - partition_2.push_back(ops.at(2).Get()); - partition_2.push_back(ops.at(3).Get()); - - auto sliced_graph_2 = litert::Subgraph(&model.Get()->EmplaceSubgraph()); - OutlinePartition(*(subgraph->Get()), sliced_graph_2.Get(), partition_2); - - const auto& internal_slice_2 = *sliced_graph_2.Get(); - ASSERT_TRUE(ValidateSubgraphIO(internal_slice_2)); - ASSERT_TRUE(ValidateLocalTopology(internal_slice_2.Ops().cbegin(), - internal_slice_2.Ops().cend())); - - auto edited_subgraph_ops = subgraph->Ops(); - - ASSERT_EQ(edited_subgraph_ops.size(), 3); - ASSERT_EQ(edited_subgraph_ops.at(0).Code(), kLiteRtOpCodeTflCustom); - ASSERT_EQ(edited_subgraph_ops.at(1).Code(), kLiteRtOpCodeTflMul); - ASSERT_EQ(edited_subgraph_ops.at(2).Code(), kLiteRtOpCodeTflCustom); - - { - auto sliced_ops = sliced_graph_1.Ops(); - - ASSERT_EQ(sliced_ops.size(), 1); - ASSERT_EQ(sliced_ops.at(0).Code(), kLiteRtOpCodeTflAdd); - } - - { - auto sliced_ops = sliced_graph_2.Ops(); - - ASSERT_EQ(sliced_ops.size(), 2); - ASSERT_EQ(sliced_ops.at(0).Code(), kLiteRtOpCodeTflMul); - ASSERT_EQ(sliced_ops.at(1).Code(), kLiteRtOpCodeTflAdd); - } -} - -TEST(TestSliceSubgraphSimpleMultiOp, PartitionWithIndex) { - auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - - auto ops = subgraph->Ops(); - - // func.func @main(arg0) - // 0 = tfl.add arg0, arg0 - // 1 = tfl.mul 0, 0 - // 2 = tfl.mul 1, 1 - // 3 = tfl.add 2, 2 - // return 3 - - { - std::vector selected_ops; - selected_ops.push_back({ops.at(1).Get(), 0}); - selected_ops.push_back({ops.at(2).Get(), 1}); - - auto partitions = GroupPartitions(selected_ops); - ASSERT_EQ(partitions.size(), 2); - ASSERT_EQ(partitions.front().size(), 1); - ASSERT_EQ(partitions.back().size(), 1); - - absl::flat_hash_set ops_in_partition; - for (int i = 0; i < partitions.size(); ++i) { - for (const auto& op : partitions.at(i)) { - ops_in_partition.insert(op); - } - } - for (int i = 0; i < partitions.size(); ++i) { - EXPECT_TRUE(ops_in_partition.contains(selected_ops.at(i).first)); - } - } - - { - std::vector selected_ops; - selected_ops.push_back({ops.at(0).Get(), 1}); - selected_ops.push_back({ops.at(1).Get(), 2}); - selected_ops.push_back({ops.at(2).Get(), 3}); - selected_ops.push_back({ops.at(3).Get(), 4}); - - auto partitions = GroupPartitions(selected_ops); - ASSERT_EQ(partitions.size(), 4); - - absl::flat_hash_set ops_in_partition; - for (int i = 0; i < partitions.size(); ++i) { - for (const auto& op : partitions.at(i)) { - ops_in_partition.insert(op); - } - } - for (int i = 0; i < partitions.size(); ++i) { - EXPECT_TRUE(ops_in_partition.contains(selected_ops.at(i).first)); - } - } -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.cc b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.cc deleted file mode 100644 index 2cc0fea5f0909f..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.cc +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h" - -#include -#include -#include -#include - -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" - -namespace { -static constexpr absl::string_view kPairChar = "="; -static constexpr absl::string_view kDelim = ","; -} // namespace - -namespace litert::internal { - -void CompilerFlags::Clear() { - keys_.clear(); - values_.clear(); -} - -void CompilerFlags::Push(std::string key, std::string value) { - keys_.push_back(std::move(key)); - values_.push_back(std::move(value)); -} - -LiteRtStatus CompilerFlags::SetPluginFlags( - LiteRtCompilerPlugin handle, - decltype(LiteRtCompilerPluginSetFlags) set_flags) const { - std::vector keys(keys_.size()); - std::vector values(values_.size()); - for (auto i = 0; i < keys_.size(); ++i) { - keys[i] = keys_[i].c_str(); - values[i] = values_[i].c_str(); - } - return set_flags(handle, keys.size(), keys.data(), values.data()); -} - -Expected ParseCompilerFlags(absl::string_view flags_str) { - using KeyVal = std::pair; - - CompilerFlags result; - if (flags_str.empty()) { - return result; - } - - for (const auto flag : absl::StrSplit(flags_str, kDelim)) { - KeyVal key_value = absl::StrSplit(flag, absl::MaxSplits(kPairChar, 1)); - result.Push(std::move(key_value.first), std::move(key_value.second)); - } - - return result; -} - -} // namespace litert::internal - -std::ostream& operator<<(std::ostream& os, - const litert::internal::CompilerFlags& flags) { - for (auto i = 0; i < flags.keys_.size(); ++i) { - os << flags.keys_[i]; - const auto& value = flags.values_[i]; - if (!value.empty()) { - os << kPairChar << value; - } - if (i < flags.keys_.size() - 1) { - os << kDelim; - } - } - return os; -} diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h deleted file mode 100644 index 403ff1db527fa9..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_FLAGS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_FLAGS_H_ - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" - -namespace litert::internal { -class CompilerFlags; -} - -// For logging. -std::ostream& operator<<(std::ostream& os, - const litert::internal::CompilerFlags& flags); - -namespace litert::internal { - -class CompilerFlags { - public: - CompilerFlags() = default; - - // Clears all flags. - void Clear(); - - // Pushes a new flag to the end of the list. - void Push(std::string key, std::string value = ""); - - // Sets the flags on the given plugin. - LiteRtStatus SetPluginFlags( - LiteRtCompilerPlugin handle, - decltype(LiteRtCompilerPluginSetFlags) set_flags) const; - - private: - friend std::ostream& ::operator<<(std::ostream& os, - const CompilerFlags& flags); - - std::vector keys_; - std::vector values_; -}; - -// Parses a comma-separated (no space) list of compiler flags. Flags may be -// key-value pairs in the format of "key=value", or just "key". E.g. -// "key1=value1,key2". -Expected ParseCompilerFlags(absl::string_view flags_str); - -} // namespace litert::internal - -// For logging. - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_FLAGS_H_ diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags_test.cc b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags_test.cc deleted file mode 100644 index 0fcfdd72c52740..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags_test.cc +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utility types for mapping LiteRt IR to arbitrary backend specific -// types. Implementations of these types define mapping for ops and tensors -// that may be used in a stndalone fashion. They also may be composed -// to create lowerings of entire graphs with topology. - -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h" - -#include -#include -#include - -#include -#include -#include "absl/strings/str_cat.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" - -struct LiteRtCompilerPluginT { - using Flag = std::pair; - std::vector flags; -}; - -LiteRtStatus LiteRtCompilerPluginSetFlags(LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex num_flags, - const char** keys, - const char** values) { - auto& flags = compiler_plugin->flags; - flags.resize(num_flags); - for (int i = 0; i < num_flags; ++i) { - auto& flag = flags[i]; - flag.first = std::string(keys[i]); - flag.second = std::string(values[i]); - } - return kLiteRtStatusOk; -} - -namespace litert::internal { -namespace { - -using ::testing::ElementsAre; -using ::testing::Pair; - -TEST(CompilerFlagsTest, SetPluginFlags) { - static constexpr const char* kKey1 = "key1"; - static constexpr const char* kKey2 = "key2"; - static constexpr const char* kKey3 = "key3"; - static constexpr const char* kValue1 = "value1"; - static constexpr const char* kEmtpyVal = ""; - - LiteRtCompilerPluginT plugin; - CompilerFlags flags; - flags.Push(kKey1, kValue1); - flags.Push(kKey2, kEmtpyVal); - flags.Push(kKey3); - LITERT_ASSERT_OK(flags.SetPluginFlags(&plugin, LiteRtCompilerPluginSetFlags)); - - EXPECT_THAT(plugin.flags, - ElementsAre(Pair(kKey1, kValue1), Pair(kKey2, kEmtpyVal), - Pair(kKey3, kEmtpyVal))); -} - -TEST(CompilerFlagsTest, ParseCompilerFlags) { - static constexpr const char* kKey1 = "key1"; - static constexpr const char* kKey2 = "key2"; - static constexpr const char* kKey3 = "key3"; - static constexpr const char* kValue1 = "value1"; - static constexpr const char* kEmtpyVal = ""; - - const auto flags_str = - absl::StrCat(kKey1, "=", kValue1, ",", kKey2, "=", kEmtpyVal, ",", kKey3); - - LiteRtCompilerPluginT plugin; - CompilerFlags flags; - flags.Push(kKey1, kValue1); - flags.Push(kKey2, kEmtpyVal); - flags.Push(kKey3); - LITERT_ASSERT_OK(flags.SetPluginFlags(&plugin, LiteRtCompilerPluginSetFlags)); - - EXPECT_THAT(plugin.flags, - ElementsAre(Pair(kKey1, kValue1), Pair(kKey2, kEmtpyVal), - Pair(kKey3, kEmtpyVal))); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.cc b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.cc deleted file mode 100644 index d593840081e293..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.cc +++ /dev/null @@ -1,649 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/log/absl_check.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_op_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/algo.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/core/model/buffer_manager.h" -#include "tensorflow/lite/experimental/litert/core/model/ir_allocator.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/core/version.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h" - -namespace litert::internal { - -// -// CompiledResult -// - -Expected> CompiledResult::ByteCode( - LiteRtParamIndex byte_code_idx) const { - const void* data; - size_t size; - LITERT_RETURN_IF_ERROR(parent_.get_compiled_result_byte_code( - compiled_result_handle_, byte_code_idx, &data, &size)); - return BufferRef(data, size); -} - -Expected CompiledResult::NumByteCodeModules() const { - LiteRtParamIndex byte_code_idx; - LITERT_RETURN_IF_ERROR(parent_.get_compiled_result_num_byte_code( - compiled_result_handle_, &byte_code_idx)); - return byte_code_idx; -} - -Expected CompiledResult::NumCalls() const { - LiteRtParamIndex num_calls; - LITERT_RETURN_IF_ERROR(parent_.get_compiled_result_num_calls( - compiled_result_handle_, &num_calls)); - return num_calls; -} - -Expected CompiledResult::CallInfo(LiteRtParamIndex call_idx) const { - const void* data; - size_t size; - LiteRtParamIndex byte_code_idx; - - LITERT_RETURN_IF_ERROR(parent_.get_compiled_result_call_info( - compiled_result_handle_, call_idx, &data, &size, &byte_code_idx)); - - absl::string_view call_info_str(reinterpret_cast(data), size); - return ::litert::internal::CallInfo(call_info_str, byte_code_idx); -} - -CompiledResult::~CompiledResult() { - if (compiled_result_handle_ != nullptr) { - parent_.destroy_compiled_result(compiled_result_handle_); - } -} - -CompiledResult::CompiledResult(CompiledResult&& other) - : parent_(other.parent_), - compiled_result_handle_(other.compiled_result_handle_) { - other.parent_ = {}; - other.compiled_result_handle_ = nullptr; -} - -CompiledResult& CompiledResult::operator=(CompiledResult&& other) { - if (this != &other) { - parent_ = other.parent_; - other.parent_ = {}; - - compiled_result_handle_ = other.compiled_result_handle_; - other.compiled_result_handle_ = nullptr; - } - return *this; -} - -// -// CompilerPlugin -// - -namespace { - -#define RESOLVE_API_FUNC(name, dest) \ - LITERT_ASSIGN_OR_RETURN(dest, lib.LookupSymbol(name.data())); - -LiteRtStatus ResolvePluginApi(SharedLibrary& lib, - LiteRtCompilerPluginApi& result) { - RESOLVE_API_FUNC(kLiteRtGetCompilerPluginVersion, - result.get_compiler_plugin_version); - RESOLVE_API_FUNC(kLiteRtGetCompilerPluginSupportedHardware, - result.get_compiler_plugin_supported_hardware); - RESOLVE_API_FUNC(kLiteRtGetCompilerPluginSocManufacturer, - result.get_compiler_plugin_soc_manufacturer); - RESOLVE_API_FUNC(kLiteRtGetNumCompilerPluginSupportedSocModels, - result.get_num_compiler_plugin_supported_models); - RESOLVE_API_FUNC(kLiteRtGetCompilerPluginSupportedSocModel, - result.get_compiler_plugin_supported_soc_model); - - RESOLVE_API_FUNC(kLiteRtCreateCompilerPlugin, result.create_compiler_plugin); - RESOLVE_API_FUNC(kLiteRtDestroyCompilerPlugin, - result.destroy_compiler_plugin); - - RESOLVE_API_FUNC(kLiteRtCompilerPluginPartition, - result.compiler_plugin_partition); - RESOLVE_API_FUNC(kLiteRtCompilerPluginCompile, - result.compiler_plugin_compile); - - RESOLVE_API_FUNC(kLiteRtDestroyCompiledResult, - result.destroy_compiled_result); - RESOLVE_API_FUNC(kLiteRtCompiledResultNumByteCodeModules, - result.get_compiled_result_num_byte_code); - RESOLVE_API_FUNC(kLiteRtGetCompiledResultByteCode, - result.get_compiled_result_byte_code); - RESOLVE_API_FUNC(kLiteRtGetCompiledResultCallInfo, - result.get_compiled_result_call_info); - RESOLVE_API_FUNC(kLiteRtGetNumCompiledResultCalls, - result.get_compiled_result_num_calls); - RESOLVE_API_FUNC(kLiteRtCompilerPluginSetFlags, result.set_flags); - - return kLiteRtStatusOk; -} - -Expected> GetSocModels( - const LiteRtCompilerPluginApi& api, LiteRtCompilerPlugin plugin_handle) { - std::vector soc_models; - - LiteRtParamIndex num_models; - LITERT_RETURN_IF_ERROR( - api.get_num_compiler_plugin_supported_models(plugin_handle, &num_models)); - - for (LiteRtParamIndex i = 0; i < num_models; ++i) { - const char* model; - if (api.get_compiler_plugin_supported_soc_model(plugin_handle, i, &model) != - kLiteRtStatusOk) { - continue; - } - soc_models.push_back(std::string(model)); - } - - return soc_models; -} - -// Sort plugins so that we first apply those supporting NPU, then those -// supporting GPU, and finally those supporting CPU. -void SortPlugins(std::vector& compiler_plugins) { - std::sort(compiler_plugins.begin(), compiler_plugins.end(), - [](auto& x, auto& y) { - auto x_supported_hardware = x.SupportedHardware(); - auto y_supported_hardware = y.SupportedHardware(); - if (x_supported_hardware && y_supported_hardware) { - bool x_npu = (*x_supported_hardware & kLiteRtHwAcceleratorNpu); - bool x_gpu = (*x_supported_hardware & kLiteRtHwAcceleratorGpu); - bool x_cpu = (*x_supported_hardware & kLiteRtHwAcceleratorCpu); - bool y_npu = (*y_supported_hardware & kLiteRtHwAcceleratorNpu); - bool y_gpu = (*y_supported_hardware & kLiteRtHwAcceleratorGpu); - bool y_cpu = (*y_supported_hardware & kLiteRtHwAcceleratorCpu); - int x_score = 100 * x_npu + 10 * x_gpu + x_cpu; - int y_score = 100 * y_npu + 10 * y_gpu + y_cpu; - return x_score < y_score; - } - return true; - }); -} - -} // namespace - -Expected CompilerPlugin::LoadPlugin( - const absl::string_view lib_path) { - CompilerPlugin plugin; - LITERT_LOG(LITERT_INFO, "Loading plugin at: %s", lib_path.data()); - - LITERT_ASSIGN_OR_RETURN( - plugin.lib_, - SharedLibrary::Load(lib_path, RtldFlags::Now().Local().DeepBind())); - LITERT_LOG(LITERT_INFO, "Loaded plugin at: %s", lib_path.data()); - - LITERT_RETURN_IF_ERROR(ResolvePluginApi(plugin.lib_, plugin.plugin_api_)); - LITERT_LOG(LITERT_INFO, "Resolved plugin api at: %s", lib_path.data()); - - LITERT_RETURN_IF_ERROR( - plugin.plugin_api_.create_compiler_plugin(&plugin.plugin_handle_)); - LITERT_LOG(LITERT_INFO, "Initialize plugin at: %s", lib_path.data()); - - auto api_version = plugin.ApiVersion(); - if (!api_version) { - return api_version.Error(); - } - - LITERT_RETURN_IF_ERROR(litert::internal::IsSameVersionAsRuntime(*api_version), - Unexpected(kLiteRtStatusErrorWrongVersion, - "Unsupported compiler plugin version")); - - // This should never change throughout the lifetime of the compiler - // plugin so save to avoid recalling. - auto soc_models = GetSocModels(plugin.plugin_api_, plugin.plugin_handle_); - if (!soc_models) { - return soc_models.Error(); - } - plugin.soc_models_ = *soc_models; - - return plugin; -} - -Expected> CompilerPlugin::LoadPlugins( - absl::Span lib_search_paths) { - std::vector plugin_lib_paths; - for (auto search_path : lib_search_paths) { - // Skip paths that are not valid. - if (Exists(search_path)) { - LITERT_RETURN_IF_ERROR( - FindLiteRtCompilerPluginSharedLibs(search_path, plugin_lib_paths)); - } - } - - std::vector loaded_plugins; - loaded_plugins.reserve(lib_search_paths.size()); - - for (const auto& lib_path : plugin_lib_paths) { - LITERT_LOG(LITERT_INFO, "Loading plugin at: %s", lib_path.c_str()); - auto plugin = LoadPlugin(lib_path); - if (!plugin.HasValue()) { - continue; - } - loaded_plugins.push_back(std::move(plugin.Value())); - } - - // Sort plugins. - SortPlugins(loaded_plugins); - - return loaded_plugins; -} - -CompilerPlugin::CompilerPlugin(CompilerPlugin&& other) - : soc_models_(std::move(other.soc_models_)), - lib_(std::move(other.lib_)), - plugin_api_(std::move(other.plugin_api_)), - plugin_handle_(std::move(other.plugin_handle_)) { - other.soc_models_ = {}; - other.plugin_api_ = {}; - other.lib_.Close(); - other.plugin_handle_ = nullptr; -} - -CompilerPlugin& CompilerPlugin::operator=(CompilerPlugin&& other) { - if (this != &other) { - std::swap(soc_models_, other.soc_models_); - std::swap(lib_, other.lib_); - std::swap(plugin_api_, other.plugin_api_); - std::swap(plugin_handle_, other.plugin_handle_); - } - return *this; -} - -CompilerPlugin::~CompilerPlugin() { - if (plugin_handle_ != nullptr) { - plugin_api_.destroy_compiler_plugin(plugin_handle_); - } -} - -std::string CompilerPlugin::DebugString() const { - std::string version_str = "?"; - if (auto version = ApiVersion(); version) { - version_str = absl::StrFormat("%d.%d.%d", version->major, version->minor, - version->patch); - } - return absl::StrFormat("%s compiler plugin (ver %s)", SocManufacturer(), - version_str); -} - -Expected CompilerPlugin::ApiVersion() const { - LiteRtApiVersion api_version; - LITERT_RETURN_IF_ERROR(plugin_api_.get_compiler_plugin_version(&api_version)); - return api_version; -} - -Expected CompilerPlugin::SupportedHardware() const { - LiteRtHwAccelerators supported_hardware; - LITERT_RETURN_IF_ERROR(plugin_api_.get_compiler_plugin_supported_hardware( - plugin_handle_, &supported_hardware)); - return supported_hardware; -} - -Expected> CompilerPlugin::Partition( - const Subgraph& subgraph, absl::string_view soc_model) { - LiteRtOpListT ops; - const char* soc_model_str = !soc_model.empty() ? soc_model.data() : nullptr; - LITERT_RETURN_IF_ERROR(plugin_api_.compiler_plugin_partition( - plugin_handle_, soc_model_str, subgraph.Get(), &ops)); - return ops.Values(); -} - -Expected CompilerPlugin::Compile(LiteRtModel partitions, - absl::string_view soc_model) { - CompiledResult result = MakeResult(); - // If the user has passed an soc_model, then we use it; otherwise we let the - // backend pick the appropriate one by passing nullptr as soc_model. This is - // important for on-device compilation, where the backend must determine the - // SoC model based on the user device. - const char* soc_model_str = !soc_model.empty() ? soc_model.data() : nullptr; - LITERT_RETURN_IF_ERROR(plugin_api_.compiler_plugin_compile( - plugin_handle_, soc_model_str, partitions, - &result.compiled_result_handle_)); - return result; -} - -namespace { - -LiteRtStatus PartitionSubgraph( - std::vector selected_ops, - LiteRtSubgraphT& subgraph, PartitionResult& result, - BufferManager* buffer_manager) { - // Group selected ops into connected islands. - auto islands = GroupPartitions(selected_ops); - if (islands.empty()) { - return kLiteRtStatusOk; - } - - // For each connected island, slice into new subgraph and replace use with - // single dispatch op. - for (auto& island : islands) { - auto& new_subgraph = result.second.EmplaceBack(buffer_manager); - auto* dispatch_op = OutlinePartition(subgraph, &new_subgraph, island); - result.first.push_back(dispatch_op); - } - - return kLiteRtStatusOk; -} - -} // namespace - -Expected PartitionModel( - CompilerPlugin& compiler_plugin, LiteRtModelT& model, - const absl::flat_hash_set& subgraphs_to_partition) { - // This algorithm decides the subgraphs to be partitioned by the plugin. This - // is a trivial process with the exception of composite ops and their - // decomposition subgraphs. Currently, we deploy the most naive approach to - // handling composite ops. - // - // There are two cases to consider: - // 1. The composite op is an "odml.npu_call", in which case it represents a - // parition which was explictly requested by the model author. - // - // In this case, the the composite itself is always selected, regardless of - // whether the plugin selects it. Its subgraph is not passed to the partition - // function and it is passed in its entirety to the compilation function. - // - // More advanced behavior could include: - // * Ensuring the plugin can compile the entire partition, and inlining it if - // not. - // - // 2. Standard non npu_call composite ops. Currently these are treated as a - // regular op, and their decomposition subgraphs are completely ignored in all - // phases of plugin application. - // - // More advanced behavior could include: - // * Allowing the plugin to compile the decomposition subgraph in the case - // it cannot lower the composite directly. Potentially inline in this case - // contingent on the availability of a suitable CPU kernel for the composite - // op. - // - // ASSUMPTIONS: - // * npu_call ops ARE NOT nested within decompositions of other npu_call ops. - // * Standard composite ops ARE allowed to be nested within decompositions of - // npu_call ops. - // * No two npu_call ops share the same subgraph. - - // Find decomposition subgraphs and npu_call ops. These will be used to filter - // subgraphs passed to the plugin and pass on auto-selected npu_call - // partitions. - absl::flat_hash_set decomp_subgraphs; - std::vector npu_calls; - - ForEachIr(&model, [&](LiteRtOp op) { - auto info = GetOptionsAs(op); - if (!info) { - return; - } - decomp_subgraphs.insert(info->subgraph); - if (info->name == CompositeOptions::kNpuCall) { - npu_calls.push_back(std::move(*info)); - } - }); - - // Build partition result via calling plugin on non-decomposition subgraphs. - PartitionResult result; - for (auto i = 0; i < model.Subgraphs().size(); ++i) { - if (decomp_subgraphs.contains(i)) { - continue; - } - if (!subgraphs_to_partition.empty() && - !subgraphs_to_partition.contains(i)) { - continue; - } - auto* subgraph = model.Subgraphs()[i]; - auto selected_ops = compiler_plugin.Partition(Subgraph(subgraph)); - // TODO ensure selected ops don't contain npu_calls. - if (!selected_ops) { - return selected_ops.Error(); - } - auto num_selected_ops = selected_ops->size(); - auto num_ops = subgraph->Ops().size(); - - auto num_partitions = result.first.size(); - LITERT_RETURN_IF_ERROR(PartitionSubgraph( - std::move(*selected_ops), *subgraph, result, model.Buffers())); - num_partitions = result.first.size() - num_partitions; - LITERT_LOG(LITERT_INFO, - "PartitionSubgraph: %d, selected num ops: %lu, from totoal ops: " - "%lu, num partitions: %lu", - i, num_selected_ops, num_ops, num_partitions); - } - - // Add npu_call partitions to result. Update the npu_call ops to be dispatch - // ops. - std::vector decomps_to_compile; - for (auto& npu_call : npu_calls) { - auto* op = npu_call.op; - MakeDispatchOp(*op); - result.first.push_back(op); - decomps_to_compile.push_back(npu_call.subgraph); - } - model.TransferSubgraphTo(result.second, std::move(decomps_to_compile)); - - return result; -} - -Expected PartitionModelDirect( - std::vector selected_ops, LiteRtModelT& model) { - if (model.Subgraphs().size() != 1) { - // Only single subgraphs supported for direct partitioning. - return Unexpected(kLiteRtStatusErrorRuntimeFailure); - } - // Accumulate partition results for each subgraph in model. - PartitionResult result; - auto* subgraph = model.Subgraphs().front(); - LITERT_RETURN_IF_ERROR(PartitionSubgraph(std::move(selected_ops), *subgraph, - result, model.Buffers())); - ABSL_DCHECK_EQ(result.first.size(), result.second.Size()); - return result; -} - -Expected ApplyPluginWithPartition(CompilerPlugin& compiler_plugin, - LiteRtModelT& model, - PartitionResult partitions, - absl::string_view soc_model) { - auto& dispatch_ops = partitions.first; - auto& subgraphs = partitions.second; - - // Wrap the partitioned subgraphs in a LiteRtModel. - LiteRtModelT sliced_model; - sliced_model.TransferSubgraphsFrom(std::move(subgraphs)); - - // Copy op codes. - const auto& op_codes = litert::internal::GetTflOpCodes(model); - - LiteRtModelT::TflOpCodes codes; - codes.reserve(op_codes.size()); - for (const auto& op_code : op_codes) { - codes.emplace_back(std::make_unique(*op_code)); - } - - litert::internal::SetTflOpCodes(sliced_model, std::move(codes)); - - // Pass sliced subgraphs to plugin for compilation. - auto compiled_result = compiler_plugin.Compile(&sliced_model, soc_model); - if (!compiled_result) { - return compiled_result.Error(); - } - - // Register byte code buffers as external buffers. Map the byte code indices - // to the registered buffer ids. - auto num_byte_code = compiled_result->NumByteCodeModules(); - if (!num_byte_code) { - return num_byte_code.Error(); - } - - std::vector byte_code_idx_to_buf_id(*num_byte_code); - - for (auto i = 0; i < *num_byte_code; ++i) { - auto byte_code = compiled_result->ByteCode(i); - if (!byte_code) { - return byte_code.Error(); - } - - // TODO: This copy could probably be avoided. - OwningBufferRef owned_byte_code(byte_code->Data(), - byte_code->Size()); - const auto buf_id = - model.Buffers()->RegisterOwnedBuffer(std::move(owned_byte_code)); - - byte_code_idx_to_buf_id[i] = buf_id; - } - - // Register byte code buffers and add edges from dispatch ops to them. - for (auto i = 0; i < dispatch_ops.size(); ++i) { - auto* dispatch_op = dispatch_ops.at(i); - - auto call_info = compiled_result->CallInfo(i); - if (!call_info) { - return call_info.Error(); - } - auto [name, byte_code_idx] = *call_info; - const auto buf_id = byte_code_idx_to_buf_id[byte_code_idx]; - - model.AttachAssetToOp(dispatch_op, buf_id, std::string(name)); - } - - // Tag the model with make/model from the plugin. - auto build_stamp = - MakeBuildStamp(compiler_plugin.SocManufacturer(), soc_model); - if (!build_stamp) { - return build_stamp.Error(); - } - - if (auto status = - model.PushMetadata(kLiteRtBuildStampKey, std::move(*build_stamp)); - status != kLiteRtStatusOk) { - return Error(status); - } - - return {}; -} - -Expected ApplyPlugin( - CompilerPlugin& compiler_plugin, LiteRtModelT& model, - absl::string_view soc_model, - const absl::flat_hash_set& subgraphs_to_partition) { - // Collect partitions to pass to compilation. - auto partitions = - PartitionModel(compiler_plugin, model, subgraphs_to_partition); - if (!partitions) { - return partitions.Error(); - } - return ApplyPluginWithPartition(compiler_plugin, model, - std::move(*partitions), soc_model); -} - -Expected ApplyPlugins( - LiteRtEnvironment environment, LiteRtModel model, - LiteRtHwAcceleratorSet selected_hw_accelerators, bool* mutated) { - auto option = - environment->GetOption(kLiteRtEnvOptionTagCompilerPluginLibraryDir); - if (!option.has_value() || option->type != kLiteRtAnyTypeString) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Compiler plugin is not configured"); - } - std::string compiler_plugin_lib_path = option->str_value; - - const std::array - compiler_plugin_lib_search_paths = {compiler_plugin_lib_path}; - - auto compiler_plugins = litert::internal::CompilerPlugin::LoadPlugins( - compiler_plugin_lib_search_paths); - if (!compiler_plugins) { - return compiler_plugins.Error(); - } - if (compiler_plugins->empty()) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "No compiler plugin found"); - } - - std::vector success_messages; - std::vector error_messages; - - ApplyPluginsResult result; - result.num_applied_plugins = 0; - for (auto& compiler_plugin : *compiler_plugins) { - auto plugin_name = compiler_plugin.DebugString(); - - auto plugin_supported_hardware = compiler_plugin.SupportedHardware(); - if (!plugin_supported_hardware) { - error_messages.push_back(absl::StrCat( - plugin_name, " ", plugin_supported_hardware.Error().Message())); - continue; - } - - if (*plugin_supported_hardware & selected_hw_accelerators) { - auto status = ApplyPlugin(compiler_plugin, *model); - if (mutated != nullptr) { - *mutated = true; - } - if (!status) { - error_messages.push_back( - absl::StrCat(plugin_name, " ", status.Error().Message())); - continue; - } - - success_messages.push_back(absl::StrCat(plugin_name)); - result.num_applied_plugins++; - } - } - - result.success_message = absl::StrJoin(success_messages, ", "); - result.error_message = absl::StrJoin(error_messages, ", "); - - return result; -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h deleted file mode 100644 index 76c6ccbdc1b2df..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h" - -// C++ wrappers and high-level functions for managing compiler plugins -// and applying them to models. - -namespace litert::internal { - -// Name and index of byte code. -using CallInfo = std::pair; - -// Wraps vendor compiled result. Must be outlived by the CompilerPlugin -// the generated it. -class CompiledResult { - public: - friend class CompilerPlugin; - - // Number of byte code modules compiled by the plugin. - Expected NumByteCodeModules() const; - - // Get the single module of compiled byte code. This contains the - // compilation result for all entry points. - Expected> ByteCode( - LiteRtParamIndex byte_code_idx = 0) const; - - // Get information regarding the "ith" entry points in the compiled module. - // There will be oe entry point for each subgraph compiled for. - Expected CallInfo(LiteRtParamIndex call_idx) const; - - // Get the number of entry points in the compiled module. This will be equal - // to the number of subgraphs passed to the compilation step. - Expected NumCalls() const; - - explicit CompiledResult(const LiteRtCompilerPluginApi& parent) - : parent_(parent) {} - - CompiledResult(CompiledResult&& other); - CompiledResult& operator=(CompiledResult&& other); - CompiledResult(const CompiledResult& other) = delete; - CompiledResult& operator=(const CompiledResult& other) = delete; - - ~CompiledResult(); - - private: - LiteRtCompilerPluginApi parent_; - LiteRtCompiledResult compiled_result_handle_ = nullptr; -}; - -// Wraps vendor compiler plugin. -class CompilerPlugin { - public: - std::string DebugString() const; - - // Get the compiler plugin's API version. - Expected ApiVersion() const; - - // Get the supported HW accelerators (e.g., GPU, NPU). - Expected SupportedHardware() const; - - // Get the manufacturer associated with this plugin. NOTE: SocManufacturer - // string returned by the underlying plugin are expected to have static - // lifetime. - absl::string_view SocManufacturer() const { - return plugin_api_.get_compiler_plugin_soc_manufacturer(); - } - - // Get list of unique soc models targetable by this plugin. - const std::vector& SocModels() const { return soc_models_; } - - // Selects ops for the plugin to compile. - Expected> Partition( - const Subgraph& subgraph, absl::string_view soc_model = ""); - - // Compile given LiteRtSubgraphs. Result object must be outlived by - // this CompilerPlugin. - Expected Compile(LiteRtModel partitions, - absl::string_view soc_model = ""); - - // Search for shared library files with prefix "libLiteRtCompilerPlugin" in - // the directories passed through "lib_search_paths". Populates - // "loaded_plugins" with resolved plugin apis for each found library that can - // be successfully loaded. Additionally initializes the compiler plugin - // instances and stores handle. - static Expected> LoadPlugins( - absl::Span lib_search_paths); - - // Set compiler flags within the plugin. - LiteRtStatus SetFlags(const CompilerFlags& flags) { - return flags.SetPluginFlags(plugin_handle_, plugin_api_.set_flags); - } - - CompilerPlugin(CompilerPlugin&& other); - CompilerPlugin& operator=(CompilerPlugin&& other); - CompilerPlugin(const CompilerPlugin& other) = delete; - CompilerPlugin& operator=(const CompilerPlugin& other) = delete; - - // Destroys any living `LiteRtCompilerPlugin` and frees reference - // to dynamically loaded library. - ~CompilerPlugin(); - - private: - static Expected LoadPlugin(absl::string_view lib_path); - CompilerPlugin() = default; - - std::vector soc_models_; - SharedLibrary lib_; - LiteRtCompilerPluginApi plugin_api_ = {}; - LiteRtCompilerPlugin plugin_handle_ = nullptr; - - // Internal LiteRtCompiledResult wrapper. - - CompiledResult MakeResult() const { return CompiledResult(plugin_api_); } -}; - -// Higher level functions for applying plugin to graph. -//===--------------------------------------------------------------------------- - -// Dispatch op references and their subgraph to be compiled. -using PartitionResult = - std::pair, typename LiteRtSubgraphT::Alloc>; - -// Applies just the partition phase of the plugin on the model. Returns -// references newly allocated subgraphs removed from input and their -// corresponding dispatch ops in the input. -Expected PartitionModel( - CompilerPlugin& compiler_plugin, LiteRtModelT& model, - const absl::flat_hash_set& subgraphs_to_partition = {}); - -// Same as "PartitionModel" choose partitions directly based on the selected -// ops. Selected ops may contain any ops in the the main subgraph of the model. -// This function will separate them into DAGs and slice the model accordingly. -Expected PartitionModelDirect( - std::vector selected_ops, LiteRtModelT& model); - -// Applies both the partition and compile steps to the model. Generated -// byte_code will be internalized within the model for later serialization. -Expected ApplyPlugin( - CompilerPlugin& compiler_plugin, LiteRtModelT& model, - absl::string_view soc_model = "", - const absl::flat_hash_set& subgraphs_to_partition = {}); - -// Applies the compilation step to the model given a predetermined partition. -Expected ApplyPluginWithPartition(CompilerPlugin& compiler_plugin, - LiteRtModelT& model, - PartitionResult partitions, - absl::string_view soc_model = ""); - -// Apply all available plugins providing the selected HW accelerators to the -// given model, modify the model accordingly, and return (1) the number of -// compiler plugins successfully applied, (2) a string listing the compiler -// plugins that were successfully applied, and (3) a string listing the compiler -// plugins that failed to apply with an associated error message. This mutates -// the given model. -struct ApplyPluginsResult { - size_t num_applied_plugins; - std::string success_message; - std::string error_message; -}; - -Expected ApplyPlugins( - LiteRtEnvironment environment, LiteRtModel model, - LiteRtHwAcceleratorSet selected_hw_accelerators, bool* mutated = nullptr); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin_test.cc b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin_test.cc deleted file mode 100644 index 96403219ec93b3..00000000000000 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin_test.cc +++ /dev/null @@ -1,498 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" - -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_op_options.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/tools/dump.h" - -namespace litert::internal { -namespace { - -using testing::UniqueTestDirectory; - -constexpr absl::string_view kTestPluginSearchPath = - "third_party/tensorflow/lite/experimental/litert/vendors/examples"; - -constexpr absl::string_view kTestManufacturer = "ExampleSocManufacturer"; -constexpr absl::string_view kTestModels = "ExampleSocModel"; - -TEST(CompilerPluginTest, LoadTestPlugin) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - ASSERT_EQ(plugins->front().SocModels().size(), 1); - EXPECT_EQ(plugins->front().SocModels().front(), kTestModels); -} - -TEST(CompilerPluginTest, LoadTestPluginWithMalformed) { - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - Touch(Join({dir->Str(), "notLibLiteRt.so"})); - - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); -} - -TEST(CompilerPluginTest, MultipleValidPlugins) { - auto plugins = CompilerPlugin::LoadPlugins( - {kTestPluginSearchPath, kTestPluginSearchPath}); - - ASSERT_EQ(plugins->size(), 2); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - EXPECT_EQ(plugins->back().SocManufacturer(), kTestManufacturer); -} - -TEST(CompilerPluginTest, MoveAssign) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - - CompilerPlugin other = std::move(plugins->front()); - - EXPECT_EQ(other.SocManufacturer(), kTestManufacturer); -} - -TEST(CompilerPluginTest, MoveConstruct) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - - CompilerPlugin other(std::move(plugins->front())); - - EXPECT_EQ(other.SocManufacturer(), kTestManufacturer); -} - -TEST(CompilerPluginTest, SocModels) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - - EXPECT_THAT(plugins->front().SocModels(), - ::testing::ElementsAreArray({kTestModels})); -} - -TEST(CompilerPluginTest, SetFlags) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - LITERT_ASSERT_OK(plugins->front().SetFlags(CompilerFlags())); -} - -TEST(CompilerPluginTest, Partition) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - - auto model = testing::LoadTestFileModel("mul_simple.tflite"); - auto subgraph = model.MainSubgraph(); - auto ops = plugins->front().Partition(*subgraph); - ASSERT_TRUE(ops); - - EXPECT_EQ(ops->size(), 2); -} - -TEST(CompilerPluginTest, Compile) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - - auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); - auto& model = *model_wrap.Get(); - - auto result = plugins->front().Compile(&model); - ASSERT_TRUE(result); - - auto byte_code = result->ByteCode(); - ASSERT_TRUE(byte_code && byte_code->Size() > 0); - - auto num_calls = result->NumCalls(); - ASSERT_TRUE(num_calls); - ASSERT_EQ(*num_calls, 1); - - auto call_info = result->CallInfo(0); - ASSERT_TRUE(call_info); -} - -TEST(CompilerPluginTest, Dump) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - - std::stringstream dump; - Dump(plugins->front(), dump); - - ASSERT_EQ(dump.view(), - "SocManufacturer: ExampleSocManufacturer\nSocModels: { " - "ExampleSocModel }\n"); -} - -TEST(PartitionModelTest, Simple) { - auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); - auto& model = *model_wrap.Get(); - - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - auto& plugin = plugins->front(); - - auto partition_result = PartitionModel(plugin, model); - ASSERT_TRUE(partition_result); - ASSERT_EQ(model.NumSubgraphs(), 1); - - const auto& [ops, subgraphs] = *partition_result; - - EXPECT_EQ(ops.size(), 1); - EXPECT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - - EXPECT_EQ(subgraphs.Size(), 1); - EXPECT_EQ(subgraphs.Elements().front()->Ops().size(), 2); -} - -TEST(PartitionModelTest, PartitionDirect) { - auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); - auto& model = *model_wrap.Get(); - - std::vector selected_ops = { - {model.MainSubgraph()->Ops().front(), 0}, - {model.MainSubgraph()->Ops().back(), 0}}; - - auto partition_result = PartitionModelDirect(std::move(selected_ops), model); - ASSERT_TRUE(partition_result); - ASSERT_EQ(model.NumSubgraphs(), 1); - - const auto& [ops, subgraphs] = *partition_result; - - EXPECT_EQ(ops.size(), 1); - EXPECT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - - EXPECT_EQ(subgraphs.Size(), 1); - EXPECT_EQ(subgraphs.Elements().front()->Ops().size(), 2); -} - -TEST(PartitionModelTest, MultiSubgraph) { - auto model_wrap = testing::LoadTestFileModel("multi_subgraph_mul.tflite"); - auto& model = *model_wrap.Get(); - - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - auto& plugin = plugins->front(); - - auto partition_result = PartitionModel(plugin, model); - ASSERT_TRUE(partition_result); - ASSERT_EQ(model.NumSubgraphs(), 2); - - const auto& [ops, subgraphs] = *partition_result; - - EXPECT_EQ(ops.size(), 2); - EXPECT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_EQ(ops.back()->OpCode(), kLiteRtOpCodeTflCustom); - - EXPECT_EQ(subgraphs.Size(), 2); - EXPECT_EQ(subgraphs.Elements().front()->Ops().size(), 1); - EXPECT_EQ(subgraphs.Elements().back()->Ops().size(), 1); -} - -TEST(PartitionModelTest, MultiSubgraphWithSelectedSubgraphs) { - auto model_wrap = testing::LoadTestFileModel("multi_subgraph_mul.tflite"); - auto& model = *model_wrap.Get(); - - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - auto& plugin = plugins->front(); - - auto partition_result = PartitionModel(plugin, model, {1}); - ASSERT_TRUE(partition_result); - ASSERT_EQ(model.NumSubgraphs(), 2); - - const auto& [ops, subgraphs] = *partition_result; - - EXPECT_EQ(ops.size(), 1); - EXPECT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - - EXPECT_EQ(subgraphs.Size(), 1); - EXPECT_EQ(subgraphs.Elements().front()->Ops().size(), 1); -} - -TEST(PartitionModelTest, CstMultiSubgraph) { - auto model_wrap = testing::LoadTestFileModel("multi_use_cst.tflite"); - auto& model = *model_wrap.Get(); - ASSERT_EQ(model.MainSubgraph()->Ops().size(), 3); - - std::vector selected_ops = { - {model.MainSubgraph()->Ops().front(), 0}, - {model.MainSubgraph()->Ops().back(), 0}, - }; - auto partition_result = PartitionModelDirect(std::move(selected_ops), model); - ASSERT_TRUE(partition_result); - - const auto& [ops, subgraphs] = *partition_result; - - EXPECT_EQ(ops.size(), 2); - EXPECT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_EQ(ops.back()->OpCode(), kLiteRtOpCodeTflCustom); - - EXPECT_EQ(subgraphs.Size(), 2); - EXPECT_EQ(subgraphs.Elements().front()->Ops().size(), 1); - EXPECT_EQ(subgraphs.Elements().back()->Ops().size(), 1); - - const auto& cst_1 = - subgraphs.Elements().front()->Ops().front()->Input(1).Weights(); - const auto& cst_2 = - subgraphs.Elements().back()->Ops().front()->Input(1).Weights(); - - // Both weights should have the same object managed by the same buffer - // manager. - ASSERT_EQ(cst_1.GetBufferManager(), model.Buffers()); - ASSERT_EQ(cst_2.GetBufferManager(), model.Buffers()); - ASSERT_GT(cst_1.Buffer().Size(), 0); - ASSERT_GT(cst_2.Buffer().Size(), 0); - EXPECT_EQ(cst_1.GetBufferId(), cst_2.GetBufferId()); - ASSERT_EQ(cst_1.Buffer().Data(), cst_2.Buffer().Data()); -} - -TEST(ApplyTest, Simple) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - - ASSERT_TRUE(ApplyPlugin(plugins->front(), model)); - ASSERT_EQ(model.NumSubgraphs(), 1); - - auto& subgraph = *model.MainSubgraph(); - ASSERT_EQ(subgraph.Ops().size(), 1); - - auto* op = subgraph.Ops().front(); - - EXPECT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_TRUE(model.FindOpAsset(op)); - - EXPECT_TRUE(model.FindMetadata(kLiteRtBuildStampKey)); -} - -TEST(ApplyTest, WithPartition) { - auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); - auto& model = *model_wrap.Get(); - - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - auto& plugin = plugins->front(); - - auto partition_result = PartitionModel(plugin, model); - ASSERT_TRUE(partition_result); - ASSERT_EQ(model.NumSubgraphs(), 1); - - ASSERT_TRUE(ApplyPluginWithPartition(plugins->front(), model, - std::move(*partition_result))); - - auto& subgraph = model.Subgraph(0); - ASSERT_EQ(subgraph.Ops().size(), 1); - - auto* op = subgraph.Ops().front(); - - EXPECT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_TRUE(model.FindOpAsset(op)); -} - -TEST(ApplyTest, MultiSubgraph) { - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - ASSERT_EQ(plugins->size(), 1); - auto model_wrap = testing::LoadTestFileModel("multi_subgraph_mul.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - - ASSERT_TRUE(ApplyPlugin(plugins->front(), model)); - ASSERT_EQ(model.NumSubgraphs(), 2); - - { - auto& subgraph = model.Subgraph(0); - ASSERT_EQ(subgraph.Ops().size(), 1); - - auto* op = subgraph.Ops().front(); - - EXPECT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_TRUE(model.FindOpAsset(op)); - } - - { - auto& subgraph = model.Subgraph(1); - ASSERT_EQ(subgraph.Ops().size(), 1); - - auto* op = subgraph.Ops().front(); - - EXPECT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_TRUE(model.FindOpAsset(op)); - } - - EXPECT_TRUE(model.FindMetadata(kLiteRtBuildStampKey)); -} - -TEST(ApplyTest, ApplyPlugins) { - auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - - const std::array environment_options = { - litert::Environment::Option{ - /*.tag=*/litert::Environment::OptionTag::CompilerPluginLibraryDir, - /*.value=*/kTestPluginSearchPath, - }, - }; - auto env = litert::Environment::Create(environment_options); - ASSERT_TRUE(env); - - LiteRtHwAccelerators compilation_options = static_cast( - kLiteRtHwAcceleratorCpu | kLiteRtHwAcceleratorGpu | - kLiteRtHwAcceleratorNpu); - auto result = - litert::internal::ApplyPlugins(env->Get(), &model, compilation_options); - ASSERT_TRUE(result); - - ASSERT_EQ(model.NumSubgraphs(), 1); - - auto& subgraph = *model.MainSubgraph(); - ASSERT_EQ(subgraph.Ops().size(), 1); - - auto* op = subgraph.Ops().front(); - - EXPECT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - EXPECT_TRUE(model.FindOpAsset(op)); - - EXPECT_TRUE(model.FindMetadata(kLiteRtBuildStampKey)); -} - -TEST(PartitionTest, MappedCompositeOp) { - auto model_wrap = testing::LoadTestFileModel("rms_norm_composite.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - auto partition_result = PartitionModel(plugins->front(), model); - ASSERT_TRUE(partition_result); - // One new subgraph for the consumed composite op only, decomp not consumed. - ASSERT_EQ(partition_result->second.Size(), 1); -} - -TEST(PartitionTest, SimpleNpuCallComposite) { - auto model_wrap = testing::LoadTestFileModel("simple_composite.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - auto* decomp = model.Subgraphs()[1]; - - auto partition_result = PartitionModel(plugins->front(), model); - ASSERT_TRUE(partition_result); - - auto& ops = partition_result->first; - ASSERT_EQ(ops.size(), 1); - ASSERT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - - auto& sgs = partition_result->second; - ASSERT_EQ(sgs.Size(), 1); - ASSERT_EQ(sgs.Elements().front(), decomp); -} - -TEST(PartitionTest, MultiNpuCallComposite) { - auto model_wrap = testing::LoadTestFileModel("multi_composite.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - ASSERT_EQ(model.NumSubgraphs(), 4); - auto* decomp1 = model.Subgraphs()[1]; - auto* non_npu_call_decomop = model.Subgraphs()[2]; - auto* decomp2 = model.Subgraphs()[3]; - - auto partition_result = PartitionModel(plugins->front(), model); - ASSERT_TRUE(partition_result); - - { - // Subgraphs to be compiled will be moved to the result from the model. - // Non-npu-call decompositions will be reindexed. - ASSERT_EQ(model.NumSubgraphs(), 2); - ASSERT_EQ(model.Subgraphs()[1], non_npu_call_decomop); - auto opts = GetOptionsAs(model.Subgraph(0).Ops()[1]); - ASSERT_TRUE(opts); - ASSERT_EQ(opts->subgraph, 1); - } - - { - // All npu call ops are now dispatch ops. - auto& ops = partition_result->first; - - ASSERT_EQ(ops.size(), 2); - auto* first_dispatch_op = ops.front(); - auto* second_dispatch_op = ops.back(); - - ASSERT_EQ(first_dispatch_op->OpCode(), kLiteRtOpCodeTflCustom); - ASSERT_EQ(first_dispatch_op, model.Subgraphs()[0]->Ops().front()); - - ASSERT_EQ(second_dispatch_op->OpCode(), kLiteRtOpCodeTflCustom); - ASSERT_EQ(second_dispatch_op, model.Subgraphs()[0]->Ops().back()); - } - - { - // Bodies to compile are the decompositions of npu call ops. - auto& sgs = partition_result->second; - - ASSERT_EQ(sgs.Size(), 2); - ASSERT_EQ(sgs.Elements().front(), decomp1); - ASSERT_EQ(sgs.Elements().back(), decomp2); - } -} - -TEST(PartitionTest, NestedNpuCallComposite) { - auto model_wrap = testing::LoadTestFileModel("nested_composite.tflite"); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap.Get(); - auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); - - ASSERT_EQ(model.NumSubgraphs(), 3); - - auto partition_result = PartitionModel(plugins->front(), model); - ASSERT_TRUE(partition_result); - - auto& ops = partition_result->first; - ASSERT_EQ(ops.size(), 1); - ASSERT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); - - auto& sgs = partition_result->second; - ASSERT_EQ(sgs.Size(), 1); - ASSERT_EQ(sgs.Elements().front()->Op(0).OpCode(), kLiteRtOpCodeShloComposite); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/BUILD b/tensorflow/lite/experimental/litert/core/BUILD deleted file mode 100644 index 005ced8f23276c..00000000000000 --- a/tensorflow/lite/experimental/litert/core/BUILD +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - "//third_party/odml/infra/ml_drift_delegate/litert:__subpackages__", - ], -) - -cc_library( - name = "build_stamp", - srcs = ["build_stamp.cc"], - hdrs = ["build_stamp.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "build_stamp_test", - srcs = ["build_stamp_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - ], - deps = [ - ":build_stamp", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "dynamic_loading", - srcs = ["dynamic_loading.cc"], - hdrs = ["dynamic_loading.h"], - linkopts = ["-ldl"], - deps = [ - ":filesystem", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", # buildcleaner: keep - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "insert_order_map", - hdrs = ["insert_order_map.h"], - deps = [ - "@com_google_absl//absl/container:flat_hash_map", - ], -) - -cc_test( - name = "insert_order_map_test", - srcs = ["insert_order_map_test.cc"], - deps = [ - ":insert_order_map", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "environment", - srcs = ["environment.cc"], - hdrs = [ - "environment.h", - "//tensorflow/lite/experimental/litert/c:litert_environment.h", - ], - deps = [ - ":environment_options", - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment_options", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/runtime:accelerator_registry", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "environment_options", - srcs = ["environment_options.cc"], - hdrs = ["environment_options.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment_options_header", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - ], -) - -cc_test( - name = "environment_options_test", - srcs = ["environment_options_test.cc"], - deps = [ - ":environment_options", - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment_options", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "environment_test", - srcs = ["environment_test.cc"], - deps = [ - ":environment", - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "filesystem", - srcs = ["filesystem.cc"], - hdrs = ["filesystem.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_detail", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "dispatch_op_schema", - srcs = ["dispatch_op_schema.cc"], - hdrs = ["dispatch_op_schema.h"], - copts = ["-DFLATBUFFERS_LOCALE_INDEPENDENT=0"], - deps = [ - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "@flatbuffers//:runtime_cc", - ], -) - -cc_test( - name = "filesystem_test", - srcs = ["filesystem_test.cc"], - deps = [ - ":filesystem", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -# copybara:uncomment_begin(no OSS for unique-test-directory) -# cc_test( -# name = "dynamic_loading_test", -# srcs = ["dynamic_loading_test.cc"], -# tags = [ -# # Sanitizer runtimes are incompatible with RTLD_DEEPBIND. -# "noasan", -# "nomsan", -# "nosan", -# ], -# deps = [ -# ":dynamic_loading", -# ":filesystem", -# "@com_google_googletest//:gtest_main", -# "@com_google_absl//absl/strings:string_view", -# "//tensorflow/lite/experimental/litert/c:litert_logging", # buildcleaner: keep -# "//tensorflow/lite/experimental/litert/test:common", -# "//tensorflow/lite/experimental/litert/test:matchers", -# ], -# ) -# copybara:uncomment_end - -cc_test( - name = "dispatch_op_schema_test", - srcs = ["dispatch_op_schema_test.cc"], - deps = [ - ":dispatch_op_schema", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "version", - hdrs = ["version.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - ], -) diff --git a/tensorflow/lite/experimental/litert/core/build_stamp.cc b/tensorflow/lite/experimental/litert/core/build_stamp.cc deleted file mode 100644 index 9b7e942c36622f..00000000000000 --- a/tensorflow/lite/experimental/litert/core/build_stamp.cc +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -namespace { -// Simple metadata added to the flatbuffer related to compiler plugin. -struct BuildStamp { - char soc_manufacturer[kSocManufacturerMaxLen + 1] = {}; - char soc_model[kSocModelMaxLen + 1] = {}; -}; - -} // namespace - -Expected> MakeBuildStamp( - absl::string_view soc_manufacturer, absl::string_view soc_model) { - if (soc_manufacturer.size() >= kSocManufacturerMaxLen || - soc_model.size() >= kSocModelMaxLen) { - LITERT_LOG(LITERT_ERROR, "%s", "Soc Make/Model strings too large\n"); - return Unexpected(kLiteRtStatusErrorInvalidArgument); - } - BuildStamp stamp; - soc_manufacturer.copy(stamp.soc_manufacturer, soc_manufacturer.size()); - soc_model.copy(stamp.soc_model, soc_model.size()); - return OwningBufferRef(reinterpret_cast(&stamp), - sizeof(stamp)); -} - -// Parse a serialized build stamp from the given buf. -Expected> ParseBuildStamp( - BufferRef buf) { - if (buf.Size() != sizeof(BuildStamp)) { - LITERT_LOG(LITERT_ERROR, "%s", "Build stamp size mismatch\n"); - return Unexpected(kLiteRtStatusErrorInvalidArgument); - } - const BuildStamp* stamp = reinterpret_cast(buf.Data()); - return std::make_tuple(absl::string_view(stamp->soc_manufacturer), - absl::string_view(stamp->soc_model)); -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/build_stamp.h b/tensorflow/lite/experimental/litert/core/build_stamp.h deleted file mode 100644 index bf9ee91934e503..00000000000000 --- a/tensorflow/lite/experimental/litert/core/build_stamp.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_BUILD_STAMP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_BUILD_STAMP_H_ - -#include - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -// TODO update this library to use the flexbuffers api. - -// Shared "custom_code" for all dispatch ops. -static constexpr absl::string_view kLiteRtDispatchOpCustomCode = "DISPATCH_OP"; - -// -// Build Stamp -// - -// Maximum size of string for soc_manufacturer. -static constexpr size_t kSocManufacturerMaxLen = 124; - -// Maximum size of string for soc_model. -static constexpr size_t kSocModelMaxLen = 124; - -// Metadata key to lookup the build stamp. -static constexpr absl::string_view kLiteRtBuildStampKey = "LiteRtStamp"; - -// Make a serialized build stamp that can go directly in the flatbuffer. -Expected> MakeBuildStamp( - absl::string_view soc_manufacturer, absl::string_view soc_model); - -// Parse a serialized build stamp from the given buf. -Expected> ParseBuildStamp( - BufferRef buf); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_BUILD_STAMP_H_ diff --git a/tensorflow/lite/experimental/litert/core/build_stamp_test.cc b/tensorflow/lite/experimental/litert/core/build_stamp_test.cc deleted file mode 100644 index a0c3ce4fbbf1d0..00000000000000 --- a/tensorflow/lite/experimental/litert/core/build_stamp_test.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" - -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace litert::internal { - -namespace { - -using ::testing::litert::IsError; - -static constexpr absl::string_view kSocModel = "TestSocModel"; -static constexpr absl::string_view kSocMan = "TestSocMan"; - -TEST(TestBuildStamp, MakeBuildStampInputsTooLarge) { - // NOLINTNEXTLINE - std::string long_manufacturer(256, 'a'); - auto res = MakeBuildStamp(long_manufacturer, kSocModel); - EXPECT_THAT(res, IsError(kLiteRtStatusErrorInvalidArgument)); -} - -TEST(TestBuildStamp, MakeBuildStamp) { - auto stamp = MakeBuildStamp(kSocMan, kSocModel); - auto pstamp = ParseBuildStamp(*stamp); - auto [man, model] = *pstamp; - EXPECT_EQ(man, kSocMan); - EXPECT_EQ(model, kSocModel); -} - -} // namespace - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/dispatch_op_schema.cc b/tensorflow/lite/experimental/litert/core/dispatch_op_schema.cc deleted file mode 100644 index ed2226ef664cef..00000000000000 --- a/tensorflow/lite/experimental/litert/core/dispatch_op_schema.cc +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" - -#include -#include -#include -#include - -#include "flatbuffers/flexbuffers.h" // from @flatbuffers -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" - -namespace litert { -namespace internal { -namespace { - -static constexpr const char kBytecodeSizeKey[] = "bytecode_size"; -static constexpr const char kBytecodeOffsetKey[] = "bytecode_offset"; -static constexpr const char kNameKey[] = "name"; - -} // namespace - -OwningBufferRef MakeDispatchOpOptions(DispatchOpOptions options) { - flexbuffers::Builder fbb; - - // Set maximum width for scalars to 64 bits. This prevents any upsizing of - // the buffer when updating the bytecode size and offset in place. - fbb.ForceMinimumBitWidth(flexbuffers::BIT_WIDTH_64); - - auto start = fbb.StartMap(); - - fbb.UInt(kBytecodeSizeKey, options.bytecode_size); - fbb.UInt(kBytecodeOffsetKey, options.bytecode_offset); - fbb.String(kNameKey, options.name); - - fbb.EndMap(start); - fbb.Finish(); - - auto buf = fbb.GetBuffer(); - OwningBufferRef res; - res.Assign(buf.data(), buf.size()); - - return res; -} - -bool UpdateDispatchOpOptionsInPlace(DispatchOpOptions options, - MutableBufferRef buffer) { - auto opts = flexbuffers::GetRoot(buffer.Data(), buffer.Size()).AsMap(); - - // Update name if same len. - const auto name_ok = opts[kNameKey].MutateString(options.name); - - // Update bytecode size and offset. Since min scalar bit width is set to max - // possible value, it shouldn't fail in theory. - const auto size_ok = opts[kBytecodeSizeKey].MutateUInt(options.bytecode_size); - const auto offset_ok = - opts[kBytecodeOffsetKey].MutateUInt(options.bytecode_offset); - - return name_ok && size_ok && offset_ok; -} - -DispatchOpOptions GetDispatchOpOptions(BufferRef buffer) { - const auto opts = flexbuffers::GetRoot(buffer.Data(), buffer.Size()).AsMap(); - - const size_t bytecode_size = opts[kBytecodeSizeKey].AsUInt64(); - const size_t bytecode_offset = opts[kBytecodeOffsetKey].AsUInt64(); - std::string name(opts[kNameKey].AsString().c_str()); - - return DispatchOpOptions{ - bytecode_size, - bytecode_offset, - std::move(name), - }; -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/dispatch_op_schema.h b/tensorflow/lite/experimental/litert/core/dispatch_op_schema.h deleted file mode 100644 index a6f6eb9216caba..00000000000000 --- a/tensorflow/lite/experimental/litert/core/dispatch_op_schema.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DISPATCH_OP_SCHEMA_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DISPATCH_OP_SCHEMA_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" - -// Utilities for working with the dispatch op custom options buffer. These -// functions leverage the flexbuffer api under the hood which allows for inplace -// updates. - -namespace litert::internal { - -// Schema representing the custom options data for dispatch ops. Primarly used -// to for tracking location of bytecode. -struct DispatchOpOptions { - // The size of the bytecode for the dispatch op. - size_t bytecode_size; - - // The offset of the bytecode for the dispatch op relative to the start of the - // model file. - size_t bytecode_offset; - - // Name of specific dispatch op or entry point to be called in a shared - // bytecode module. - std::string name; -}; - -// Get a serialized representation of the dispatch op options. These should -// be stored directly in the custom options of the dispatch op. -OwningBufferRef MakeDispatchOpOptions(DispatchOpOptions options); - -// Update the dispatch op options in the given buffer with the given options. -// The buffer should be the custom options buffer of the dispatch op. Fails if -// the passed values would resize the buffer. -bool UpdateDispatchOpOptionsInPlace(DispatchOpOptions options, - MutableBufferRef buffer); - -// Get the dispatch op options from the given buffer. The buffer should be the -// custom options buffer of the dispatch op. -DispatchOpOptions GetDispatchOpOptions(BufferRef buffer); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DISPATCH_OP_SCHEMA_H_ diff --git a/tensorflow/lite/experimental/litert/core/dispatch_op_schema_test.cc b/tensorflow/lite/experimental/litert/core/dispatch_op_schema_test.cc deleted file mode 100644 index 53f784b50a674e..00000000000000 --- a/tensorflow/lite/experimental/litert/core/dispatch_op_schema_test.cc +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" - -#include - -#include - -namespace litert { -namespace internal { -namespace { - -static constexpr size_t kBufferSize = 100; -static constexpr size_t kBufferOffset = 200; -static constexpr const char kName[] = "test_name"; - -TEST(DispatchOpSchemaTest, DispatchOpOptions) { - DispatchOpOptions options = { - kBufferSize, - kBufferOffset, - kName, - }; - - auto buffer = MakeDispatchOpOptions(options); - ASSERT_GT(buffer.Size(), 0); - - auto parsed_options = GetDispatchOpOptions(buffer); - ASSERT_EQ(parsed_options.bytecode_size, kBufferSize); - ASSERT_EQ(parsed_options.bytecode_offset, kBufferOffset); - ASSERT_EQ(parsed_options.name, kName); -} - -TEST(DispatchOpSchemaTest, UpdateDispatchOpOptions) { - DispatchOpOptions options = { - kBufferSize, - kBufferOffset, - kName, - }; - - auto buffer = MakeDispatchOpOptions(options); - ASSERT_GT(buffer.Size(), 0); - - static constexpr size_t kNewBufferSize = 1000; - static constexpr size_t kNewBufferOffset = 2000; - - DispatchOpOptions new_options = { - kNewBufferSize, - kNewBufferOffset, - kName, - }; - - ASSERT_TRUE(UpdateDispatchOpOptionsInPlace(new_options, buffer)); - - auto parsed_options = GetDispatchOpOptions(buffer); - ASSERT_EQ(parsed_options.bytecode_size, kNewBufferSize); - ASSERT_EQ(parsed_options.bytecode_offset, kNewBufferOffset); - ASSERT_EQ(parsed_options.name, kName); -} - -} // namespace -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/dynamic_loading.cc b/tensorflow/lite/experimental/litert/core/dynamic_loading.cc deleted file mode 100644 index 37c4ef2040dd86..00000000000000 --- a/tensorflow/lite/experimental/litert/core/dynamic_loading.cc +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" - -#include -#include - -// clang-format off -#ifndef __ANDROID__ -#if __has_include() -#include -#endif -#endif -// clang-format on - -#include -#include // NOLINT -#include -#include - -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" - -namespace litert::internal { - -namespace { - -static constexpr absl::string_view kLdLibraryPath = "LD_LIBRARY_PATH"; - -bool EnvPathContains(absl::string_view path, absl::string_view var_value) { - return absl::EndsWith(var_value, path) || - absl::StrContains(var_value, absl::StrCat(path, ":")); -} - -} // namespace - -static constexpr absl::string_view kSo = ".so"; - -LiteRtStatus FindLiteRtSharedLibsHelper(const std::string& search_path, - const std::string& lib_pattern, - bool full_match, - std::vector& results) { - if (!Exists(search_path)) { - return kLiteRtStatusErrorInvalidArgument; - } - - // TODO implement path glob in core/filesystem.h and remove filesystem - // include from this file. - for (const auto& entry : std::filesystem::directory_iterator( - search_path, - std::filesystem::directory_options::skip_permission_denied)) { - const auto& path = entry.path(); - if (access(path.c_str(), R_OK) != 0) { - continue; - } - if (entry.is_regular_file()) { - if (full_match) { - if (path.string().find(lib_pattern) != -1) { - LITERT_LOG(LITERT_VERBOSE, "Found shared library: %s", path.c_str()); - results.push_back(path); - } - } else { - const auto stem = path.stem().string(); - const auto ext = path.extension().string(); - if (stem.find(lib_pattern) == 0 && kSo == ext) { - LITERT_LOG(LITERT_VERBOSE, "Found shared library: %s", path.c_str()); - results.push_back(path); - } - } - } else if (entry.is_directory()) { - FindLiteRtSharedLibsHelper(path, lib_pattern, full_match, results); - } - } - - return kLiteRtStatusOk; -} - -static const char kCompilerPluginLibPatternFmt[] = "CompilerPlugin"; - -LiteRtStatus FindLiteRtCompilerPluginSharedLibs( - absl::string_view search_path, std::vector& results) { - std::string root(search_path); - const std::string lib_pattern = - absl::StrCat(kLiteRtSharedLibPrefix, kCompilerPluginLibPatternFmt); - return FindLiteRtSharedLibsHelper(root, lib_pattern, /*full_match=*/false, - results); -} - -static const char kDispatchLibPatternFmt[] = "Dispatch"; - -LiteRtStatus FindLiteRtDispatchSharedLibs(absl::string_view search_path, - std::vector& results) { - std::string root(search_path.data()); - const std::string lib_pattern = - absl::StrCat(kLiteRtSharedLibPrefix, kDispatchLibPatternFmt); - return FindLiteRtSharedLibsHelper(root, lib_pattern, /*full_match=*/false, - results); -} - -LiteRtStatus PutLibOnLdPath(absl::string_view search_path, - absl::string_view lib_pattern) { - std::vector results; - LITERT_RETURN_IF_ERROR(FindLiteRtSharedLibsHelper( - std::string(search_path), std::string(lib_pattern), true, results)); - if (results.empty()) { - LITERT_LOG(LITERT_INFO, "No match found in %s", search_path.data()); - return kLiteRtStatusOk; - } - - const auto lib_dir = std::filesystem::path(results[0]).parent_path().string(); - absl::string_view ld = getenv(kLdLibraryPath.data()); - - if (EnvPathContains(lib_dir, ld)) { - LITERT_LOG(LITERT_INFO, "dir already in LD_LIBRARY_PATH"); - return kLiteRtStatusOk; - } - - std::string new_ld; - if (ld.empty()) { - new_ld = lib_dir; - } else { - new_ld = absl::StrCat(ld, ":", lib_dir); - } - - LITERT_LOG(LITERT_INFO, "Adding %s to LD_LIBRARY_PATH", new_ld.c_str()); - setenv(kLdLibraryPath.data(), new_ld.c_str(), /*overwrite=*/1); - - return kLiteRtStatusOk; -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/dynamic_loading.h b/tensorflow/lite/experimental/litert/core/dynamic_loading.h deleted file mode 100644 index 5f44f5aafaa851..00000000000000 --- a/tensorflow/lite/experimental/litert/core/dynamic_loading.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DYNAMIC_LOADING_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DYNAMIC_LOADING_H_ - -#include -#include - -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" - -namespace litert::internal { - -constexpr absl::string_view kLiteRtSharedLibPrefix = "libLiteRt"; - -// Loads shared library at given path. Logging can be disabled to probe for -// shared libraries. -LiteRtStatus OpenLib(absl::string_view so_path, void** lib_handle, - bool log_failure = true); - -// Find all litert shared libraries in "search_path" and return -// kLiteRtStatusErrorInvalidArgument if the provided search_path doesn't -// exist. All internal dynamically linked dependencies for litert should be -// prefixed with "libLiteRtCompilerPlugin". -LiteRtStatus FindLiteRtCompilerPluginSharedLibs( - absl::string_view search_path, std::vector& results); - -// Find all litert shared libraries in "search_path" and return -// kLiteRtStatusErrorInvalidArgument if the provided search_path doesn't -// exist. All internal dynamically linked dependencies for litert should be -// prefixed with "libLiteRtDispatch". -LiteRtStatus FindLiteRtDispatchSharedLibs(absl::string_view search_path, - std::vector& results); - -// Find shared libraries for a given pattern in "search_path" and return -// kLiteRtStatusErrorInvalidArgument if the provided search_path doesn't -// exist. -LiteRtStatus FindLiteRtSharedLibsHelper(const std::string& search_path, - const std::string& lib_pattern, - bool full_match, - std::vector& results); - -// Analogous to the above, but the first match identified, its immeidate parent -// directory will be appended to the LD_LIBRARY_PATH. -LiteRtStatus PutLibOnLdPath(absl::string_view search_path, - absl::string_view lib_pattern); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DYNAMIC_LOADING_H_ diff --git a/tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc b/tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc deleted file mode 100644 index 000e33947fb790..00000000000000 --- a/tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" - -#include -#include // NOLINT -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace litert::internal { -namespace { - -using litert::testing::UniqueTestDirectory; -using ::testing::Contains; -using ::testing::HasSubstr; - -constexpr absl::string_view kNotLiteRtSo = "notLibLiteRt.so"; -constexpr absl::string_view kLiteRtSo1 = "libLiteRtCompilerPlugin_1.so"; -constexpr absl::string_view kLiteRtSo2 = "libLiteRtCompilerPlugin_2.so"; -constexpr absl::string_view kLiteRtSo3 = "libLiteRtDispatch_1.so"; -constexpr absl::string_view kLiteRtSo4 = "libLiteRtDispatch_2.so"; -constexpr absl::string_view kLdLibraryPath = "LD_LIBRARY_PATH"; - -TEST(TestDynamicLoading, GlobNoMatch) { - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - Touch(Join({dir->Str(), kNotLiteRtSo})); - - std::vector results; - LITERT_ASSERT_OK(litert::internal::FindLiteRtCompilerPluginSharedLibs( - dir->Str(), results)); - EXPECT_EQ(results.size(), 0); - std::vector results2; - LITERT_ASSERT_OK( - litert::internal::FindLiteRtDispatchSharedLibs(dir->Str(), results2)); - EXPECT_EQ(results2.size(), 0); -} - -TEST(TestDynamicLoading, GlobOneMatch) { - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - Touch(Join({dir->Str(), kLiteRtSo1})); - Touch(Join({dir->Str(), kLiteRtSo3})); - Touch(Join({dir->Str(), kNotLiteRtSo})); - - std::vector results; - LITERT_ASSERT_OK(litert::internal::FindLiteRtCompilerPluginSharedLibs( - dir->Str(), results)); - ASSERT_EQ(results.size(), 1); - EXPECT_TRUE(absl::string_view(results.front()).ends_with(kLiteRtSo1)); - - std::vector results2; - LITERT_ASSERT_OK( - litert::internal::FindLiteRtDispatchSharedLibs(dir->Str(), results2)); - ASSERT_EQ(results2.size(), 1); - EXPECT_TRUE(absl::string_view(results2.front()).ends_with(kLiteRtSo3)); -} - -TEST(TestDynamicLoading, GlobMultiMatch) { - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - Touch(Join({dir->Str(), kLiteRtSo1})); - Touch(Join({dir->Str(), kLiteRtSo2})); - Touch(Join({dir->Str(), kLiteRtSo3})); - Touch(Join({dir->Str(), kLiteRtSo4})); - Touch(Join({dir->Str(), kNotLiteRtSo})); - - std::vector results; - LITERT_ASSERT_OK(litert::internal::FindLiteRtCompilerPluginSharedLibs( - dir->Str(), results)); - ASSERT_EQ(results.size(), 2); - EXPECT_THAT(results, Contains(HasSubstr(kLiteRtSo1))); - EXPECT_THAT(results, Contains(HasSubstr(kLiteRtSo2))); - - std::vector results2; - LITERT_ASSERT_OK( - litert::internal::FindLiteRtDispatchSharedLibs(dir->Str(), results2)); - ASSERT_EQ(results2.size(), 2); - EXPECT_THAT(results2, Contains(HasSubstr(kLiteRtSo3))); - EXPECT_THAT(results2, Contains(HasSubstr(kLiteRtSo4))); -} - -TEST(TestDynamicLoadingHelper, HelperWithFullMatch) { - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - Touch(Join({dir->Str(), kLiteRtSo1})); - Touch(Join({dir->Str(), kLiteRtSo2})); - Touch(Join({dir->Str(), kLiteRtSo3})); - Touch(Join({dir->Str(), kLiteRtSo4})); - Touch(Join({dir->Str(), kNotLiteRtSo})); - - std::vector results; - LITERT_ASSERT_OK(litert::internal::FindLiteRtSharedLibsHelper( - std::string(dir->Str()), std::string(kLiteRtSo4), true, results)); - ASSERT_EQ(results.size(), 1); - EXPECT_THAT(results, Contains(HasSubstr(kLiteRtSo4))); -} - -TEST(TestPutLibOnLdPath, AppendToEmptyLdPath) { - unsetenv(kLdLibraryPath.data()); - - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - - const auto dir_path = dir->Str(); - const auto lib_path = Join({dir_path, kLiteRtSo1}); - Touch(lib_path); - - LITERT_ASSERT_OK(PutLibOnLdPath(dir_path, kLiteRtSo1)); - absl::string_view ld_library_path = getenv(kLdLibraryPath.data()); - EXPECT_THAT(ld_library_path, HasSubstr(dir_path)); -} - -TEST(TestPutLibOnLdPath, AppendToLdPathNoMatch) { - unsetenv(kLdLibraryPath.data()); - - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - - const auto dir_path = dir->Str(); - - LITERT_ASSERT_OK(PutLibOnLdPath(dir_path, kLiteRtSo1)); - ASSERT_EQ(getenv(kLdLibraryPath.data()), nullptr); -} - -TEST(TestPutLibOnLdPath, AppendToExistingLdPath) { - static constexpr absl::string_view kExistingLdPath = "an/existing/path"; - - unsetenv(kLdLibraryPath.data()); - setenv(kLdLibraryPath.data(), kExistingLdPath.data(), /*overwrite=*/1); - - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - - const auto dir_path = dir->Str(); - const auto lib_path = Join({dir_path, kLiteRtSo1}); - Touch(lib_path); - - LITERT_ASSERT_OK(PutLibOnLdPath(dir_path, kLiteRtSo1)); - absl::string_view ld_library_path = getenv(kLdLibraryPath.data()); - EXPECT_THAT(ld_library_path, HasSubstr(dir_path)); - EXPECT_THAT(ld_library_path, HasSubstr(kExistingLdPath)); -} - -TEST(TestPutLibOnLdPath, AppendToLdLibraryPathNoDupePath) { - unsetenv(kLdLibraryPath.data()); - - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - - const auto dir_path = dir->Str(); - const auto lib_path = Join({dir_path, kLiteRtSo1}); - Touch(lib_path); - - setenv(kLdLibraryPath.data(), dir_path.data(), /*overwrite=*/1); - - LITERT_ASSERT_OK(PutLibOnLdPath(dir_path, kLiteRtSo1)); - absl::string_view ld_library_path = getenv(kLdLibraryPath.data()); - EXPECT_THAT(ld_library_path, HasSubstr(dir_path)); - EXPECT_EQ(ld_library_path.size(), dir_path.size()); -} - -TEST(TestPutLibOnLdPath, AppendToNestedLdPath) { - unsetenv(kLdLibraryPath.data()); - - const auto dir = UniqueTestDirectory::Create(); - ASSERT_TRUE(dir); - - const auto dir_path = dir->Str(); - const auto nested_dir_path = Join({dir_path, "another/dir"}); - const auto lib_path = Join({nested_dir_path, kLiteRtSo1}); - ASSERT_TRUE(std::filesystem::create_directories(nested_dir_path)); - Touch(lib_path); - - LITERT_ASSERT_OK(PutLibOnLdPath(dir_path, kLiteRtSo1)); - absl::string_view ld_library_path = getenv(kLdLibraryPath.data()); - EXPECT_THAT(ld_library_path, HasSubstr(nested_dir_path)); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/environment.cc b/tensorflow/lite/experimental/litert/core/environment.cc deleted file mode 100644 index 838a587fc471e6..00000000000000 --- a/tensorflow/lite/experimental/litert/core/environment.cc +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/environment.h" - -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -litert::Expected LiteRtEnvironmentT::CreateWithOptions( - absl::Span options) { - LITERT_LOG(LITERT_INFO, "Creating LiteRT environment with options"); - auto env = std::make_unique(); - for (const auto& opt : options) { - env->options_.SetOption(opt); - } - - return env; -} diff --git a/tensorflow/lite/experimental/litert/core/environment.h b/tensorflow/lite/experimental/litert/core/environment.h deleted file mode 100644 index 0ac5d42b0fc6b2..00000000000000 --- a/tensorflow/lite/experimental/litert/core/environment.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_H_ - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/environment_options.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator_registry.h" - -// A singleton class that contains global LiteRT environment options. -class LiteRtEnvironmentT { - public: - using Ptr = std::unique_ptr; - - LiteRtEnvironmentT() = default; - // Create an environment instance with options. - static litert::Expected CreateWithOptions( - absl::Span options); - - ~LiteRtEnvironmentT() = default; - - std::optional GetOption(LiteRtEnvOptionTag tag) const { - auto opt = options_.GetOption(tag); - return opt.HasValue() ? std::optional(opt.Value()) - : std::nullopt; - } - - LiteRtEnvironmentOptionsT& GetOptions() { return options_; } - const LiteRtEnvironmentOptionsT& GetOptions() const { return options_; } - - litert::internal::AcceleratorRegistry& GetAcceleratorRegistry() { - return accelerators_; - } - - private: - litert::internal::AcceleratorRegistry accelerators_; - LiteRtEnvironmentOptionsT options_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_H_ diff --git a/tensorflow/lite/experimental/litert/core/environment_options.cc b/tensorflow/lite/experimental/litert/core/environment_options.cc deleted file mode 100644 index ce1ea724ae83de..00000000000000 --- a/tensorflow/lite/experimental/litert/core/environment_options.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/environment_options.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -LiteRtEnvironmentOptionsT::LiteRtEnvironmentOptionsT( - LiteRtEnvironmentOptionsT&& other) - : options_(std::move(other.options_)), - string_option_values_(std::move(other.string_option_values_)) { - // Update the string pointers in case they have changed when moving the - // container. This can happen because of small string optimization. - RefreshStringOptionValuePointers(); -} - -LiteRtEnvironmentOptionsT& LiteRtEnvironmentOptionsT::operator=( - LiteRtEnvironmentOptionsT&& other) { - options_ = std::move(other.options_); - string_option_values_ = std::move(other.string_option_values_); - // Update the string pointers in case they have changed when moving the - // container. This can happen because of small string optimization. - RefreshStringOptionValuePointers(); - return *this; -} - -void LiteRtEnvironmentOptionsT::RefreshStringOptionValuePointers() { - for (const auto& [tag, value] : string_option_values_) { - options_[tag].str_value = value.c_str(); - } -} - -litert::Expected LiteRtEnvironmentOptionsT::GetOption( - LiteRtEnvOptionTag tag) const { - if (auto it = options_.find(tag); it != options_.end()) { - return it->second; - } - return litert::Error(kLiteRtStatusErrorNotFound, - "Option was not set for this environment."); -} - -litert::Expected LiteRtEnvironmentOptionsT::SetOption( - LiteRtEnvOption option) { - if (option.value.type == kLiteRtAnyTypeString) { - auto [string_it, _] = string_option_values_.insert_or_assign( - option.tag, option.value.str_value); - LiteRtAny value{/*type=*/kLiteRtAnyTypeString}; - value.str_value = string_it->second.c_str(); - options_[option.tag] = value; - } else { - options_[option.tag] = option.value; - } - return {}; -} diff --git a/tensorflow/lite/experimental/litert/core/environment_options.h b/tensorflow/lite/experimental/litert/core/environment_options.h deleted file mode 100644 index 92133fe59835c7..00000000000000 --- a/tensorflow/lite/experimental/litert/core/environment_options.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_OPTIONS_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -class LiteRtEnvironmentOptionsT { - public: - LiteRtEnvironmentOptionsT() = default; - - LiteRtEnvironmentOptionsT(LiteRtEnvironmentOptionsT&& other); - LiteRtEnvironmentOptionsT& operator=(LiteRtEnvironmentOptionsT&& other); - - litert::Expected GetOption(LiteRtEnvOptionTag tag) const; - litert::Expected SetOption(LiteRtEnvOption option); - - private: - void RefreshStringOptionValuePointers(); - - std::unordered_map options_; - std::unordered_map string_option_values_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/core/environment_options_test.cc b/tensorflow/lite/experimental/litert/core/environment_options_test.cc deleted file mode 100644 index 62294328590487..00000000000000 --- a/tensorflow/lite/experimental/litert/core/environment_options_test.cc +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/environment_options.h" - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { - -using testing::Eq; -using testing::Ne; -using testing::litert::IsError; - -TEST(EnvironmentOptionsTest, SetGetStringOptionWorks) { - LiteRtEnvironmentOptionsT options; - constexpr const char* kStrValue = "string_value"; - LiteRtEnvOption env_option{/*tag=*/kLiteRtEnvOptionTagDispatchLibraryDir, - /*value=*/{/*type=*/kLiteRtAnyTypeString}}; - env_option.value.str_value = kStrValue; - options.SetOption(env_option); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtAny stored_option, - options.GetOption(kLiteRtEnvOptionTagDispatchLibraryDir)); - - EXPECT_THAT(stored_option.type, Eq(kLiteRtAnyTypeString)); - EXPECT_THAT(stored_option.str_value, Ne(nullptr)); - EXPECT_THAT(stored_option.str_value, Ne(kStrValue)); -} - -TEST(EnvironmentOptionsTest, SetGetIntOptionWorks) { - constexpr int kIntValue = 3; - LiteRtEnvironmentOptionsT options; - LiteRtEnvOption env_option{/*tag=*/kLiteRtEnvOptionTagOpenClDeviceId, - /*value=*/{/*type=*/kLiteRtAnyTypeInt}}; - env_option.value.int_value = kIntValue; - options.SetOption(env_option); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtAny stored_option, - options.GetOption(kLiteRtEnvOptionTagOpenClDeviceId)); - - EXPECT_THAT(stored_option.type, Eq(kLiteRtAnyTypeInt)); - EXPECT_THAT(stored_option.int_value, Eq(kIntValue)); -} - -TEST(EnvironmentOptionsTest, GetNotSetReturnsNotFound) { - LiteRtEnvironmentOptionsT options; - - // Add a non related option. - constexpr const char* kStrValue = "string_value"; - LiteRtEnvOption env_option{/*tag=*/kLiteRtEnvOptionTagDispatchLibraryDir, - /*value=*/{/*type=*/kLiteRtAnyTypeString}}; - env_option.value.str_value = kStrValue; - options.SetOption(env_option); - - // Request an option that wasn't added. - EXPECT_THAT(options.GetOption(kLiteRtEnvOptionTagOpenClDeviceId), - IsError(kLiteRtStatusErrorNotFound)); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/core/environment_test.cc b/tensorflow/lite/experimental/litert/core/environment_test.cc deleted file mode 100644 index b0d199a53173cf..00000000000000 --- a/tensorflow/lite/experimental/litert/core/environment_test.cc +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/environment.h" - -#include -#include -#include - -#include -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" - -namespace litert::internal { -namespace { - -TEST(LiteRtEnvironmentT, CreateWithOptions) { - const std::array environment_options = { - LiteRtEnvOption{ - kLiteRtEnvOptionTagCompilerPluginLibraryDir, - *ToLiteRtAny(std::any("sample path")), - }, - }; - auto env = LiteRtEnvironmentT::CreateWithOptions(environment_options); - ASSERT_TRUE(env); - - auto option = (*env)->GetOption(kLiteRtEnvOptionTagCompilerPluginLibraryDir); - ASSERT_TRUE(option.has_value()); - ASSERT_EQ(option->type, kLiteRtAnyTypeString); - ASSERT_STREQ(option->str_value, "sample path"); -} - -TEST(LiteRtEnvironmentT, CheckStringCopy) { - LiteRtEnvironmentT::Ptr env; - - // The passed string becomes obsolete after the scope. - { - const std::array environment_options = { - LiteRtEnvOption{ - kLiteRtEnvOptionTagCompilerPluginLibraryDir, - *ToLiteRtAny(std::any("sample path")), - }, - }; - auto res = LiteRtEnvironmentT::CreateWithOptions(environment_options); - ASSERT_TRUE(res); - env = std::move(*res); - } - - auto option = env->GetOption(kLiteRtEnvOptionTagCompilerPluginLibraryDir); - ASSERT_TRUE(option.has_value()); - ASSERT_EQ(option->type, kLiteRtAnyTypeString); - ASSERT_STREQ(option->str_value, "sample path"); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/filesystem.cc b/tensorflow/lite/experimental/litert/core/filesystem.cc deleted file mode 100644 index e97a583aee27af..00000000000000 --- a/tensorflow/lite/experimental/litert/core/filesystem.cc +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/filesystem.h" - -#include -#include -#include // NOLINT -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -namespace litert::internal { - -namespace { - -using StdPath = std::filesystem::path; - -StdPath MakeStdPath(absl::string_view path) { - return StdPath(std::string(path.begin(), path.end())); -} - -bool StdExists(const StdPath& std_path) { - return std::filesystem::exists(std_path); -} - -size_t StdSize(const StdPath& std_path) { - return std::filesystem::file_size(std_path); -} - -LiteRtStatus StdIFRead(const StdPath& std_path, char* data, size_t size) { - std::ifstream in_file_stream(std_path, std::ifstream::binary); - if (!in_file_stream) { - return kLiteRtStatusErrorFileIO; - } - - in_file_stream.read(data, size); - if (!in_file_stream) { - return kLiteRtStatusErrorFileIO; - } - - in_file_stream.close(); - return kLiteRtStatusOk; -} - -} // namespace - -void Touch(absl::string_view path) { std::ofstream(MakeStdPath(path)); } - -std::string Join(const std::vector& paths) { - StdPath std_path; - for (auto subpath : paths) { - std_path /= MakeStdPath(subpath); - } - return std_path.generic_string(); -} - -bool Exists(absl::string_view path) { return StdExists(MakeStdPath(path)); } - -Expected Size(absl::string_view path) { - auto std_path = MakeStdPath(path); - if (!StdExists(std_path)) { - return Error(kLiteRtStatusErrorNotFound, - absl::StrFormat("File not found: %s", std_path.c_str())); - } - return StdSize(std_path); -} - -Expected> LoadBinaryFile(absl::string_view path) { - auto std_path = MakeStdPath(path); - - if (!StdExists(std_path)) { - return Error(kLiteRtStatusErrorNotFound, - absl::StrFormat("File not found: %s", std_path.c_str())); - } - - OwningBufferRef buf(StdSize(std_path)); - LITERT_RETURN_IF_ERROR(StdIFRead(std_path, buf.StrData(), buf.Size())); - - return buf; -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/filesystem.h b/tensorflow/lite/experimental/litert/core/filesystem.h deleted file mode 100644 index 3de517dfd4d5c6..00000000000000 --- a/tensorflow/lite/experimental/litert/core/filesystem.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_FILESYSTEM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_FILESYSTEM_H_ - -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -// Generic file operations. Try to encapsulate the std filesystem header as much -// as possible because its technically unapproved. - -namespace litert::internal { - -// Append all given subpaths together (e.g. os.path.join). -std::string Join(const std::vector& paths); - -// Make a new empty file at the given path. -void Touch(absl::string_view path); - -// Does this file exist. -bool Exists(absl::string_view path); - -// Get size of file. -Expected Size(absl::string_view path); - -// Load the bytes of the file at given path. -Expected> LoadBinaryFile(absl::string_view path); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_FILESYSTEM_H_ diff --git a/tensorflow/lite/experimental/litert/core/filesystem_test.cc b/tensorflow/lite/experimental/litert/core/filesystem_test.cc deleted file mode 100644 index d961d469d10100..00000000000000 --- a/tensorflow/lite/experimental/litert/core/filesystem_test.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/filesystem.h" - -#include -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" - -namespace litert::internal { -namespace { - -static constexpr absl::string_view kPrefix = "a/prefix"; -static constexpr absl::string_view kInfix = "an/infix"; -static constexpr absl::string_view kSuffix = "suffix.ext"; - -TEST(FilesystemTest, JoinTwo) { - const auto path = Join({kPrefix, kSuffix}); - EXPECT_EQ(path, absl::StrFormat("%s/%s", kPrefix, kSuffix)); -} - -TEST(FilesystemTest, JoinMany) { - const auto path = Join({kPrefix, kInfix, kSuffix}); - EXPECT_EQ(path, absl::StrFormat("%s/%s/%s", kPrefix, kInfix, kSuffix)); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/insert_order_map.h b/tensorflow/lite/experimental/litert/core/insert_order_map.h deleted file mode 100644 index f1c9ca46804943..00000000000000 --- a/tensorflow/lite/experimental/litert/core/insert_order_map.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_INSERT_ORDER_MAP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_INSERT_ORDER_MAP_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" - -namespace litert::internal { - -// A map implementation that iterates in the same order as initial insertion. -template -class InsertOrderMap { - public: - using Pair = std::pair; - using Values = std::vector; - using ValRef = std::reference_wrapper; - using Map = absl::flat_hash_map; - using Iterator = typename Values::iterator; - - InsertOrderMap() = default; - - std::optional Find(const Key& key) { - if (auto it = map_.find(key); it != map_.end()) { - const auto ind = it->second; - return values_[ind]; - } - return {}; - } - - bool Contains(const Key& key) const { return map_.find(key) != map_.end(); } - - void InsertOrAssign(const Key& key, const Val& val) { - if (auto it = map_.find(key); it != map_.end()) { - const auto ind = it->second; - values_[ind].second = val; - } else { - values_.push_back({key, val}); - map_.insert({key, values_.size() - 1}); - } - } - - size_t Size() const { return values_.size(); } - - void Clear() { - values_.clear(); - map_.clear(); - } - - Iterator Begin() { return values_.begin(); } - - Iterator End() { return values_.end(); } - - private: - Values values_; - Map map_; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_INSERT_ORDER_MAP_H_ diff --git a/tensorflow/lite/experimental/litert/core/insert_order_map_test.cc b/tensorflow/lite/experimental/litert/core/insert_order_map_test.cc deleted file mode 100644 index 6c24a01be97bdf..00000000000000 --- a/tensorflow/lite/experimental/litert/core/insert_order_map_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/insert_order_map.h" - -#include -#include -#include - -#include -#include - -namespace litert::internal { -namespace { - -using ::testing::ElementsAre; - -using TestMap = InsertOrderMap; - -static constexpr int k1 = 1; -static constexpr int k2 = 2; -static constexpr int k3 = 3; -static constexpr int k4 = 4; -static constexpr const char kV1[] = "1"; -static constexpr const char kV2[] = "2"; -static constexpr const char kV3[] = "3"; -static constexpr const char kV4[] = "4"; - -TestMap MakeTestMap() { - TestMap map; - map.InsertOrAssign(k1, kV1); - map.InsertOrAssign(k2, kV2); - map.InsertOrAssign(k3, kV3); - return map; -} - -TEST(InsertOrderMapTest, IterateInInsertOrder) { - auto map = MakeTestMap(); - ASSERT_EQ(map.Size(), 3); - - std::vector values(map.Begin(), map.End()); - EXPECT_THAT(values, - ElementsAre(std::make_pair(k1, kV1), std::make_pair(k2, kV2), - std::make_pair(k3, kV3))); -} - -TEST(InsertOrderMapTest, IterateInInsertOrderWithUpdate) { - auto map = MakeTestMap(); - ASSERT_EQ(map.Size(), 3); - - map.InsertOrAssign(k1, kV4); - std::vector values(map.Begin(), map.End()); - EXPECT_THAT(values, - ElementsAre(std::make_pair(k1, kV4), std::make_pair(k2, kV2), - std::make_pair(k3, kV3))); -} - -TEST(InsertOrderMapTest, FindExisting) { - auto map = MakeTestMap(); - ASSERT_EQ(map.Size(), 3); - - auto val = map.Find(k1); - ASSERT_TRUE(val.has_value()); - EXPECT_EQ(val->get().first, k1); - EXPECT_EQ(val->get().second, kV1); - - EXPECT_TRUE(map.Contains(k1)); -} - -TEST(InsertOrderMapTest, FindMissing) { - auto map = MakeTestMap(); - ASSERT_EQ(map.Size(), 3); - - EXPECT_EQ(map.Find(k4), std::nullopt); - EXPECT_FALSE(map.Contains(k4)); -} - -TEST(InsertOrderMapTest, Clear) { - auto map = MakeTestMap(); - ASSERT_EQ(map.Size(), 3); - - map.Clear(); - EXPECT_EQ(map.Size(), 0); - EXPECT_EQ(map.Begin(), map.End()); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/BUILD b/tensorflow/lite/experimental/litert/core/model/BUILD deleted file mode 100644 index 071d3200041830..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/BUILD +++ /dev/null @@ -1,349 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:special_rule.bzl", "lite_rt_friends") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - ] + lite_rt_friends(), -) - -cc_library( - name = "model", - srcs = ["model.cc"], - hdrs = [ - "model.h", - "//tensorflow/lite/experimental/litert/c:litert_model_hdrs", - ], - deps = [ - ":buffer_manager", - ":ir_allocator", - "//tensorflow/compiler/mlir/lite/core:model_builder_base", - "//tensorflow/lite/core/c:c_api_types", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_layout", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "model_test", - srcs = ["model_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", - ], - deps = [ - ":buffer_manager", - ":model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "model_load", - srcs = ["model_load.cc"], - hdrs = ["model_load.h"], - deps = [ - ":buffer_manager", - ":flatbuffer_to_litert", - ":model", - ":model_graph", - "//tensorflow/compiler/mlir/lite/core:model_builder_base", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "model_file_test", - srcs = ["model_file_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - # copybara:uncomment "//tensorflow/lite/java/demo/app/src/main/assets:mobilenet_v1_1.0_224.tflite", - ], - deps = [ - ":buffer_manager", - ":graph_validation", - ":model", - ":model_file_test_util", - ":model_load", - ":model_serialize", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core:dispatch_op_schema", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:test_models", - "//tensorflow/lite/schema:schema_fbs_with_mutable", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "model_serialize", - srcs = ["model_serialize.cc"], - hdrs = ["model_serialize.h"], - deps = [ - ":litert_to_flatbuffer", - ":model", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core:dispatch_op_schema", - "//tensorflow/lite/experimental/litert/core:insert_order_map", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/schema:schema_fbs_with_mutable", - "@com_google_absl//absl/container:flat_hash_map", - ], -) - -cc_library( - name = "flatbuffer_to_litert", - srcs = ["flatbuffer_to_litert.cc"], - hdrs = ["flatbuffer_to_litert.h"], - deps = [ - ":model", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/schema:schema_fbs", - ], -) - -cc_test( - name = "flatbuffer_to_litert_test", - srcs = ["flatbuffer_to_litert_test.cc"], - deps = [ - ":flatbuffer_to_litert", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "litert_to_flatbuffer", - srcs = ["litert_to_flatbuffer.cc"], - hdrs = ["litert_to_flatbuffer.h"], - deps = [ - ":model", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "litert_to_flatbuffer_test", - srcs = ["litert_to_flatbuffer_test.cc"], - deps = [ - ":litert_to_flatbuffer", - ":model", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "model_buffer", - srcs = ["model_buffer.cc"], - hdrs = ["model_buffer.h"], - deps = [ - ":model", - ":model_load", - ":model_serialize", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core:filesystem", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "model_file_test_util", - testonly = 1, - srcs = ["model_file_test_util.cc"], - hdrs = ["model_file_test_util.h"], - deps = [ - ":flatbuffer_to_litert", - ":model", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_detail", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "ir_allocator", - hdrs = ["ir_allocator.h"], - deps = ["@com_google_absl//absl/types:span"], -) - -cc_test( - name = "ir_allocator_test", - srcs = ["ir_allocator_test.cc"], - deps = [ - ":ir_allocator", - ":model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "model_graph", - srcs = ["model_graph.cc"], - hdrs = [ - "model_graph.h", - "//tensorflow/lite/experimental/litert/cc:litert_consts.h", - ], - deps = [ - ":model", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_detail", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:absl_check", - ], -) - -cc_library( - name = "graph_validation", - srcs = ["graph_validation.cc"], - hdrs = ["graph_validation.h"], - deps = [ - ":model", - ":model_graph", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_detail", - ], -) - -cc_library( - name = "buffer_manager", - hdrs = ["buffer_manager.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - ], -) - -cc_test( - name = "model_graph_test", - srcs = ["model_graph_test.cc"], - deps = [ - ":graph_validation", - ":ir_allocator", - ":model", - ":model_graph", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "model_buffer_test", - srcs = ["model_buffer_test.cc"], - deps = [ - ":model", - ":model_buffer", - ":model_load", - "//tensorflow/compiler/mlir/lite:allocation", - "//tensorflow/lite:framework", - "//tensorflow/lite:model_builder", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core:cc_api_stable", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/core:dispatch_op_schema", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_cascade_model_npu", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "buffer_manager_test", - srcs = ["buffer_manager_test.cc"], - deps = [ - ":buffer_manager", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/core/model/buffer_manager.h b/tensorflow/lite/experimental/litert/core/model/buffer_manager.h deleted file mode 100644 index af6b97f15c052b..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/buffer_manager.h +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_BUFFER_MANAGER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_BUFFER_MANAGER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -// Extra info about how the buffer is handled during load or serialization. -struct BufferContext { - using Ref = std::reference_wrapper; - - // Whether the buffer should be appended to the flatbuffer during - // serialization. - bool should_append = false; -}; - -// Container type for efficiently holding data buffers used by the model. These -// buffers may be owned or non-owned by the model. Uses id based indexing. -class BufferManager { - public: - using Ptr = std::unique_ptr; - - // Unique identifier for a buffer. 0 is reserved for empty buffers. - using BufferId = uint32_t; - static constexpr BufferId kEmptyBufferId = 0; - - // Register a buffer that is not owned by the model. Caller must ensure the - // buffer outlives the model. - BufferId RegisterNonOwnedBuffer( - BufferRef buffer, - std::optional context = std::nullopt) { - auto&& ctx = context.has_value() ? std::move(*context) : BufferContext{}; - buffers_.emplace_back(BufferWithContext(buffer, std::move(ctx))); - return buffers_.size() - 1; - } - - // Register a buffer that is owned by the model. - BufferId RegisterOwnedBuffer( - OwningBufferRef&& buffer, - std::optional context = std::nullopt) { - auto&& ctx = context.has_value() ? std::move(*context) : BufferContext{}; - buffers_.emplace_back(BufferWithContext(buffer, std::move(ctx))); - return buffers_.size() - 1; - } - - // Get a view of the buffer at the given id. - Expected> GetBuffer(BufferId id) { - if (id >= buffers_.size()) { - return Error(kLiteRtStatusErrorIndexOOB); - } - return GetView(buffers_[id].first); - } - - // Get the context of the buffer at the given id. - Expected GetContext(BufferId id) { - if (id >= buffers_.size()) { - return Error(kLiteRtStatusErrorIndexOOB); - } - return std::ref(buffers_[id].second); - } - - // Number of buffers. Ids will be 0 <-> num - 1. - size_t NumBuffers() const { return buffers_.size(); } - - BufferManager() { - // Zero is reserved for empty buffers. - buffers_.emplace_back( - BufferWithContext(BufferRef(), BufferContext{})); - } - BufferManager(const BufferManager&) = delete; - BufferManager& operator=(const BufferManager&) = delete; - BufferManager(BufferManager&& other) = default; - BufferManager& operator=(BufferManager&& other) = default; - - private: - using BufferType = std::variant, OwningBufferRef>; - using BufferWithContext = std::pair; - - static BufferRef GetView(const BufferType& buffer) { - BufferRef res; - std::visit([&res](auto&& arg) { res = arg; }, buffer); - return res; - } - - std::vector buffers_; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_BUFFER_MANAGER_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/buffer_manager_test.cc b/tensorflow/lite/experimental/litert/core/model/buffer_manager_test.cc deleted file mode 100644 index b077eda8b4f3da..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/buffer_manager_test.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/buffer_manager.h" - -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" - -namespace litert::internal { - -namespace { - -static constexpr absl::string_view kData = "foo"; - -TEST(BufferManagerTest, EmptyFirstBuffer) { - BufferManager manager; - - EXPECT_EQ(manager.NumBuffers(), 1); - EXPECT_EQ(manager.GetBuffer(BufferManager::kEmptyBufferId)->Size(), 0); -} - -TEST(BufferManagerTest, RegisterNonOwnedBuffer) { - BufferManager manager; - - OwningBufferRef buffer(kData); - const auto id = manager.RegisterNonOwnedBuffer(buffer); - - EXPECT_EQ(manager.NumBuffers(), 2); - EXPECT_EQ(manager.GetBuffer(id)->StrView(), kData); -} - -TEST(BufferManagerTest, RegisterOwnedBuffer) { - BufferManager manager; - - OwningBufferRef buffer(kData); - const auto id = manager.RegisterOwnedBuffer(std::move(buffer)); - - EXPECT_EQ(manager.NumBuffers(), 2); - EXPECT_EQ(manager.GetBuffer(id)->StrView(), kData); -} - -TEST(BufferManagerTest, RegisterWithContext) { - BufferManager manager; - - OwningBufferRef buffer(kData); - BufferContext context = {true}; - const auto id = manager.RegisterNonOwnedBuffer(buffer, context); - - EXPECT_EQ(manager.NumBuffers(), 2); - EXPECT_EQ(manager.GetBuffer(id)->StrView(), kData); - EXPECT_EQ(manager.GetContext(id)->get().should_append, true); -} - -} // namespace - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.cc b/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.cc deleted file mode 100644 index 36c721af2009cc..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.cc +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h" - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert::internal { - -LiteRtStatus IsOpSupported(const tflite::OperatorT& op) { - // TODO: b/365299994 - Check for supported options. - - if (!op.intermediates.empty()) { - // TODO: b/365299994 - Support intermediates. - LITERT_LOG(LITERT_ERROR, "Intermediate tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - if (op.large_custom_options_size != 0) { - // TODO: b/365299994 - Support large custom options. - LITERT_LOG(LITERT_ERROR, "Large custom options not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - for (auto m_input : op.mutating_variable_inputs) { - if (m_input) { - // TODO: b/365299994 - Support mutating variable inputs. - LITERT_LOG(LITERT_ERROR, "Mutating variable inputs not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - } - - return kLiteRtStatusOk; -} - -LiteRtStatus IsBufferSupported(const tflite::BufferT& buffer) { - if (buffer.offset != 0) { - // TODO: b/365299994 - Support buffer with offset. - LITERT_LOG(LITERT_ERROR, "Buffers with offset not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - return kLiteRtStatusOk; -} - -LiteRtStatus IsTensorSupported(const TflTensor& tensor) { - if (tensor.is_variable) { - // TODO: b/365299994 - Support variable tensors. - LITERT_LOG(LITERT_ERROR, "Variable tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - if (!tensor.variant_tensors.empty()) { - // TODO: b/365299994 - Support variant tensors. - LITERT_LOG(LITERT_ERROR, "Variant tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - if (tensor.sparsity) { - // TODO: b/365299994 - Support sparsity tensors. - LITERT_LOG(LITERT_ERROR, "Sparsity tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - return kLiteRtStatusOk; -} - -LiteRtElementType MapElementType(TflElementType type) { - switch (type) { - case tflite::TensorType_FLOAT32: - return kLiteRtElementTypeFloat32; - case tflite::TensorType_FLOAT16: - return kLiteRtElementTypeFloat16; - case tflite::TensorType_INT32: - return kLiteRtElementTypeInt32; - case tflite::TensorType_INT64: - return kLiteRtElementTypeInt64; - case tflite::TensorType_BOOL: - return kLiteRtElementTypeBool; - case tflite::TensorType_INT16: - return kLiteRtElementTypeInt16; - case tflite::TensorType_INT8: - return kLiteRtElementTypeInt8; - case tflite::TensorType_UINT8: - return kLiteRtElementTypeUInt8; - case tflite::TensorType_INT4: - return kLiteRtElementTypeInt4; - default: - return kLiteRtElementTypeNone; - } -} - -Expected MapTensorType(const TflTensorType& tfl_tensor_type) { - const auto& [element_type, shape] = tfl_tensor_type; - auto ranked_shape = AsDynamicShape(shape); - if (!ranked_shape) { - LITERT_LOG(LITERT_ERROR, "Only ranked tensors currently supported"); - return Error(kLiteRtStatusErrorUnsupported); - } - - auto litert_element_type = MapElementType(element_type); - if (litert_element_type == kLiteRtElementTypeNone) { - LITERT_LOG(LITERT_ERROR, "Element type not currently supported"); - return Error(kLiteRtStatusErrorUnsupported); - } - - TensorTypeDetail detail; - detail.ranked_tensor_type.element_type = litert_element_type; - detail.ranked_tensor_type.layout = BuildLayout(*ranked_shape); - - return std::make_pair(kLiteRtRankedTensorType, detail); -} - -Expected MapQuantization(const TflQuantization* tfl_quantization, - ScratchBufferProvider buffer_provider) { - if (!IsQuantized(tfl_quantization)) { - return MakeEmptyQuantization(); - } - - if (auto tfl_qparams = AsPerTensorQparams(tfl_quantization)) { - return MakePerTensorQuantization(tfl_qparams->second, tfl_qparams->first); - } - - if (auto tfl_qparams = AsPerChannelQparams(tfl_quantization)) { - [[maybe_unused]] const auto& [quantized_dimension, num_channels, - zero_points, scales] = *tfl_qparams; - return MakePerChannelQuantization(scales, zero_points, quantized_dimension, - buffer_provider); - } - - LITERT_LOG(LITERT_ERROR, "Uknown tfl quantization type"); - return Error(kLiteRtStatusErrorUnsupported); -} -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h b/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h deleted file mode 100644 index 92a7d11cdf0321..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_FLATBUFFER_TO_LITERT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_FLATBUFFER_TO_LITERT_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -namespace litert::internal { - -LiteRtStatus IsOpSupported(const TflOp& op); - -LiteRtStatus IsBufferSupported(const TflBuffer& buffer); - -// Checks if the misc non-type non quantization parts of this tensor are -// supported in the litet model api. -LiteRtStatus IsTensorSupported(const TflTensor& tensor); - -LiteRtElementType MapElementType(TflElementType element_type); - -Expected MapTensorType(const TflTensorType& tfl_tensor_type); - -Expected MapQuantization(const TflQuantization* tfl_quantization, - ScratchBufferProvider buffer_provider); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_FLATBUFFER_TO_LITERT_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert_test.cc b/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert_test.cc deleted file mode 100644 index b0a2e6598a683f..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert_test.cc +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h" - -#include -#include -#include - -#include -#include -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -namespace litert::internal { -namespace { - -using ::testing::ElementsAreArray; - -TEST(FlatbufferToLiteRtTest, MapStaticTensorType) { - static constexpr int32_t kDims[] = {2, 2}; - static constexpr auto kDimsSpan = absl::MakeConstSpan(kDims); - - auto t = MapTensorType(std::make_pair(TflElementType::TensorType_INT32, - TflShapeInfo(kDimsSpan))); - ASSERT_TRUE(t); - - ASSERT_EQ(t->first, kLiteRtRankedTensorType); - auto& ranked = t->second.ranked_tensor_type; - EXPECT_EQ(ranked.element_type, kLiteRtElementTypeInt32); - EXPECT_EQ(absl::MakeSpan(ranked.layout.dimensions, ranked.layout.rank), - kDimsSpan); -} - -TEST(FlatbufferToLiteRtTest, MapStaticTensorInt4Type) { - static constexpr int32_t kDims[] = {2, 2}; - static constexpr auto kDimsSpan = absl::MakeConstSpan(kDims); - - auto t = MapTensorType( - std::make_pair(TflElementType::TensorType_INT4, TflShapeInfo(kDimsSpan))); - ASSERT_TRUE(t); - - ASSERT_EQ(t->first, kLiteRtRankedTensorType); - auto& ranked = t->second.ranked_tensor_type; - EXPECT_EQ(ranked.element_type, kLiteRtElementTypeInt4); - EXPECT_EQ(absl::MakeSpan(ranked.layout.dimensions, ranked.layout.rank), - kDimsSpan); -} - -TEST(FlatbufferToLiteRtTest, MapDynamicTensorType) { - static constexpr int32_t kDims[] = {-1, 2}; - static constexpr auto kDimsSpan = absl::MakeConstSpan(kDims); - - auto t = MapTensorType(std::make_pair(TflElementType::TensorType_INT32, - TflShapeInfo(kDimsSpan))); - ASSERT_TRUE(t); - - ASSERT_EQ(t->first, kLiteRtRankedTensorType); - auto& ranked = t->second.ranked_tensor_type; - EXPECT_EQ(ranked.element_type, kLiteRtElementTypeInt32); - EXPECT_EQ(absl::MakeSpan(ranked.layout.dimensions, ranked.layout.rank), - kDimsSpan); -} - -TEST(FlatbufferToLiteRtTest, MapNoQuantization) { - LiteRtTensorT tensor; - auto q = MapQuantization(nullptr, tensor); - ASSERT_TRUE(q); - ASSERT_EQ(q->first, kLiteRtQuantizationNone); -} - -TEST(FlatbufferToLiteRtTest, MapPerTensorQuantization) { - static constexpr float kScale = 1.0; - static constexpr int64_t kZp = 2; - - TflQuantization tfl_q; - tfl_q.scale.assign({kScale}); - tfl_q.zero_point.assign({kZp}); - - LiteRtTensorT tensor; - auto q = MapQuantization(&tfl_q, tensor); - ASSERT_TRUE(q); - ASSERT_EQ(q->first, kLiteRtQuantizationPerTensor); - EXPECT_EQ(q->second.per_tensor.scale, kScale); - EXPECT_EQ(q->second.per_tensor.zero_point, kZp); -} - -TEST(FlatbufferToLiteRtTest, MapPerChannelQuantization) { - static constexpr size_t kRank = 2; - static constexpr float kScales[kRank] = {1.0, 2.0}; - static constexpr int64_t kZps[kRank] = {2, 3}; - static constexpr size_t kQDim = 1; - - TflQuantization tfl_q; - tfl_q.scale.assign(kScales, kScales + kRank); - tfl_q.zero_point.assign(kZps, kZps + kRank); - tfl_q.quantized_dimension = kQDim; - - LiteRtTensorT tensor; - auto q = MapQuantization(&tfl_q, tensor); - ASSERT_TRUE(q); - ASSERT_EQ(q->first, kLiteRtQuantizationPerChannel); - EXPECT_THAT(absl::MakeConstSpan(q->second.per_channel.scales, kRank), - ElementsAreArray(kScales)); - - EXPECT_THAT(absl::MakeConstSpan(q->second.per_channel.zero_points, kRank), - ElementsAreArray(kZps)); - EXPECT_EQ(q->second.per_channel.quantized_dimension, kQDim); - EXPECT_EQ(q->second.per_channel.num_channels, kRank); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/graph_validation.cc b/tensorflow/lite/experimental/litert/core/model/graph_validation.cc deleted file mode 100644 index a9a942c1bfaa14..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/graph_validation.cc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/graph_validation.h" - -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" - -namespace litert::internal { - -bool ValidateLocalTopology(const LiteRtOpT& litert_op) { - // Check number of in edges equals number of inputs and each input index - // appears on an in edge. - for (auto i = 0; i < litert_op.Inputs().size(); ++i) { - const auto& litert_tensor = litert_op.Input(i); - - auto input_use = - GetTensorUses(litert_tensor, FindUseInds(litert_tensor, litert_op)); - - if (!ContainsIf(input_use.cbegin(), input_use.cend(), - [i](auto u) { return u.second == i; })) { - LITERT_LOG(LITERT_WARNING, - "Input tensor %d not connected to op on correct index.", i); - return false; - } - } - - // Similar to above for outputs. - for (auto i = 0; i < litert_op.Outputs().size(); ++i) { - const auto& litert_tensor = litert_op.Output(i); - - if (litert_tensor.DefiningOp() != &litert_op) { - LITERT_LOG(LITERT_WARNING, "Output back edge doesn't refer to this op."); - return false; - } - - if (litert_tensor.DefiningOpOutInd() != i) { - LITERT_LOG(LITERT_WARNING, "Output back edge ind is incorrect."); - return false; - } - } - - return true; -} - -bool ValidateSubgraphIO(const LiteRtSubgraphT& litert_subgraph) { - auto num_implied_inputs = 0; - auto num_implied_outputs = 0; - for (auto* tensor : litert_subgraph.Tensors()) { - const auto implied_out = tensor->NumUses() == 0; - const auto implied_in = - !IsConstant(*tensor) && tensor->DefiningOp() == nullptr; - - if (implied_out && implied_in) { - LITERT_LOG(LITERT_WARNING, "Graph contains a dead tensor"); - return false; - } - - const auto is_io = IsIO(litert_subgraph, *tensor); - - if (implied_in) { - if (!is_io) { - LITERT_LOG(LITERT_WARNING, - "Implied input not reflected in subgraph io %lu", - tensor - litert_subgraph.Tensors().at(0)); - return false; - } - ++num_implied_inputs; - } - - if (implied_out) { - if (!is_io) { - LITERT_LOG(LITERT_WARNING, - "Implied output not reflected in subgraph io"); - return false; - } - ++num_implied_outputs; - } - } - - if (num_implied_inputs != litert_subgraph.NumInputs()) { - LITERT_LOG( - LITERT_WARNING, - "Number of implied %lu inputs not equal to number of actual inputs %lu", - num_implied_inputs, litert_subgraph.NumInputs()); - return false; - } - - if (num_implied_outputs != litert_subgraph.NumOutputs()) { - LITERT_LOG(LITERT_WARNING, - "Number of implied %lu outputs not equal to number of actual " - "outputs %lu", - num_implied_outputs, litert_subgraph.NumOutputs()); - return false; - } - - return true; -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/graph_validation.h b/tensorflow/lite/experimental/litert/core/model/graph_validation.h deleted file mode 100644 index c0a199294f8677..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/graph_validation.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_GRAPH_VALIDATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_GRAPH_VALIDATION_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" - -// Helper functions for validating the structure of IR graphs. - -namespace litert::internal { - -// Checks the double-linked edges to immediate neighbors are valid. -bool ValidateLocalTopology(const LiteRtOpT& litert_op); - -// Runs ValidateLocalTopology across given LiteRtOp iterator. -template -bool ValidateLocalTopology(OpIt start, OpIt end) { - return std::all_of(start, end, - [](const auto* op) { return ValidateLocalTopology(*op); }); -} - -// Checks the following are bijections: -// * non-const tensor with no defining op <-> subgraph input -// * tensor with no users <-> subgraph output (assuming no side effect ops) -// These are used to figure out the i/o signatures when building a subgraph -// from scratch. -bool ValidateSubgraphIO(const LiteRtSubgraphT& litert_subgraph); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_GRAPH_VALIDATION_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/ir_allocator.h b/tensorflow/lite/experimental/litert/core/model/ir_allocator.h deleted file mode 100644 index 43433c1ecd02c8..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/ir_allocator.h +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_IR_ALLOCATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_IR_ALLOCATOR_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/types/span.h" - -namespace litert::internal { - -// A list of IR objects scoped to the same block (subgraph) that provides -// pointer stability. Facilitates management of memory and c-like access -// to elements. -template -class IrAllocator { - private: - using Storage = std::list; - using Refs = std::vector; - - public: - // Emplace a new element onto the list. - template - Ir& EmplaceBack(Args&&... args) { - auto& emp = storage_.emplace_back(std::forward(args)...); - refs_->push_back(&emp); - return emp; - } - - // Get the array of (stable) pointers to underlying elements. Suitable - // for passing through c-like interface. Consituent pointers are always - // guarateed to be stable (unless explicitly erased). The array of pointers - // itself is guaranteed to be stable so long as no length-changing operations - // occur, moving this class does not invalidate pointers or array. - absl::Span Elements() const { - return absl::MakeSpan(refs_->data(), refs_->size()); - } - - // Remove elements from the allocator if they match the predicate. - // Returns the number of elements removed. - size_t RemoveIf(std::function pred) { - auto ref_it = refs_->begin(); - for (auto it = storage_.begin(); it != storage_.end();) { - if (!pred(*it)) { - *ref_it = &*it; - ++ref_it; - ++it; - continue; - } - it = storage_.erase(it); - } - const size_t removed = refs_->end() - ref_it; - refs_->resize(refs_->size() - removed); - return removed; - } - - // Cuts all but the first `size` elements from storage. Does nothing if `size` - // is greater or equal to current size. - void ResizeDown(size_t size) { - if (size >= Size()) { - return; - } - storage_.resize(size); - refs_->resize(size); - } - - // Transfers the ownership of given allocator to this one. If `indices` is - // provided, only the objects at the given indices are transferred. - void TransferFrom(IrAllocator& other, - std::optional> indices = std::nullopt) { - if (!indices) { - storage_.splice(storage_.cend(), other.storage_); - refs_->insert(refs_->end(), other.refs_->cbegin(), other.refs_->cend()); - other.ResetRefs(); - return; - } - - auto& inds = *indices; - std::sort(inds.begin(), inds.end()); - std::vector its; - auto i = 0; - auto it = other.storage_.begin(); - for (auto ind : inds) { - std::advance(it, ind - i); - i = ind; - its.push_back(it); - } - for (auto it : its) { - storage_.splice(storage_.cend(), other.storage_, it); - } - - ResetRefs(); - other.ResetRefs(); - } - - // Override for rvalues. - void TransferFrom(IrAllocator&& other) { TransferFrom(other, std::nullopt); } - - // Transfers the object at the given index to the back of the given allocator. - void TransferTo(IrAllocator& other, - std::optional> indices = std::nullopt) { - other.TransferFrom(*this, std::move(indices)); - } - - // Number of elements stored by this allocator. - size_t Size() const { return storage_.size(); } - - IrAllocator() { refs_ = std::make_unique(); } - - // IR is generally semantically movable (without reference invalidation) - // but not copyable. IrAllocators reflect that, note moving lists - // does not invalidate references. - IrAllocator(const IrAllocator& other) = delete; - IrAllocator& operator=(const IrAllocator& other) = delete; - IrAllocator(IrAllocator&& other) = default; - IrAllocator& operator=(IrAllocator&& other) = default; - - private: - void ResetRefs() { - refs_->resize(storage_.size()); - auto it = storage_.begin(); - for (auto i = 0; i < storage_.size(); ++i, ++it) { - refs_->at(i) = &*it; - } - } - - Storage storage_; - std::unique_ptr refs_; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_IR_ALLOCATOR_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/ir_allocator_test.cc b/tensorflow/lite/experimental/litert/core/model/ir_allocator_test.cc deleted file mode 100644 index dd895dce211e25..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/ir_allocator_test.cc +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/ir_allocator.h" - -#include -#include -#include - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { -namespace { - -using ::testing::ElementsAreArray; - -static constexpr auto kCustomOpCode = kLiteRtOpCodeTflCustom; -static constexpr auto kNonCustomOpCode = kLiteRtOpCodeTflSoftmax; - -TEST(IrAllocatorTest, EmplaceBack) { - IrAllocator ops; - - LiteRtOpT my_op; - my_op.SetOpCode(kCustomOpCode); - - ops.EmplaceBack(std::move(my_op)); - ASSERT_EQ(ops.Elements().size(), 1); - EXPECT_EQ(ops.Elements().at(0)->OpCode(), kCustomOpCode); -} - -TEST(IrAllocatorTest, RemoveIf) { - IrAllocator ops; - - LiteRtOpT my_op; - my_op.SetOpCode(kNonCustomOpCode); - ops.EmplaceBack(std::move(my_op)); - - LiteRtOpT my_op2; - my_op2.SetOpCode(kCustomOpCode); - ops.EmplaceBack(std::move(my_op2)); - - LiteRtOpT my_op3; - my_op3.SetOpCode(kCustomOpCode); - ops.EmplaceBack(std::move(my_op3)); - - LiteRtOpT my_op4; - my_op4.SetOpCode(kNonCustomOpCode); - ops.EmplaceBack(std::move(my_op4)); - - auto pred = [](const auto& op) { return op.OpCode() != kCustomOpCode; }; - ASSERT_EQ(ops.RemoveIf(pred), 2); - - ASSERT_EQ(ops.Elements().size(), 2); - ASSERT_EQ(ops.Elements().at(0)->OpCode(), kCustomOpCode); - ASSERT_EQ(ops.Elements().at(1)->OpCode(), kCustomOpCode); -} - -TEST(IrAllocatorTest, ResizeDown) { - IrAllocator ops; - - LiteRtOp op1 = nullptr; - { - LiteRtOpT my_op; - my_op.SetOpCode(kNonCustomOpCode); - op1 = &ops.EmplaceBack(std::move(my_op)); - } - - { - LiteRtOpT my_op2; - my_op2.SetOpCode(kCustomOpCode); - ops.EmplaceBack(std::move(my_op2)); - } - - ops.ResizeDown(1); - - ASSERT_EQ(ops.Size(), 1); - EXPECT_EQ(ops.Elements().at(0), op1); -} - -TEST(IrAllocatorTest, Transfer) { - IrAllocator ops; - auto& op1 = ops.EmplaceBack(); - auto& op2 = ops.EmplaceBack(); - - IrAllocator other_ops; - auto& other_op1 = other_ops.EmplaceBack(); - auto& other_op2 = other_ops.EmplaceBack(); - - ops.TransferFrom(std::move(other_ops)); - - EXPECT_THAT(ops.Elements(), - ElementsAreArray({&op1, &op2, &other_op1, &other_op2})); -} - -TEST(IrAllocatorTest, TransferWithIndices) { - IrAllocator ops; - auto& op1 = ops.EmplaceBack(); - auto& op2 = ops.EmplaceBack(); - - IrAllocator other_ops; - auto& other_op1 = other_ops.EmplaceBack(); - auto& other_op2 = other_ops.EmplaceBack(); - auto& other_op3 = other_ops.EmplaceBack(); - auto& other_op4 = other_ops.EmplaceBack(); - - std::vector indices = {1, 3}; - ops.TransferFrom(other_ops, std::move(indices)); - - EXPECT_THAT(other_ops.Elements(), ElementsAreArray({&other_op1, &other_op3})); - EXPECT_THAT(ops.Elements(), - ElementsAreArray({&op1, &op2, &other_op2, &other_op4})); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.cc b/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.cc deleted file mode 100644 index 90292600ace0d3..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.cc +++ /dev/null @@ -1,128 +0,0 @@ - -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h" - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert::internal { - -namespace { - -Expected MapElementType(LiteRtElementType litert_element_type) { - switch (litert_element_type) { - case kLiteRtElementTypeFloat32: - return tflite::TensorType_FLOAT32; - case kLiteRtElementTypeFloat16: - return tflite::TensorType_FLOAT16; - case kLiteRtElementTypeInt32: - return tflite::TensorType_INT32; - case kLiteRtElementTypeInt64: - return tflite::TensorType_INT64; - case kLiteRtElementTypeBool: - return tflite::TensorType_BOOL; - case kLiteRtElementTypeInt16: - return tflite::TensorType_INT16; - case kLiteRtElementTypeInt8: - return tflite::TensorType_INT8; - default: - return Error(kLiteRtStatusErrorUnsupported); - } -} - -template -Expected MapTensorTypeDetail( - const LiteRtTenzorType& litert_tensor_type) { - return Error(kLiteRtStatusErrorUnsupported); -} - -template <> -Expected MapTensorTypeDetail( - const LiteRtRankedTensorType& litert_tensor_type) { - auto tfl_element_type = MapElementType(litert_tensor_type.element_type); - if (!tfl_element_type) { - return tfl_element_type.Error(); - } - - auto litert_shape = absl::MakeConstSpan(litert_tensor_type.layout.dimensions, - litert_tensor_type.layout.rank); - return std::make_pair(*tfl_element_type, TflShapeInfo(litert_shape)); -} - -template -Expected MapQuantizationDetail( - const LiteRtQuantDetail& litert_quantization) { - return Error(kLiteRtStatusErrorUnsupported); -} - -template <> -Expected MapQuantizationDetail( - const LiteRtQuantizationPerTensor& litert_quantization) { - auto tfl_quantization = std::make_unique(); - tfl_quantization->scale.assign({litert_quantization.scale}); - tfl_quantization->zero_point.assign({litert_quantization.zero_point}); - return tfl_quantization; -} - -template <> -Expected -MapQuantizationDetail( - const LiteRtQuantizationPerChannel& litert_quantization) { - auto tfl_quantization = std::make_unique(); - - for (int i = 0; i < litert_quantization.num_channels; ++i) { - tfl_quantization->scale.push_back(litert_quantization.scales[i]); - tfl_quantization->zero_point.push_back(litert_quantization.zero_points[i]); - } - tfl_quantization->quantized_dimension = - litert_quantization.quantized_dimension; - return tfl_quantization; -} - -} // namespace - -Expected MapTensorType(const TensorType& litert_tensor_type) { - switch (litert_tensor_type.first) { - case kLiteRtRankedTensorType: - return MapTensorTypeDetail(litert_tensor_type.second.ranked_tensor_type); - default: - return Error(kLiteRtStatusErrorUnsupported); - } -} - -Expected MapQuantization( - const Quantization& litert_quantization) { - switch (litert_quantization.first) { - case kLiteRtQuantizationNone: - return TflQuantizationPtr(nullptr); - case kLiteRtQuantizationPerTensor: - return MapQuantizationDetail(litert_quantization.second.per_tensor); - case kLiteRtQuantizationPerChannel: - return MapQuantizationDetail(litert_quantization.second.per_channel); - default: - return Error(kLiteRtStatusErrorUnsupported); - } -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h b/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h deleted file mode 100644 index 4fbe51bf9d3a0b..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h +++ /dev/null @@ -1,32 +0,0 @@ - -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_LITERT_TO_FLATBUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_LITERT_TO_FLATBUFFER_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -namespace litert::internal { - -Expected MapTensorType(const TensorType& litert_tensor_type); - -Expected MapQuantization( - const Quantization& litert_quantization); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_LITERT_TO_FLATBUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer_test.cc b/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer_test.cc deleted file mode 100644 index 3f5c8fdf101fa1..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer_test.cc +++ /dev/null @@ -1,108 +0,0 @@ - -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h" - -#include -#include -#include - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -namespace litert::internal { -namespace { - -using ::testing::ElementsAreArray; - -TEST(LiteRtToFlatbufferTest, MapNoQuantization) { - Quantization q; - auto tfl_q = MapQuantization(q); - ASSERT_TRUE(tfl_q); - EXPECT_EQ(tfl_q.Value(), nullptr); -} - -TEST(LiteRtToFlatbufferTest, MapPerTensorQuantization) { - static constexpr float kScale = 1.0; - static constexpr int64_t kZp = 2; - - Quantization q; - q.first = kLiteRtQuantizationPerTensor; - q.second.per_tensor.scale = kScale; - q.second.per_tensor.zero_point = kZp; - - auto tfl_q = MapQuantization(q); - ASSERT_TRUE(tfl_q); - EXPECT_THAT(tfl_q->get()->scale, ElementsAreArray({kScale})); - EXPECT_THAT(tfl_q->get()->zero_point, ElementsAreArray({kZp})); -} - -TEST(LiteRtToFlatbufferTest, MapPerChannelQuantization) { - static constexpr size_t kRank = 2; - static constexpr size_t kQuantizedDimension = 1; - static constexpr float kScales[kRank] = {1.0, 2.0}; - static constexpr int64_t kZps[kRank] = {2, 3}; - - Quantization q; - q.first = kLiteRtQuantizationPerChannel; - q.second.per_channel.scales = const_cast(kScales); - q.second.per_channel.zero_points = const_cast(kZps); - q.second.per_channel.num_channels = kRank; - q.second.per_channel.quantized_dimension = kQuantizedDimension; - - auto tfl_q = MapQuantization(q); - ASSERT_TRUE(tfl_q); - EXPECT_THAT(tfl_q->get()->scale, ElementsAreArray(kScales)); - EXPECT_THAT(tfl_q->get()->zero_point, ElementsAreArray(kZps)); -} - -TEST(LiteRtToFlatbufferTest, MapDynamicTensorType) { - static constexpr int32_t kDims[] = {-1, 2}; - - TensorType t; - t.first = kLiteRtRankedTensorType; - t.second.ranked_tensor_type.element_type = kLiteRtElementTypeFloat32; - t.second.ranked_tensor_type.layout = BuildLayout(kDims); - - auto tfl_t = MapTensorType(t); - ASSERT_TRUE(tfl_t); - EXPECT_EQ(tfl_t->first, TflElementType::TensorType_FLOAT32); - EXPECT_TRUE(tfl_t->second.has_rank); - EXPECT_THAT(tfl_t->second.shape, ElementsAreArray({1, 2})); - EXPECT_THAT(tfl_t->second.shape_signature, ElementsAreArray(kDims)); -} - -TEST(LiteRtToFlatbufferTest, MapStaticTensorType) { - static constexpr int32_t kDims[] = {2, 2}; - - TensorType t; - t.first = kLiteRtRankedTensorType; - t.second.ranked_tensor_type.element_type = kLiteRtElementTypeFloat32; - t.second.ranked_tensor_type.layout = BuildLayout(kDims); - - auto tfl_t = MapTensorType(t); - ASSERT_TRUE(tfl_t); - EXPECT_EQ(tfl_t->first, TflElementType::TensorType_FLOAT32); - EXPECT_TRUE(tfl_t->second.has_rank); - EXPECT_THAT(tfl_t->second.shape, ElementsAreArray({2, 2})); - EXPECT_TRUE(tfl_t->second.shape_signature.empty()); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model.cc b/tensorflow/lite/experimental/litert/core/model/model.cc deleted file mode 100644 index 552e3f1d5ed96a..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model.cc +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_layout.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" - -using ::litert::BufferRef; -using ::litert::internal::TflBuffer; -using ::litert::internal::TflBufferPtr; -using ::litert::internal::TflOpCode; -using ::litert::internal::TflOpCodePtr; -using ::litert::internal::TflOptions; -using ::litert::internal::TflOptions2; - -std::optional GetBuildStamp( - const LiteRtModelT& model) { - using ::litert::internal::kLiteRtBuildStampKey; - using ::litert::internal::ParseBuildStamp; - - auto stamp_meta = model.FindMetadata(kLiteRtBuildStampKey); - if (!stamp_meta) { - return std::nullopt; - } - auto parsed_stamp = ParseBuildStamp(*stamp_meta); - if (!parsed_stamp) { - return std::nullopt; - } - auto [soc_manufacturer, soc_model] = *parsed_stamp; - return LiteRtModelT::BuildStamp{soc_manufacturer, soc_model}; -} - -bool IsCompiled(const LiteRtModelT& model) { - return GetBuildStamp(model).has_value(); -} - -std::optional GetCustomOpCode(const LiteRtModelT& model, - const LiteRtOpT& op) { - if (op.OpCode() != kLiteRtOpCodeTflCustom) { - return {}; - } - const auto& tfl_op_codes = litert::internal::GetTflOpCodes(model); - const auto tfl_op_code_ind = litert::internal::GetTflOpCodeInd(op); - return tfl_op_codes[tfl_op_code_ind]->custom_code; -} - -TensorType MakeRankedTensorType(LiteRtElementType element_type, - absl::Span dims) { - TensorType tensor_type; - tensor_type.first = kLiteRtRankedTensorType; - auto& ranked = tensor_type.second.ranked_tensor_type; - ranked.element_type = element_type; - ABSL_DCHECK_LE(dims.size(), LITERT_TENSOR_MAX_RANK); - ranked.layout.rank = dims.size(); - std::copy(dims.begin(), dims.end(), ranked.layout.dimensions); - // Strides not yet supported. - ranked.layout.strides = nullptr; - return tensor_type; -} - -Quantization MakePerTensorQuantization(float scale, int64_t zero_point) { - Quantization quantization; - quantization.first = kLiteRtQuantizationPerTensor; - quantization.second.per_tensor.scale = scale; - quantization.second.per_tensor.zero_point = zero_point; - return quantization; -} - -LiteRtSignatureT MakeDefaultSignature(LiteRtSubgraph subgraph) { - auto tensor_name = [](auto* tensor) { return std::string(tensor->Name()); }; - - auto in_start = subgraph->Inputs().cbegin(); - auto in_end = subgraph->Inputs().cend(); - std::vector input_names(subgraph->NumInputs()); - std::transform(in_start, in_end, input_names.begin(), tensor_name); - - auto out_start = subgraph->Outputs().cbegin(); - auto out_end = subgraph->Outputs().cend(); - std::vector output_names(subgraph->NumOutputs()); - std::transform(out_start, out_end, output_names.begin(), tensor_name); - - std::string name(LiteRtSignatureT::kDefaultSignatureKey); - return LiteRtSignatureT(subgraph, std::move(input_names), - std::move(output_names), std::move(name)); -} - -::litert::Expected LookupSubgraph( - const LiteRtModelT& model, absl::string_view signature_key) { - auto sig = model.FindSignature(signature_key); - if (!sig) { - return sig.Error(); - } - return &sig->get().GetSubgraph(); -} - -void LiteRtModelT::TransferSubgraphTo(LiteRtSubgraphT::Alloc& dest, - std::vector indices) { - if (indices.empty()) { - return; - } - std::sort(indices.begin(), indices.end()); - std::vector new_inds(subgraphs_.Size(), 0); - auto num_removed = 0; - auto i = indices.begin(); - for (size_t j = 0; j < new_inds.size(); ++j) { - if (i != indices.end() && *i == j) { - ++num_removed; - // Keep track of removed sgs just for dcheck. - new_inds[j] = -1; - ++i; - continue; - } - new_inds[j] = j - num_removed; - } - - ForEachIr( - this, [&](LiteRtSubgraph subgraph, int32_t subgraph_index, LiteRtOp op) { - if (op->OpCode() != kLiteRtOpCodeShloComposite) { - return; - } - auto opts = litert::internal::TakeTflOptions2(*op); - auto& decomp_ind = - opts.AsStableHLOCompositeOptions()->decomposition_subgraph_index; - const auto new_ind = new_inds[decomp_ind]; - - // This op is either in a removed subgraph or refers to a subgraph that - // is not being removed. - ABSL_DCHECK((subgraph_index == -1) || (new_ind >= 0)); - - decomp_ind = new_ind; - litert::internal::SetTflOptions2(*op, std::move(opts)); - }); - - subgraphs_.TransferTo(dest, std::move(indices)); -} - -namespace litert::internal { - -void SetTflOpCodeInd(LiteRtOpT& litert_op, int32_t tfl_op_code_ind) { - litert_op.tfl_op_code_ind_ = tfl_op_code_ind; -} - -int32_t GetTflOpCodeInd(const LiteRtOpT& litert_op) { - return litert_op.tfl_op_code_ind_; -} - -const TflOptions& GetTflOptions(const LiteRtOpT& litert_op) { - return litert_op.tfl_option_; -} - -const TflOptions2& GetTflOptions2(const LiteRtOpT& litert_op) { - return litert_op.tfl_option_2_; -} - -TflOptions&& TakeTflOptions(LiteRtOpT& litert_op) { - return std::move(litert_op.tfl_option_); -} - -TflOptions2&& TakeTflOptions2(LiteRtOpT& litert_op) { - return std::move(litert_op.tfl_option_2_); -} - -const std::vector& GetTflOpCodes( - const LiteRtModelT& litert_model) { - return litert_model.tfl_operator_codes_; -} - -std::vector&& TakeTflOpCodes(LiteRtModelT& litert_model) { - return std::move(litert_model.tfl_operator_codes_); -} - -// new stuff start -void SetTflFlatbuffer(LiteRtModelT& litert_model, - LiteRtModelT::TflFlatbuffer&& tfl_flatbuffer) { - litert_model.tfl_flatbuffer_ = std::move(tfl_flatbuffer); -} - -const LiteRtModelT::TflFlatbuffer& GetTflFlatbuffer( - const LiteRtModelT& litert_model) { - return litert_model.tfl_flatbuffer_; -} -// new stuff end - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model.h b/tensorflow/lite/experimental/litert/core/model/model.h deleted file mode 100644 index 5d4bcfacb0a380..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model.h +++ /dev/null @@ -1,1024 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/absl_check.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" // IWYU pragma: export -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/buffer_manager.h" -#include "tensorflow/lite/experimental/litert/core/model/ir_allocator.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -//////////////////////////////////////////////////////////////////////////////// -// Internal LiteRtIR -// -// These are the backing definitions for the opaque types in the c api -// (c/litert_model.h). -// -// < STORAGE DETAIL > -// -// Unless deleted as a result of calls c api client, the lifetime of all "IR -// Objects" (definitions of opaque types) are designed to be transitively owned -// by the LiteRtModelT which is generally the longset living object. See various -// "Emplace" methods. -// -// Since c api clients interface with pointers to IR Ojbects, a form of pointer -// stability is desirable. Classes in this file enforce that pointers to IR -// Objects are valid for their entire life time. Thus a c api client may store -// pointers and depend on referential equality of IR Objects thoughout different -// calls. This also facilitates storing edge/parent-references as pointers -// within IR Objects. -// -// Direct copying is generally not allowed for IR Objects since copying -// instances of mutually recursive types is not entirely well-defined. -// -// IR Objects are generally default constructible to facilitate stable storage -// and iterative construction. -// -// < EXPOSING TFLITE SCHEMA > -// -// Direct access to tflite schema types is limited to the "detail" namespace. -// This indicates that encapsulating all the details of the flatbuffer is a WIP. -// Future implementations may use different data forms (new litert serialized -// format, tflite runtime types etc). -// -// < USAGE NOTE > -// -// The classes here contain only simple getters & setters. Care should be taken -// to leave the IR in a valid state when using setters since the graph is -// doubly-linked. Higher-level functionality for correct graph mutation can be -// found in "model_graph.h". -//////////////////////////////////////////////////////////////////////////////// - -// All tflite schema type usage. -namespace litert::internal { - -// OP - -// Placeholder for the ind of the dispatch op code added during serialization. -static constexpr auto kDispatchOpCodeTflInd = -1; - -void SetTflOpCodeInd(LiteRtOpT& litert_op, int32_t tfl_op_code_ind); - -int32_t GetTflOpCodeInd(const LiteRtOpT& litert_op); - -template -void SetTflOptions(LiteRtOpT& litert_op, Arg&& arg); - -template -void SetTflOptions2(LiteRtOpT& litert_op, Arg&& arg); - -const ::litert::internal::TflOptions& GetTflOptions(const LiteRtOpT& litert_op); - -const ::litert::internal::TflOptions2& GetTflOptions2( - const LiteRtOpT& litert_op); - -::litert::internal::TflOptions&& TakeTflOptions(LiteRtOpT& litert_op); - -::litert::internal::TflOptions2&& TakeTflOptions2(LiteRtOpT& litert_op); - -void ClearTflOptions(LiteRtOpT& litert_op); - -// MODEL - -const std::vector<::litert::internal::TflOpCodePtr>& GetTflOpCodes( - const LiteRtModelT& litert_model); - -template -void SetTflOpCodes(LiteRtModelT& litert_model, Arg&& arg); - -std::vector<::litert::internal::TflOpCodePtr>&& TakeTflOpCodes( - LiteRtModelT& litert_model); - -void SetTflFlatbuffer(LiteRtModelT& litert_model, - ::litert::internal::FlatbufferWrapper&& tfl_flatbuffer); - -const ::litert::internal::FlatbufferWrapper& GetTflFlatbuffer( - const LiteRtModelT& litert_model); - -} // namespace litert::internal - -// -// Helpers for conceptual unions from C api. -// - -// // For requesting opaque data stored within IR. -using ScratchBufferProvider = std::function; - -// TENSOR TYPE - -// Detail convenience type for tensor type union. -typedef union { - LiteRtUnrankedTensorType unranked_tensor_type; - LiteRtRankedTensorType ranked_tensor_type; -} TensorTypeDetail; - -// Union and identifier for tensor types. -using TensorType = std::pair; - -// Construct tensor type union as ranked tensor. NOTE: Copies data in `dims`. -TensorType MakeRankedTensorType(LiteRtElementType element_type, - absl::Span dims); - -// QUANTIZATION TYPE - -// Detail convenience type for quantization type union. -typedef union { - LiteRtQuantizationPerTensor per_tensor; - LiteRtQuantizationPerChannel per_channel; -} QuantizationDetail; - -// Union and identifier for quantization types. -using Quantization = std::pair; - -// Make default type with quantization info. -inline Quantization MakeEmptyQuantization() { - return Quantization(kLiteRtQuantizationNone, QuantizationDetail()); -} - -// Construct quantization type as per tensor. -Quantization MakePerTensorQuantization(float scale, int64_t zero_point); - -// Construct quantization type as per channel, requires buffer callback to -// store data. -template -Quantization MakePerChannelQuantization(const Scales& scales, - const ZeroPoints& zero_points, - int32_t quantized_dim, - ScratchBufferProvider buffer_provider) { - const auto size = std::size(scales); - ABSL_DCHECK_EQ(size, std::size(zero_points)); - - Quantization res; - res.first = kLiteRtQuantizationPerChannel; - - res.second.per_channel.num_channels = size; - res.second.per_channel.quantized_dimension = quantized_dim; - - const size_t scales_buf_size = size * sizeof(float); - const size_t zeros_buf_size = size * sizeof(int64_t); - auto* scales_buf = reinterpret_cast(buffer_provider(scales_buf_size)); - auto* zeros_buf = reinterpret_cast(buffer_provider(zeros_buf_size)); - std::copy(std::cbegin(scales), std::cend(scales), scales_buf); - std::copy(std::cbegin(zero_points), std::cend(zero_points), zeros_buf); - - res.second.per_channel.scales = scales_buf; - res.second.per_channel.zero_points = zeros_buf; - - return res; -} - -// -// Tensor -// - -// Constant data associated with a tensor. -class LiteRtWeightsT { - private: - using OwnedBuffer = ::litert::OwningBufferRef; - - public: - using BufferId = ::litert::internal::BufferManager::BufferId; - using BufferManager = ::litert::internal::BufferManager; - - // Underlying data. - ::litert::BufferRef Buffer() const { - auto buf = GetBufferManager()->GetBuffer(buffer_id_); - ABSL_DCHECK(buf.HasValue()); - return *buf; - } - - // Set the buffer manager, expects a stable pointer. A default buffer manager - // will be initialized for convenience but most cases will share a single - // buffer manager owned by the model. - void SetBufferManager(BufferManager* buffer_manager) { - buffer_manager_ = buffer_manager; - } - - // Get the underlying buffer manager. - BufferManager* GetBufferManager() const { - if (std::holds_alternative(buffer_manager_)) { - return std::get(buffer_manager_); - } else { - return std::get(buffer_manager_).get(); - } - } - - // Set from a pre-registered buffer. This expects buffer was registered - // with the same manager. - void SetBufferId(BufferId buffer_id) { buffer_id_ = buffer_id; } - - // Get the id generated for the buffer by the manager. - BufferId GetBufferId() const { return buffer_id_; } - - // IR is generally, default constructible and movable but not copyable. - LiteRtWeightsT() = default; - explicit LiteRtWeightsT(BufferManager* buffer_manager) - : buffer_manager_(buffer_manager) {} - LiteRtWeightsT(const LiteRtWeightsT&) = delete; - LiteRtWeightsT(LiteRtWeightsT&&) = default; - LiteRtWeightsT& operator=(const LiteRtWeightsT&) = delete; - LiteRtWeightsT& operator=(LiteRtWeightsT&&) = default; - - private: - BufferId buffer_id_ = BufferManager::kEmptyBufferId; - std::variant buffer_manager_ = - std::make_unique(); -}; - -// Set weights via an unowned buffer. Caller is responsible for ensuring the -// buffer outlives the weights. Registers the buffer with the manager. -inline void SetWeightsFromUnownedBuffer( - LiteRtWeightsT& weights, ::litert::BufferRef buffer, - std::optional context = std::nullopt) { - auto* manager = weights.GetBufferManager(); - auto buf_id = manager->RegisterNonOwnedBuffer(buffer, context); - weights.SetBufferId(buf_id); -} - -// Set weights via an unowned buffer. Caller is responsible for ensuring the -// buffer outlives the weights. Registers the buffer with the manager. -inline void SetWeightsFromOwnedBuffer( - LiteRtWeightsT& weights, ::litert::OwningBufferRef&& buffer, - std::optional context = std::nullopt) { - auto* manager = weights.GetBufferManager(); - auto buf_id = manager->RegisterOwnedBuffer(std::move(buffer), context); - weights.SetBufferId(buf_id); -} - -// Fundamental value in a litert program, "edges" in the graph. -class LiteRtTensorT { - private: - using UserData = std::unique_ptr; - - public: - using Ref = std::reference_wrapper; - using Use = std::pair; - using UseVec = std::vector; - using Alloc = ::litert::internal::IrAllocator; - - // The ops that take this tensor as input. - const std::vector& Users() const { return users_; } - std::vector& Users() { return users_; } - - // Which operand index users take this tensor on, respects the ordering of - // users.. - const std::vector& UserArgInds() const { - return user_arg_inds_; - } - std::vector& UserArgInds() { return user_arg_inds_; } - - // Number of uses, same as number of user arg inds. - size_t NumUses() const { return users_.size(); } - - // Get the ith use. - Use GetUse(size_t ind) const { - return {users_.at(ind), user_arg_inds_.at(ind)}; - } - - // Remove the use at the given index. - void RemoveUse(size_t ind) { - users_.erase(users_.begin() + ind); - user_arg_inds_.erase(user_arg_inds_.begin() + ind); - } - - // Get the op that outputs this tensor, null if constant or subgraph input. - LiteRtOp DefiningOp() const { return defining_op_; } - - // Get the output index of the op that defines this tensor, only meaningful - // if it has a defining op. - LiteRtParamIndex DefiningOpOutInd() const { return defining_op_out_ind_; } - - // Update the defining op of this tensor. The caller is required to update the - // given op's output if not already correct. - void SetDefiningOp(LiteRtOpT& defining_op, LiteRtParamIndex out_ind) { - defining_op_ = &defining_op; - defining_op_out_ind_ = out_ind; - } - - // Set the defining op to none. - void ClearDefiningOp() { - defining_op_ = nullptr; - defining_op_out_ind_ = 0; - } - - // Any constant data associated with this tensor. - const LiteRtWeightsT& Weights() const { return weights_; } - LiteRtWeightsT& Weights() { return weights_; } - - // Authored name associated with this tensor. May be empty. - absl::string_view Name() const { return name_; } - - // Update the name associated with this tensor. - void SetName(std::string name) { name_ = std::move(name); } - - // Get quantization information for this tensor. - const Quantization& Qparams() const { return quantization_; } - Quantization& Qparams() { return quantization_; } - - // Set quantization information. - template - void SetQarams(Arg&& arg) { - quantization_ = std::forward(arg); - } - - // Get the tensor type of this tensor. - const TensorType& Type() const { return tensor_type_; } - TensorType& Type() { return tensor_type_; } - - // Set the tensor type. - template - void SetType(Arg&& arg) { - tensor_type_ = std::forward(arg); - } - - // Get a new buffer that will live as long as this tensor. Used for storing - // various buffers passed through c-api (dims, quantization etc). - // NOTE: This is just scratch data unrelated to weights buffer. - uint8_t* RequestScratchBuffer(size_t size) { - user_data_.push_back(std::make_unique(size)); - return user_data_.back().get(); - } - - // Allow for implicit conversion to scratch buffer provider. - // NOTE: This is just scratch data unrelated to weights buffer. - // NOLINTNEXTLINE - operator ScratchBufferProvider() & { - return [this](auto s) { return this->RequestScratchBuffer(s); }; - } - - // IR is generally, default constructible and movable but not copyable. - LiteRtTensorT() = default; - LiteRtTensorT(::litert::internal::BufferManager* buffer_manager) - : weights_(buffer_manager) {} - LiteRtTensorT(const LiteRtTensorT&) = delete; - LiteRtTensorT(LiteRtTensorT&&) = default; - LiteRtTensorT& operator=(const LiteRtTensorT&) = delete; - LiteRtTensorT& operator=(LiteRtTensorT&&) = default; - - private: - std::vector users_; - std::vector user_arg_inds_; - - LiteRtOp defining_op_ = nullptr; - LiteRtParamIndex defining_op_out_ind_; - - LiteRtWeightsT weights_; - Quantization quantization_; - TensorType tensor_type_; - - std::string name_; - - std::vector user_data_; -}; - -// Helper to get multiple uses at once. -template -LiteRtTensorT::UseVec GetTensorUses(const LiteRtTensorT& tensor, - const Inds& inds) { - auto start = std::cbegin(inds); - auto end = std::cend(inds); - LiteRtTensorT::UseVec uses(end - start); - auto get = [&tensor = std::as_const(tensor)](auto i) { - return tensor.GetUse(i); - }; - std::transform(start, end, uses.begin(), get); - return uses; -} - -// -// Op -// - -// Fundamental unit of compute of a litert program, or "nodes" in the graph. -class LiteRtOpT { - public: - using Ref = std::reference_wrapper; - using Alloc = ::litert::internal::IrAllocator; - - // Input tensors for this op. - const std::vector& Inputs() const { return inputs_; } - std::vector& Inputs() { return inputs_; } - - // Access input at given ind. - LiteRtTensorT& Input(size_t ind) { return *Inputs().at(ind); } - const LiteRtTensorT& Input(size_t ind) const { return *Inputs().at(ind); } - - // Number of input tensors. - size_t NumInputs() const { return inputs_.size(); } - - // Output tensors for this op. - const std::vector& Outputs() const { return outputs_; } - std::vector& Outputs() { return outputs_; } - - // Number of output tensors. - size_t NumOutputs() const { return outputs_.size(); } - - // Access output at given ind. - LiteRtTensorT& Output(size_t ind) { return *Outputs().at(ind); } - const LiteRtTensorT& Output(size_t ind) const { return *Outputs().at(ind); } - - // Remove the ith entry of input list. - void RemoveInput(size_t ind) { inputs_.erase(inputs_.begin() + ind); } - - // Remove the ith entry of output list. - void RemoveOutput(size_t ind) { outputs_.erase(outputs_.begin() + ind); } - - // Get any custom options attached to this op. Empty if there are none. - litert::BufferRef CustomOptions() const { return custom_options_; } - - // Attach custom opaque optins to this op. - template - void SetCustomOptions(Args&&... args) { - custom_options_ = - ::litert::OwningBufferRef(std::forward(args)...); - } - - // Sets the custom options to zero length buffer. - void ClearCustomOptions() { custom_options_.Reset(); } - - // Get the op code. - LiteRtOpCode OpCode() const { return litert_op_code_; } - - // Set the op code. - void SetOpCode(LiteRtOpCode litert_op_code) { - litert_op_code_ = litert_op_code; - } - - // IR is generally, default constructible and movable but not copyable. - LiteRtOpT() = default; - LiteRtOpT(const LiteRtOpT&) = delete; - LiteRtOpT(LiteRtOpT&&) = default; - LiteRtOpT& operator=(const LiteRtOpT&) = delete; - LiteRtOpT& operator=(LiteRtOpT&&) = default; - - // Friendship for internal tflite details. - friend void litert::internal::SetTflOpCodeInd(LiteRtOpT& litert_op, - int32_t tfl_op_code_ind); - - friend int32_t litert::internal::GetTflOpCodeInd(const LiteRtOpT& litert_op); - - template - friend void litert::internal::SetTflOptions(LiteRtOpT& litert_op, Arg&& arg); - - template - friend void litert::internal::SetTflOptions2(LiteRtOpT& litert_op, Arg&& arg); - - friend const ::litert::internal::TflOptions& litert::internal::GetTflOptions( - const LiteRtOpT& litert_op); - - friend const ::litert::internal::TflOptions2& - litert::internal::GetTflOptions2(const LiteRtOpT& litert_op); - - friend ::litert::internal::TflOptions&& litert::internal::TakeTflOptions( - LiteRtOpT& litert_op); - - friend ::litert::internal::TflOptions2&& litert::internal::TakeTflOptions2( - LiteRtOpT& litert_op); - - friend void litert::internal::ClearTflOptions(LiteRtOpT& litert_op); - - private: - LiteRtOpCode litert_op_code_; - - ::litert::OwningBufferRef custom_options_; - - std::vector inputs_; - std::vector outputs_; - - // TFLITE - int32_t tfl_op_code_ind_ = litert::internal::kDispatchOpCodeTflInd; - ::litert::internal::TflOptions tfl_option_; - ::litert::internal::TflOptions2 tfl_option_2_; -}; - -// Clears any attribute data and sets the op to be a dispatch op. -inline void MakeDispatchOp(LiteRtOpT& op) { - litert::internal::ClearTflOptions(op); - op.ClearCustomOptions(); - op.SetOpCode(kLiteRtOpCodeTflCustom); - litert::internal::SetTflOpCodeInd(op, - litert::internal::kDispatchOpCodeTflInd); -} - -// -// Subgraph -// - -// Fundamental block of a litert program. Manages the storage of all -// ops and tensor within. -class LiteRtSubgraphT { - public: - using Ref = std::reference_wrapper; - using Alloc = ::litert::internal::IrAllocator; - - // Get a stable pointer for all of the tensors in this subgraph. - absl::Span Tensors() { return tensors_.Elements(); } - absl::Span Tensors() const { return tensors_.Elements(); } - - // Access the tensor at given ind. - LiteRtTensorT& Tensor(size_t ind) { return *Tensors().at(ind); } - const LiteRtTensorT& Tensor(size_t ind) const { return *Tensors().at(ind); } - - // Get a stable pointer for all of the ops in this subgraph. Will - // be a valid toplological order. - absl::Span Ops() { return ops_.Elements(); } - absl::Span Ops() const { return ops_.Elements(); } - - // Access op at the given ind. - LiteRtOpT& Op(size_t ind) { return *Ops().at(ind); } - const LiteRtOpT& Op(size_t ind) const { return *Ops().at(ind); } - - // All the subgraph input tensors, these also exist in Tensors. - const std::vector& Inputs() const { return inputs_; } - std::vector& Inputs() { return inputs_; } - - // Number of inputs tensors. - size_t NumInputs() const { return inputs_.size(); } - - // Access the subgraph input at given ind. - LiteRtTensorT& Input(size_t ind) { return *Inputs().at(ind); } - const LiteRtTensorT& Input(size_t ind) const { return *Inputs().at(ind); } - - // All the subgraph output tensors, these also exist in Tensors. - const std::vector& Outputs() const { return outputs_; } - std::vector& Outputs() { return outputs_; } - - // Number of outputs tensors. - size_t NumOutputs() const { return outputs_.size(); } - - // Access the subgraph output at given ind. - LiteRtTensorT& Output(size_t ind) { return *Outputs().at(ind); } - const LiteRtTensorT& Output(size_t ind) const { return *Outputs().at(ind); } - - // Clear the entry for the ith input. - void ClearInput(size_t ind) { inputs_.erase(inputs_.begin() + ind); } - - // Clear the entry for the ith output. - void ClearOutput(size_t ind) { outputs_.erase(outputs_.begin() + ind); } - - // Construct a new tensor which will be owned by this subgraph and get a - // reference to it. - template - LiteRtTensorT& EmplaceTensor(Args&&... args) { - if (buffer_manager_ == nullptr) { - return tensors_.EmplaceBack(std::forward(args)...); - } else { - // std::cerr << "Emplacing tensor with buffer manager \n"; - return tensors_.EmplaceBack(buffer_manager_, std::forward(args)...); - } - } - - // Construct a new op which will be owned by this subgraph and get a - // reference to it. - template - LiteRtOpT& EmplaceOp(Args&&... args) { - return ops_.EmplaceBack(std::forward(args)...); - } - - // De-allocates ops that pass given predicate. Returns number of ops removed. - size_t RemoveOpIf(std::function pred) { - return ops_.RemoveIf(pred); - } - - // De-allocates tensors that pass given predicate. Returns number of tensors - // removed. - size_t RemoveTensorIf(std::function pred) { - return tensors_.RemoveIf(pred); - } - - // IR is generally, default constructible and movable but not copyable. - LiteRtSubgraphT() = default; - LiteRtSubgraphT(::litert::internal::BufferManager* buffer_manager) - : buffer_manager_(buffer_manager) {}; - LiteRtSubgraphT(const LiteRtSubgraphT&) = delete; - LiteRtSubgraphT(LiteRtSubgraphT&&) = default; - LiteRtSubgraphT& operator=(const LiteRtSubgraphT&) = delete; - LiteRtSubgraphT& operator=(LiteRtSubgraphT&&) = default; - - // Get the buffer manager for this subgraph. - ::litert::internal::BufferManager* GetBufferManager() const { - return buffer_manager_; - } - - private: - // If null, tensors emplaced will own their own buffer managers. - ::litert::internal::BufferManager* buffer_manager_ = nullptr; - - LiteRtTensorT::Alloc tensors_; - - LiteRtOpT::Alloc ops_; - - std::vector inputs_; - std::vector outputs_; -}; - -// -// Signature -// - -class LiteRtSignatureT { - private: - using StrVec = std::vector; - - public: - using Ptr = std::unique_ptr; - using Ref = std::reference_wrapper; - using Alloc = ::litert::internal::IrAllocator; - - static constexpr absl::string_view kDefaultSignatureKey = - ""; - - LiteRtSignatureT(LiteRtSubgraph subgraph, StrVec input_names, - StrVec output_names, std::string key) - : key_(std::move(key)), - subgraph_(subgraph), - input_names_(std::move(input_names)), - output_names_(std::move(output_names)) {} - - // String named inputs for called subgraph. - const StrVec& InputNames() const { return input_names_; } - - // String named outputs for called subgraph. - const StrVec& OutputNames() const { return output_names_; } - - // Get the callable subgraph. - const LiteRtSubgraphT& GetSubgraph() const { return *subgraph_; } - LiteRtSubgraphT& GetSubgraph() { return *subgraph_; } - - // Name of the callable signature. - absl::string_view Key() const { return key_; } - - bool operator==(const LiteRtSignatureT& other) const { - const auto key_eq = key_ == other.key_; - const auto subgraph_eq = subgraph_ == other.subgraph_; - const auto input_names_eq = input_names_ == other.input_names_; - const auto output_names_eq = output_names_ == other.output_names_; - return key_eq && subgraph_eq && input_names_eq && output_names_eq; - } - - // IR is generally, default constructible and movable but not copyable. - LiteRtSignatureT() = default; - LiteRtSignatureT(const LiteRtSignatureT&) = delete; - LiteRtSignatureT(LiteRtSignatureT&&) = default; - LiteRtSignatureT& operator=(const LiteRtSignatureT&) = delete; - LiteRtSignatureT& operator=(LiteRtSignatureT&&) = default; - - private: - std::string key_; - - LiteRtSubgraph subgraph_; - - StrVec input_names_; - StrVec output_names_; -}; - -// Make a basic signature from information in the given subgraph. Used with the -// main subgraph when no explicit signatures have been authored. -LiteRtSignatureT MakeDefaultSignature(LiteRtSubgraph subgraph); - -// -// Model -// - -// Root-level graph object for litert programs. Manages the storage -// of all litert graph objects within. -class LiteRtModelT { - public: - using Ref = std::reference_wrapper; - using Ptr = std::unique_ptr; - using TflOpCodes = std::vector; - - using BufferManager = ::litert::internal::BufferManager; - using BufferId = BufferManager::BufferId; - - using OpAssetReference = std::pair; - using OpAssetMap = absl::flat_hash_map; - - using MetadataMap = absl::flat_hash_map; - - using TflFlatbuffer = ::litert::internal::FlatbufferWrapper; - - // TODO replace this with the index of the default signature. - static constexpr const size_t kMainSubgraphIndex = 0; - - // SUBGRAPHS - - // Get a stable pointer for all of the subgraphs within this model. - absl::Span Subgraphs() { return subgraphs_.Elements(); } - absl::Span Subgraphs() const { - return subgraphs_.Elements(); - } - - // Access subgraph at given ind. - LiteRtSubgraphT& Subgraph(size_t ind) { return *Subgraphs().at(ind); } - const LiteRtSubgraphT& Subgraph(size_t ind) const { - return *Subgraphs().at(ind); - } - - // Number of subraphs. - size_t NumSubgraphs() const { return subgraphs_.Elements().size(); } - - // Default entry point of this model. - const LiteRtSubgraphT* MainSubgraph() const { - return &Subgraph(kMainSubgraphIndex); - } - LiteRtSubgraph MainSubgraph() { return &Subgraph(kMainSubgraphIndex); } - - // Look up signature by key. - litert::Expected FindSignature( - absl::string_view signature_key) const { - for (LiteRtSignature sig : signatures_.Elements()) { - if (sig->Key() == signature_key) { - return std::ref(*sig); - } - } - return ::litert::Error(kLiteRtStatusErrorNotFound, "Signature not found"); - } - - // Build a new subgraph and get a stable reference to it. - template - LiteRtSubgraphT& EmplaceSubgraph(Args&&... args) { - return subgraphs_.EmplaceBack(Buffers(), std::forward(args)...); - } - - // Transfers given subgraphs into this model. New subgraphs are appended. - void TransferSubgraphsFrom(LiteRtSubgraphT::Alloc&& subgraphs) { - // TODO: Consider mergeing buffer managers here. - subgraphs_.TransferFrom(std::move(subgraphs)); - } - - // Cut all by the first `size` subgraphs. Does nothing if given size is - // greater or equal to current. - void ResizeSubgraphsDown(size_t size) { subgraphs_.ResizeDown(size); } - - // Transfers the subgraph at the given index to the back of the given - // allocator. Also updates any IR owned by the model that refers to subgraphs - // by index (e.g. composites). Does not update any IR in the subgraphs being - // transferred. - void TransferSubgraphTo(LiteRtSubgraphT::Alloc& dest, - std::vector indices); - - // SIGNATURES - - // All signatures registered with this model. - absl::Span Signatures() const { - return signatures_.Elements(); - } - - // Construct a new signature for this model. - template - LiteRtSignatureT& EmplaceSignature(Args&&... args) { - return signatures_.EmplaceBack(std::forward(args)...); - } - - // METADATA - - // Look up metadata by key, getting a view of its buffer as a string - // if it exists. - litert::Expected> FindMetadata( - absl::string_view key) const { - if (auto it = metadata_.find(key); it != metadata_.end()) { - const auto buf_id = it->second; - return Buffers()->GetBuffer(buf_id); - } - return ::litert::Error(kLiteRtStatusErrorNotFound); - } - - // Metadata key-val pair iterator. - MetadataMap::iterator MetadataBegin() { return metadata_.begin(); } - MetadataMap::iterator MetadataEnd() { return metadata_.end(); } - - // Adds a new metadata buffer to the model. Fails if it already exists. - template - LiteRtStatus PushMetadata(absl::string_view key, Args&&... args) { - if (metadata_.contains(key)) { - return kLiteRtStatusErrorInvalidArgument; - } - const auto buf_id = Buffers()->RegisterOwnedBuffer( - ::litert::OwningBufferRef(std::forward(args)...)); - metadata_.emplace(std::make_pair(std::string(key), buf_id)); - return kLiteRtStatusOk; - } - - // BUFFERS - - // Get stable pointer to buffer manager object. - BufferManager* Buffers() const { return buffer_manager_.get(); } - - // Attach an asset to the given op. An asset is a non-tensor buffer - // that is used by the op. Assets may be referenced by multiple ops. - // Each edge from an op to an asset is identified by a name. All buffers - // are appended to the model upon serialization and referenced by offset - // relative to the start of the model within the referring op's custom - // options. - void AttachAssetToOp(LiteRtOp op, BufferId buf_id, std::string name) { - OpAssetReference ref = {buf_id, std::move(name)}; - external_buffer_map_.emplace(op, std::move(ref)); - } - - // Returns an immutable view of the external buffer and the name of the edge - // if the given op has one attached. - litert::Expected FindOpAsset(LiteRtOp op) { - if (auto it = external_buffer_map_.find(op); - it != external_buffer_map_.end()) { - return it->second; - } - return ::litert::Error(kLiteRtStatusErrorNotFound); - } - - // Contains details about the compiler used if this model was compiled. - struct BuildStamp { - absl::string_view soc_manufacturer; - absl::string_view soc_model; - }; - - // IR is generally, default constructible and movable but not copyable. - LiteRtModelT() = default; - LiteRtModelT(const LiteRtModelT&) = delete; - LiteRtModelT(LiteRtModelT&&) = default; - LiteRtModelT& operator=(const LiteRtModelT&) = delete; - LiteRtModelT& operator=(LiteRtModelT&&) = default; - - // TFLITE - - // Friendship for internal tflite details. - friend const TflOpCodes& litert::internal::GetTflOpCodes( - const LiteRtModelT& litert_model); - - template - friend void litert::internal::SetTflOpCodes(LiteRtModelT& litert_model, - Arg&& arg); - - friend TflOpCodes&& litert::internal::TakeTflOpCodes( - LiteRtModelT& litert_model); - - friend void litert::internal::SetTflFlatbuffer( - LiteRtModelT& litert_model, TflFlatbuffer&& tfl_flatbuffer); - - friend const TflFlatbuffer& litert::internal::GetTflFlatbuffer( - const LiteRtModelT& litert_model); - - explicit LiteRtModelT(TflFlatbuffer&& tfl_flatbuffer) - : tfl_flatbuffer_(std::move(tfl_flatbuffer)) {} - - private: - LiteRtSubgraphT::Alloc subgraphs_; - LiteRtSignatureT::Alloc signatures_; - - MetadataMap metadata_; - OpAssetMap external_buffer_map_; - - // Use unique ptr here to keep stable. - BufferManager::Ptr buffer_manager_ = std::make_unique(); - - // TFLITE - TflOpCodes tfl_operator_codes_; - TflFlatbuffer tfl_flatbuffer_; -}; - -// Get the build stamp from the model if it exists. -// TODO: Consider a setter and internalizeing all build stamp stuff behind model -// interface. -std::optional GetBuildStamp( - const LiteRtModelT& model); - -// Returns true if this model contains any ops compiled for NPU. -bool IsCompiled(const LiteRtModelT& model); - -// Get the custom op code from a given op if it is a custom op. -std::optional GetCustomOpCode(const LiteRtModelT& model, - const LiteRtOpT& op); - -// Lookup subgraph by signature name. -::litert::Expected LookupSubgraph( - const LiteRtModelT& model, absl::string_view signature_key); - -namespace litert::internal { - -template -void SetTflOptions(LiteRtOpT& litert_op, Arg&& arg) { - litert_op.tfl_option_ = std::forward(arg); -} - -template -void SetTflOptions2(LiteRtOpT& litert_op, Arg&& arg) { - litert_op.tfl_option_2_ = std::forward(arg); -} - -inline void ClearTflOptions(LiteRtOpT& litert_op) { - litert_op.tfl_option_2_.Reset(); - litert_op.tfl_option_.Reset(); -} - -template -void SetTflOpCodes(LiteRtModelT& litert_model, Arg&& arg) { - litert_model.tfl_operator_codes_ = std::forward(arg); -} - -} // namespace litert::internal - -// -// Misc Ir Containers -// - -using LiteRtOpWithPartitionIndex = std::pair; - -// Used for communicating selections of ops in when partitioning. -class LiteRtOpListT { - public: - void Push(LiteRtOp op, LiteRtParamIndex partition_index = 0) { - values_.push_back(LiteRtOpWithPartitionIndex(op, partition_index)); - } - - std::vector Values() const { - std::vector ops; - ops.reserve(values_.size()); - ops.assign(values_.begin(), values_.end()); - - return ops; - } - - private: - // Investigate if this is possible with vector (hit some issues). - std::list values_; -}; - -// -// Traversal Utils -// - -// Apply func to all the IR in the given model. Iteration behavior is determined -// by the callback signature. -template -void ForEachIr(LiteRtModel model, F func) { - // Per subgraph callbacks. - using SgF1 = std::function; - using SgF2 = std::function; - - // Per op callbacks. - using OpF1 = std::function; - using OpF2 = std::function; - using OpF3 = - std::function; - - constexpr bool kIsSgOpF1 = std::is_convertible_v; - constexpr bool kIsSgF2 = std::is_convertible_v; - constexpr bool kIsOpF1 = std::is_convertible_v; - constexpr bool kIsOpF2 = std::is_convertible_v; - constexpr bool kIsOpF3 = std::is_convertible_v; - - for (int i = 0; i < model->NumSubgraphs(); ++i) { - auto subgraph = model->Subgraphs()[i]; - - if constexpr (kIsSgF2) { - func(subgraph, i); - } else if constexpr (kIsSgOpF1) { - func(subgraph); - } else { - for (int j = 0; j < subgraph->Ops().size(); ++j) { - auto* op = subgraph->Ops()[j]; - if constexpr (kIsOpF1) { - func(op); - } else if constexpr (kIsOpF2) { - func(subgraph, op); - } else if constexpr (kIsOpF3) { - func(subgraph, i, op); - } - } - } - } -} - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_buffer.cc b/tensorflow/lite/experimental/litert/core/model/model_buffer.cc deleted file mode 100644 index 3353b3adbf10af..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_buffer.cc +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_load.h" -#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" - -namespace litert { -namespace internal { - -Expected> GetModelBufWithByteCode( - LiteRtModelT&& model, - const absl::flat_hash_map>& - custom_code_to_npu_bytecode, - size_t bytecode_alignment) { - for (const auto& subgraph : model.Subgraphs()) { - for (auto op : subgraph->Ops()) { - if (op->OpCode() == kLiteRtOpCodeTflCustom) { - auto custom_code = GetCustomOpCode(model, *op); - if (!custom_code) { - continue; - } - - auto iter = custom_code_to_npu_bytecode.find(*custom_code); - if (iter == custom_code_to_npu_bytecode.end()) { - return Error(kLiteRtStatusErrorUnsupported, - absl::StrFormat("Unexpected custom code: %s", - custom_code->c_str())); - } - - LiteRtOpT* custom_op = op; - OwningBufferRef byte_code(iter->second); - const auto buf_id = - model.Buffers()->RegisterOwnedBuffer(std::move(byte_code)); - model.AttachAssetToOp(custom_op, buf_id, ""); - } - } - } - - return SerializeModel(std::move(model), bytecode_alignment); -} - -Expected> GetModelBufWithByteCode( - absl::string_view tfl_file, - const absl::flat_hash_map& - custom_code_to_npu_file, - size_t bytecode_alignment) { - auto model = LoadModelFromFile(tfl_file); - if (!model) { - return model.Error(); - } - - absl::flat_hash_map> - custom_code_to_npu_bytecode; - for (auto& iter : custom_code_to_npu_file) { - auto npu_file_buf = LoadBinaryFile(iter.second); - if (!npu_file_buf) { - return npu_file_buf.Error(); - } - custom_code_to_npu_bytecode[iter.first] = std::move(*npu_file_buf); - } - - return GetModelBufWithByteCode( - std::move(**model), custom_code_to_npu_bytecode, bytecode_alignment); -} - -Expected> GetModelBufWithByteCode( - LiteRtModelT&& model, BufferRef npu_byte_code, - size_t bytecode_alignment) { - absl::flat_hash_map> - custom_code_to_npu_bytecode; - for (const auto& subgraph : model.Subgraphs()) { - for (auto op : subgraph->Ops()) { - if (op->OpCode() == kLiteRtOpCodeTflCustom) { - auto custom_code = GetCustomOpCode(model, *op); - if (!custom_code) { - continue; - } - OwningBufferRef byte_code(npu_byte_code.Data(), - npu_byte_code.Size()); - custom_code_to_npu_bytecode[*custom_code] = std::move(byte_code); - } - } - } - - return GetModelBufWithByteCode(std::move(model), custom_code_to_npu_bytecode, - bytecode_alignment); -} - -Expected> GetModelBufWithByteCode( - absl::string_view tfl_file, absl::string_view npu_file, - size_t bytecode_alignment) { - auto model = LoadModelFromFile(tfl_file); - if (!model) { - return model.Error(); - } - - auto npu_file_buf = LoadBinaryFile(npu_file); - if (!npu_file_buf) { - return npu_file_buf.Error(); - } - - return GetModelBufWithByteCode(std::move(**model), std::move(*npu_file_buf), - bytecode_alignment); -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/model/model_buffer.h b/tensorflow/lite/experimental/litert/core/model/model_buffer.h deleted file mode 100644 index 623e86f19b2899..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_buffer.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_BUFFER_H_ - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -// Get a buffer that is the concatenation of given tflite file and one or more -// NPU byte code files. Adds metadata containing the offset/size of npu byte -// code. TFL custom ops are mapped to NPU byte code by their custom code, which -// must be non-null. -// -// NOTE: this is intended to be used for testing and tools and may be removed in -// the future. -Expected> GetModelBufWithByteCode( - absl::string_view tfl_file, - const absl::flat_hash_map& - custom_code_to_npu_file, - size_t bytecode_alignment = 1); - -// Same as above, but with a map specifying NPU byte code buffers. -Expected> GetModelBufWithByteCode( - LiteRtModelT&& model, - const absl::flat_hash_map>& - custom_code_to_npu_bytecode, - size_t bytecode_alignment = 1); - -// Same as above, but only a single NPU byte code file is specified. -Expected> GetModelBufWithByteCode( - absl::string_view tfl_file, absl::string_view npu_file, - size_t bytecode_alignment = 1); - -// Same as above, but only a single NPU byte code buffer is specified. -Expected> GetModelBufWithByteCode( - LiteRtModelT&& model, BufferRef npu_byte_code, - size_t bytecode_alignment = 1); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_buffer_test.cc b/tensorflow/lite/experimental/litert/core/model/model_buffer_test.cc deleted file mode 100644 index 00eb7f557f045e..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_buffer_test.cc +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" - -#include -#include - -#include -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "tensorflow/compiler/mlir/lite/allocation.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_load.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/interpreter_builder.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/stderr_reporter.h" - -namespace litert::internal { -namespace { - -static constexpr absl::string_view kNpuFile = kGoogleTensorModelFileName; -static constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; -static constexpr absl::string_view kCascadedTfliteFile = - "simple_cascade_model_npu.tflite"; - -TEST(GetModelBufWithByteCode, CreateInterpreter) { - auto model_with_byte_code = - GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - - auto alloc = std::make_unique( - model_with_byte_code->Data(), model_with_byte_code->Size(), - tflite::DefaultErrorReporter()); - - auto fb_model = tflite::FlatBufferModel::BuildFromBuffer( - reinterpret_cast(alloc->base()), alloc->bytes()); - ASSERT_NE(fb_model, nullptr); - - tflite::ops::builtin::BuiltinOpResolver resolver; - std::unique_ptr interpreter; - tflite::InterpreterBuilder(*fb_model, resolver)(&interpreter); - EXPECT_NE(interpreter, nullptr); -} - -TEST(GetModelBufWithByteCode, CheckAppended) { - auto model_with_byte_code = - GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - - auto model = LoadModelFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - - auto* op = model->get()->Subgraphs().front()->Ops().front(); - ASSERT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - auto dispatch_opts = GetDispatchOpOptions(op->CustomOptions()); - EXPECT_EQ(dispatch_opts.name, ""); - EXPECT_LE(dispatch_opts.bytecode_offset + dispatch_opts.bytecode_size, - model_with_byte_code->Size()); -} - -TEST(GetModelBufWithByteCode, CreateInterpreterWithMultpleNpuNodes) { - absl::flat_hash_map custom_code_to_npu_file = { - {"DISPATCH_OP_1", testing::GetTestFilePath(kNpuFile)}, - {"DISPATCH_OP_2", testing::GetTestFilePath(kNpuFile)}, - }; - - auto model_with_byte_code = GetModelBufWithByteCode( - testing::GetTestFilePath(kCascadedTfliteFile), custom_code_to_npu_file); - ASSERT_TRUE(model_with_byte_code); - - auto alloc = std::make_unique( - model_with_byte_code->Data(), model_with_byte_code->Size(), - tflite::DefaultErrorReporter()); - - auto fb_model = tflite::FlatBufferModel::BuildFromBuffer( - reinterpret_cast(alloc->base()), alloc->bytes()); - ASSERT_NE(fb_model, nullptr); - - tflite::ops::builtin::BuiltinOpResolver resolver; - std::unique_ptr interpreter; - tflite::InterpreterBuilder(*fb_model, resolver)(&interpreter); - EXPECT_NE(interpreter, nullptr); -} - -TEST(GetModelBufWithByteCode, CheckAppendedWithMultipleNpuOps) { - absl::flat_hash_map custom_code_to_npu_file = { - {"DISPATCH_OP_1", testing::GetTestFilePath(kNpuFile)}, - {"DISPATCH_OP_2", testing::GetTestFilePath(kNpuFile)}, - }; - - auto model_with_byte_code = GetModelBufWithByteCode( - testing::GetTestFilePath(kCascadedTfliteFile), custom_code_to_npu_file); - ASSERT_TRUE(model_with_byte_code); - - auto model = LoadModelFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - - for (auto& op : model->get()->Subgraphs().front()->Ops()) { - ASSERT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - auto dispatch_opts = GetDispatchOpOptions(op->CustomOptions()); - EXPECT_EQ(dispatch_opts.name, ""); - EXPECT_LE(dispatch_opts.bytecode_offset + dispatch_opts.bytecode_size, - model_with_byte_code->Size()); - } -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_file_test.cc b/tensorflow/lite/experimental/litert/core/model/model_file_test.cc deleted file mode 100644 index 9f3e07fd6caf3e..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_file_test.cc +++ /dev/null @@ -1,1041 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include // NOLINT -#include -#include -#include -#include -#include - -// schema/mutable/schema_generated.h and schema/schema_generated.h (included -// through flatbuffer_tools.h via model.h) have the same #ifdef, thus this line -// need to be put at the top to ensure we get the "mutable" version. -#if 1 -#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" -#endif - -#include // IWYU pragma: keep -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" -#include "tensorflow/lite/experimental/litert/core/model/buffer_manager.h" -#include "tensorflow/lite/experimental/litert/core/model/graph_validation.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_file_test_util.h" -#include "tensorflow/lite/experimental/litert/core/model/model_load.h" -#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/test_models.h" -#include "tensorflow/lite/schema/mutable/schema_generated.h" - -namespace litert::internal { -namespace { - -using ::litert::testing::GetTestFilePath; -using ::testing::Each; -using ::testing::ElementsAreArray; -using ::testing::FloatEq; -using ::testing::Values; -using ::testing::litert::IsError; - -using ModelFactory = std::function()>; - -static constexpr absl::string_view kAddSimple = "add_simple.tflite"; -static constexpr absl::string_view kAddCst = "add_cst.tflite"; -static constexpr absl::string_view kDynamicShapeModel = - "dynamic_shape_tensor.tflite"; -static constexpr absl::string_view kSimpleMultiOp = "simple_multi_op.tflite"; -static constexpr absl::string_view kOneMul = "one_mul.tflite"; -static constexpr absl::string_view kSimpleMultiSubgraph = - "multi_subgraph.tflite"; -static constexpr absl::string_view kCstMultiSubgraph = - "cst_multi_subgraph.tflite"; - -// Load a model, then serialize and re-load. Used to test serialization. -Expected LoadModelThroughRoundTrip(absl::string_view filename) { - auto model = Model::CreateFromFile(GetTestFilePath(filename)); - if (!model) { - return model.Error(); - } - - OwningBufferRef buf; - auto [data, size, offset] = buf.GetWeak(); - - const auto opts = litert::SerializationOptions::Defaults(); - LITERT_RETURN_IF_ERROR(LiteRtSerializeModel(model->Release(), &data, &size, - &offset, true, opts)); - - // Reload model. - LiteRtModel result = nullptr; - LITERT_RETURN_IF_ERROR( - LiteRtCreateModelFromBuffer(buf.Data(), buf.Size(), &result)); - - return Model::CreateFromOwnedHandle(result); -} - -ModelFactory MakeRoundTripFactory(absl::string_view filename) { - return [=]() { return LoadModelThroughRoundTrip(filename); }; -} - -ModelFactory MakeLoadFactory(absl::string_view filename) { - return [=]() { return Model::CreateFromFile(GetTestFilePath(filename)); }; -} - -// Test fixture parameterized by a file path to test model. -class TestWithModelPath : public ::testing::TestWithParam { - protected: - std::string GetTestModelPath() const { - return testing::GetTestFilePath(GetParam()); - } -}; - -// Test fixture pareterized by a function that loads a model. -class TestWithModelFactory : public ::testing::TestWithParam { - protected: - Expected LoadModel() { return GetParam()(); } -}; - -// Simple tests -//===--------------------------------------------------------------------------- - -TEST(ModelLoadTest, BadFilepath) { - LiteRtModel model = nullptr; - EXPECT_THAT(LiteRtCreateModelFromFile("bad_path", &model), - IsError(kLiteRtStatusErrorNotFound)); -} - -TEST(ModelLoadTest, BadFileData) { - // NOLINTBEGIN -#ifndef NDEBUG - // In debug mode, flatbuffers will `assert` while verifying. This will - // cause this test to crash (as expected). - GTEST_SKIP(); -#endif - std::filesystem::path test_file_path(::testing::TempDir()); - test_file_path.append("bad_file.txt"); - - std::ofstream bad_file; - bad_file.open(test_file_path.c_str()); - bad_file << "not_tflite"; - bad_file.close(); - - LiteRtModel model = nullptr; - EXPECT_THAT(LiteRtCreateModelFromFile(test_file_path.c_str(), &model), - IsError(kLiteRtStatusErrorInvalidFlatbuffer)); - // NOLINTEND -} - -TEST(ModelLoadTest, GetCustomOpCode) { - auto model = litert::testing::LoadTestFileModel("simple_model_npu.tflite"); - ASSERT_TRUE(model); - const auto& litert_model = *model.Get(); - const auto& op = *litert_model.MainSubgraph()->Ops().front(); - auto custom_op_code = GetCustomOpCode(litert_model, op); - ASSERT_TRUE(custom_op_code.has_value()); - EXPECT_EQ(*custom_op_code, "DISPATCH_OP"); -} - -TEST(ModelLoadTest, WithMetadata) { - constexpr static absl::string_view kMetadataName = "an_soc_manufacturer"; - constexpr static absl::string_view kMetadataData = "My_Meta_Data"; - - auto flatbuffer = - FlatbufferWrapper::CreateFromTflFile(GetTestFilePath(kAddSimple)); - auto tfl_model = flatbuffer->get()->Unpack(); - PushMetadata(kMetadataName, *tfl_model, - BufferRef(kMetadataData.data(), kMetadataData.size())); - auto serialialized = SerializeFlatbuffer(*tfl_model); - - auto litert_model = LoadModelFromBuffer(serialialized); - ASSERT_TRUE(litert_model); - - auto metadata = litert_model->get()->FindMetadata(kMetadataName); - ASSERT_TRUE(metadata); - EXPECT_EQ(metadata->StrView(), kMetadataData); -} - -TEST(ModelSerializeTest, WithMetadata) { - auto model = litert::testing::LoadTestFileModel(kAddSimple); - - constexpr static absl::string_view kMetadataName = "an_soc_manufacturer"; - constexpr static absl::string_view kMetadataData = "My_Meta_Data"; - - LITERT_ASSERT_OK(model.Get()->PushMetadata( - kMetadataName, OwningBufferRef(kMetadataData))); - - auto serialized = SerializeModel(std::move(*model.Get())); - EXPECT_TRUE(VerifyFlatbuffer(serialized->Span())); - - auto re_loaded = LoadModelFromBuffer(*serialized); - auto metadata = re_loaded->get()->FindMetadata(kMetadataName); - EXPECT_EQ(metadata->StrView(), kMetadataData); -} - -TEST(ModelLoadTest, WithSignature) { - auto model = litert::testing::LoadTestFileModel(kAddSimple); - auto& litert_model = *model.Get(); - - auto signature = - litert_model.FindSignature(LiteRtSignatureT::kDefaultSignatureKey); - ASSERT_TRUE(signature); - - EXPECT_EQ(signature->get().InputNames().size(), 1); - EXPECT_EQ(signature->get().OutputNames().size(), 1); - EXPECT_EQ(&signature->get().GetSubgraph(), litert_model.MainSubgraph()); -} - -TEST(ModelLoadTest, NoSignature) { - auto model = *Model::CreateFromFile(testing::GetTfliteFilePath( - "java/demo/app/src/main/assets/mobilenet_v1_1.0_224.tflite")); - if (!model) { - GTEST_SKIP() << "Model file is not available."; - } - auto& litert_model = *model.Get(); - auto signature = - litert_model.FindSignature(LiteRtSignatureT::kDefaultSignatureKey); - ASSERT_TRUE(signature); - EXPECT_EQ(signature->get().InputNames().size(), 1); - EXPECT_EQ(signature->get().OutputNames().size(), 1); - EXPECT_EQ(&signature->get().GetSubgraph(), litert_model.MainSubgraph()); -} - -TEST(ModelSerializeTest, WithSignature) { - auto model = litert::testing::LoadTestFileModel(kAddSimple); - auto& litert_model = *model.Get(); - - static constexpr char kInput[] = "foo"; - static constexpr char kOutput[] = "bar"; - static constexpr char kKey[] = "newKey"; - - LiteRtSignatureT signature(litert_model.MainSubgraph(), {kInput}, {kOutput}, - kKey); - litert_model.EmplaceSignature(std::move(signature)); - - auto serialized = SerializeModel(std::move(*model.Get())); - EXPECT_TRUE(VerifyFlatbuffer(serialized->Span())); - - auto re_loaded = LoadModelFromBuffer(*serialized); - auto re_loaded_signature = re_loaded->get()->FindSignature(kKey); - ASSERT_TRUE(re_loaded_signature); - const auto& sig = re_loaded_signature->get(); - - const auto& inputs = sig.InputNames(); - const auto& outputs = sig.OutputNames(); - EXPECT_THAT(inputs, ElementsAreArray({kInput})); - EXPECT_THAT(outputs, ElementsAreArray({kOutput})); - EXPECT_EQ(&sig.GetSubgraph(), re_loaded->get()->MainSubgraph()); -} - -TEST(ModelLoadTest, ReverseSignature) { - auto model = - litert::testing::LoadTestFileModel("reverse_signature_model.tflite"); - ASSERT_TRUE(model); - auto& litert_model = *model.Get(); - - auto signature = litert_model.FindSignature("serving_default"); - ASSERT_TRUE(signature); - - // Check if the input and output names are in the order of the subgraph - // inputs and outputs instead of the signature appearance order. - const auto& sig = signature->get(); - ASSERT_EQ(sig.InputNames().size(), 2); - EXPECT_STREQ(sig.InputNames()[0].c_str(), "y"); - EXPECT_STREQ(sig.InputNames()[1].c_str(), "x"); - ASSERT_EQ(sig.OutputNames().size(), 2); - EXPECT_STREQ(sig.OutputNames()[0].c_str(), "sum"); - EXPECT_STREQ(sig.OutputNames()[1].c_str(), "prod"); - - auto serialized = SerializeModel(std::move(*model.Get())); - EXPECT_TRUE(VerifyFlatbuffer(serialized->Span())); - - auto re_loaded = LoadModelFromBuffer(*serialized); - auto re_loaded_signature = re_loaded->get()->FindSignature("serving_default"); - ASSERT_TRUE(re_loaded_signature); - - // Check again with the serialized model. - const auto& re_sig = re_loaded_signature->get(); - ASSERT_EQ(re_sig.InputNames().size(), 2); - EXPECT_STREQ(re_sig.InputNames()[0].c_str(), "y"); - EXPECT_STREQ(re_sig.InputNames()[1].c_str(), "x"); - ASSERT_EQ(re_sig.OutputNames().size(), 2); - EXPECT_STREQ(re_sig.OutputNames()[0].c_str(), "sum"); - EXPECT_STREQ(re_sig.OutputNames()[1].c_str(), "prod"); -} - -TEST(ModelLoadTest, WithOffsetTensorBuffer) { - static constexpr absl::string_view kTensorData = "SOME_TENSOR_DATA"; - - auto flatbuffer = - FlatbufferWrapper::CreateFromTflFile(GetTestFilePath(kAddSimple)); - auto tfl_model = flatbuffer->get()->Unpack(); - const auto buf_ind = tfl_model->subgraphs[0]->tensors[0]->buffer; - auto& tfl_buffer = tfl_model->buffers[buf_ind]; - tfl_buffer->offset = 1; - tfl_buffer->size = 1; - auto model_buf = SerializeFlatbuffer(*tfl_model); - auto* packed_tfl = tflite::GetMutableModel(model_buf.Data()); - auto* buf = packed_tfl->mutable_buffers()->GetMutableObject(buf_ind); - ASSERT_TRUE(buf->mutate_offset(model_buf.Size())); - ASSERT_TRUE(buf->mutate_size(kTensorData.size())); - OwningBufferRef final_serializd(kTensorData.size() + - model_buf.Size()); - std::memcpy(final_serializd.Data(), model_buf.Data(), model_buf.Size()); - std::memcpy(final_serializd.Data() + model_buf.Size(), kTensorData.data(), - kTensorData.size()); - - auto litert_model = LoadModelFromBuffer(final_serializd); - ASSERT_TRUE(litert_model); - - const auto& weights_buffer = - litert_model->get()->Subgraph(0).Tensor(0).Weights(); - EXPECT_EQ(weights_buffer.Buffer().StrView(), kTensorData); - - // The loaded buffer should indicate that it should be also serialized as - // external. - const auto will_append = weights_buffer.GetBufferManager() - ->GetContext(weights_buffer.GetBufferId()) - ->get() - .should_append; - EXPECT_TRUE(will_append); - - // All tensors in the first subgraph should have the same buffer manager as - // the model. - for (auto* tensor : litert_model->get()->Subgraph(0).Tensors()) { - EXPECT_EQ(tensor->Weights().GetBufferManager(), - litert_model->get()->Buffers()); - } -} - -TEST(ModelSerializeTest, WithOffsetTensorBuffer) { - static constexpr absl::string_view kTensorData = "SOME_TENSOR_DATA"; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - auto& tensor = sg.EmplaceTensor(); - sg.EmplaceOp(); - tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, {})); - auto& weights = tensor.Weights(); - weights.SetBufferManager(root.Buffers()); - - OwningBufferRef buffer(kTensorData); - BufferContext context; - context.should_append = true; - SetWeightsFromOwnedBuffer(weights, std::move(buffer), context); - - auto serialized = SerializeModel(std::move(root)); - ASSERT_TRUE(serialized); - - // Verify the op contains an offset and size to the byte code and the correct - // name. - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - - auto tfl = fb->get()->Unpack(); - const auto& tfl_tensor = tfl->subgraphs[0]->tensors[0]; - const auto tfl_buffer_ind = tfl_tensor->buffer; - const auto& tfl_buffer = tfl->buffers[tfl_buffer_ind]; - - auto data = - serialized->StrView().substr(tfl_buffer->offset, tfl_buffer->size); - EXPECT_EQ(data, kTensorData); -} - -TEST(ModelSerializeTest, WithMultipleOffsetTensorBuffer) { - static constexpr absl::string_view kTensorData = "SOME_TENSOR_DATA"; - static constexpr absl::string_view kTensorData2 = "SOME_TENSOR_DATA2"; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - sg.EmplaceOp(); - - { - auto& tensor = sg.EmplaceTensor(); - tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, {})); - auto& weights = tensor.Weights(); - weights.SetBufferManager(root.Buffers()); - - OwningBufferRef buffer(kTensorData); - BufferContext context; - context.should_append = true; - SetWeightsFromOwnedBuffer(weights, std::move(buffer), context); - } - - { - auto& tensor = sg.EmplaceTensor(); - tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, {})); - auto& weights = tensor.Weights(); - weights.SetBufferManager(root.Buffers()); - - OwningBufferRef buffer(kTensorData2); - BufferContext context; - context.should_append = true; - SetWeightsFromOwnedBuffer(weights, std::move(buffer), context); - } - - auto serialized = SerializeModel(std::move(root)); - ASSERT_TRUE(serialized); - - // Verify the op contains an offset and size to the byte code and the correct - // name. - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - - auto tfl = fb->get()->Unpack(); - - { - const auto& tfl_tensor = tfl->subgraphs[0]->tensors[0]; - const auto tfl_buffer_ind = tfl_tensor->buffer; - const auto& tfl_buffer = tfl->buffers[tfl_buffer_ind]; - - auto data = - serialized->StrView().substr(tfl_buffer->offset, tfl_buffer->size); - EXPECT_EQ(data, kTensorData); - } - - { - const auto& tfl_tensor = tfl->subgraphs[0]->tensors[1]; - const auto tfl_buffer_ind = tfl_tensor->buffer; - const auto& tfl_buffer = tfl->buffers[tfl_buffer_ind]; - - auto data = - serialized->StrView().substr(tfl_buffer->offset, tfl_buffer->size); - EXPECT_EQ(data, kTensorData2); - } -} - -TEST(ModelSerializeTest, WithSingleExternalBuffer) { - static constexpr absl::string_view kByteCode = "SOME_BYTE_CODE"; - static constexpr absl::string_view kName = "foo"; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - auto& op = sg.EmplaceOp(); - - OwningBufferRef buffer(kByteCode); - const auto buf_id = root.Buffers()->RegisterOwnedBuffer(std::move(buffer)); - root.AttachAssetToOp(&op, buf_id, std::string(kName)); - - auto serialized = SerializeModel(std::move(root)); - ASSERT_TRUE(serialized); - - // Verify the op contains an offset and size to the byte code and the correct - // name. - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - - auto tfl = fb->get()->Unpack(); - const auto& opts = tfl->subgraphs[0]->operators[0]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode); -} - -TEST(ModelSerializeTest, WithMultipleUniqueExternalBuffer) { - static constexpr absl::string_view kByteCode = "SOME_BYTE_CODE"; - static constexpr absl::string_view kName = "foo"; - static constexpr absl::string_view kByteCode2 = "SOME_BYTE_CODE2"; - static constexpr absl::string_view kName2 = "bar"; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - auto& op = sg.EmplaceOp(); - auto& op2 = sg.EmplaceOp(); - - OwningBufferRef buffer(kByteCode); - const auto buf_id = root.Buffers()->RegisterOwnedBuffer(std::move(buffer)); - root.AttachAssetToOp(&op, buf_id, std::string(kName)); - - OwningBufferRef buffer2(kByteCode2); - const auto buf_id2 = root.Buffers()->RegisterOwnedBuffer(std::move(buffer2)); - root.AttachAssetToOp(&op2, buf_id2, std::string(kName2)); - - auto serialized = SerializeModel(std::move(root)); - ASSERT_TRUE(serialized); - - // Verify both ops contains an offset and size to the byte code and the - // correct name. - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - - auto tfl = fb->get()->Unpack(); - - { - const auto& opts = tfl->subgraphs[0]->operators[0]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode); - } - - { - const auto& opts = tfl->subgraphs[0]->operators[1]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName2); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode2); - } -} - -TEST(ModelSerializeTest, WithSharedExternalBuffer) { - static constexpr absl::string_view kByteCode = "SOME_BYTE_CODE"; - static constexpr absl::string_view kName = "foo"; - static constexpr absl::string_view kName2 = "bar"; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - auto& op = sg.EmplaceOp(); - auto& op2 = sg.EmplaceOp(); - - OwningBufferRef buffer(kByteCode); - const auto buf_id = root.Buffers()->RegisterOwnedBuffer(std::move(buffer)); - - root.AttachAssetToOp(&op, buf_id, std::string(kName)); - root.AttachAssetToOp(&op2, buf_id, std::string(kName2)); - - auto serialized = SerializeModel(std::move(root)); - ASSERT_TRUE(serialized); - - // Verify both ops point to the same appended buffer. - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - - auto tfl = fb->get()->Unpack(); - - { - const auto& opts = tfl->subgraphs[0]->operators[0]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode); - } - - { - const auto& opts = tfl->subgraphs[0]->operators[1]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName2); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode); - } -} - -TEST(ModelSerializeTest, WithOffsetTensorBufferAndOpAsset) { - static constexpr absl::string_view kTensorData = "SOME_TENSOR_DATA"; - static constexpr absl::string_view kByteCode = "SOME_BYTE_CODE"; - static constexpr absl::string_view kName = "name"; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - auto& op = sg.EmplaceOp(); - auto& tensor = sg.EmplaceTensor(); - tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, {})); - auto& weights = tensor.Weights(); - weights.SetBufferManager(root.Buffers()); - - { - OwningBufferRef buffer(kTensorData); - BufferContext context; - context.should_append = true; - SetWeightsFromOwnedBuffer(weights, std::move(buffer), context); - } - - { - OwningBufferRef buffer(kByteCode); - const auto buf_id = root.Buffers()->RegisterOwnedBuffer(std::move(buffer)); - root.AttachAssetToOp(&op, buf_id, std::string(kName)); - } - - auto serialized = SerializeModel(std::move(root)); - ASSERT_TRUE(serialized); - - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - auto tfl = fb->get()->Unpack(); - - { - const auto& tfl_tensor = tfl->subgraphs[0]->tensors[0]; - const auto tfl_buffer_ind = tfl_tensor->buffer; - const auto& tfl_buffer = tfl->buffers[tfl_buffer_ind]; - - auto data = - serialized->StrView().substr(tfl_buffer->offset, tfl_buffer->size); - EXPECT_EQ(data, kTensorData); - } - - { - const auto& opts = tfl->subgraphs[0]->operators[0]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode); - } -} - -TEST(ModelSerializeTest, WithOffsetTensorBufferAndOpAssetHasAlignment) { - static constexpr absl::string_view kTensorData = "SOME_TENSOR_DATA"; - static constexpr absl::string_view kByteCode = "SOME_BYTE_CODE"; - static constexpr absl::string_view kName = "name"; - static constexpr size_t kAlignment = 32; - - LiteRtModelT root; - auto& sg = root.EmplaceSubgraph(); - auto& op = sg.EmplaceOp(); - auto& tensor = sg.EmplaceTensor(); - tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, {})); - auto& weights = tensor.Weights(); - weights.SetBufferManager(root.Buffers()); - - { - OwningBufferRef buffer(kTensorData); - BufferContext context; - context.should_append = true; - SetWeightsFromOwnedBuffer(weights, std::move(buffer), context); - } - - { - OwningBufferRef buffer(kByteCode); - const auto buf_id = root.Buffers()->RegisterOwnedBuffer(std::move(buffer)); - root.AttachAssetToOp(&op, buf_id, std::string(kName)); - } - - auto serialized = SerializeModel(std::move(root), kAlignment); - ASSERT_TRUE(serialized); - - auto fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(fb); - auto tfl = fb->get()->Unpack(); - - { - const auto& tfl_tensor = tfl->subgraphs[0]->tensors[0]; - const auto tfl_buffer_ind = tfl_tensor->buffer; - const auto& tfl_buffer = tfl->buffers[tfl_buffer_ind]; - - auto data = - serialized->StrView().substr(tfl_buffer->offset, tfl_buffer->size); - EXPECT_EQ(data, kTensorData); - } - - { - const auto& opts = tfl->subgraphs[0]->operators[0]->custom_options; - BufferRef opts_buffer(opts.data(), opts.size()); - - auto dispatch_opts = GetDispatchOpOptions(opts_buffer); - EXPECT_EQ(dispatch_opts.name, kName); - ASSERT_EQ(dispatch_opts.bytecode_offset % kAlignment, 0); - EXPECT_EQ(serialized->StrView().substr(dispatch_opts.bytecode_offset, - dispatch_opts.bytecode_size), - kByteCode); - } -} - -// Tests that explicitly check litert graph structure. -//===--------------------------------------------------------------------------- - -using AddSimpleTest = TestWithModelFactory; - -TEST_P(AddSimpleTest, CheckGraph) { - auto model = LoadModel(); - ASSERT_TRUE(model); - - // func(arg0) - // output = tfl.add(arg0, arg0) - // return(output) - // - - auto subgraph = model->MainSubgraph(); - const auto subgraph_inputs = subgraph->Inputs(); - const auto subgraph_outputs = subgraph->Outputs(); - const auto ops = subgraph->Ops(); - - ASSERT_EQ(subgraph_inputs.size(), 1); - ASSERT_EQ(subgraph_outputs.size(), 1); - - const auto& internal_ops = subgraph->Get()->Ops(); - ASSERT_TRUE( - ValidateLocalTopology(internal_ops.cbegin(), internal_ops.cend())); - ASSERT_TRUE(ValidateSubgraphIO(*subgraph->Get())); - - ASSERT_EQ(ops.size(), 1); - const auto& op = ops.front(); - - const TensorTypeInfo float_2by2_type(ElementType::Float32, {2, 2}); - ASSERT_TRUE( - MatchOpType(op, {float_2by2_type, float_2by2_type}, {float_2by2_type})); - EXPECT_EQ(op.Code(), kLiteRtOpCodeTflAdd); - - const auto op_inputs = op.Inputs(); - ASSERT_EQ(op_inputs.size(), 2); - ASSERT_EQ(op_inputs.front().Get(), subgraph_inputs.front().Get()); - ASSERT_EQ(op_inputs.front().Get(), op_inputs.back().Get()); - - const auto op_outputs = op.Outputs(); - ASSERT_EQ(op_outputs.size(), 1); - ASSERT_EQ(op_outputs.front().Get(), subgraph_outputs.front().Get()); - - ASSERT_FALSE(subgraph_outputs.front().IsConstant()); - ASSERT_FALSE(subgraph_inputs.front().IsConstant()); -} - -INSTANTIATE_TEST_SUITE_P(ModelLoadTests, AddSimpleTest, - Values(MakeLoadFactory(kAddSimple))); - -INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, AddSimpleTest, - Values(MakeRoundTripFactory(kAddSimple))); - -using AddCstTest = TestWithModelFactory; - -TEST_P(AddCstTest, CheckGraph) { - auto model = LoadModel(); - ASSERT_TRUE(model); - - // func(arg0) - // cst = ConstantTensor([1, 2, 3, 4]) - // output = tfl.add(arg0, cst) - // return(output) - // - - auto subgraph = model->MainSubgraph(); - const auto subgraph_inputs = subgraph->Inputs(); - const auto subgraph_outputs = subgraph->Outputs(); - const auto ops = subgraph->Ops(); - - ASSERT_EQ(subgraph_inputs.size(), 1); - ASSERT_EQ(subgraph_outputs.size(), 1); - - const auto& internal_ops = subgraph->Get()->Ops(); - ASSERT_TRUE( - ValidateLocalTopology(internal_ops.cbegin(), internal_ops.cend())); - ASSERT_TRUE(ValidateSubgraphIO(*subgraph->Get())); - - ASSERT_EQ(ops.size(), 1); - const auto& op = ops.front(); - - const TensorTypeInfo float_by4_type(ElementType::Float32, {4}); - ASSERT_TRUE( - MatchOpType(op, {float_by4_type, float_by4_type}, {float_by4_type})); - EXPECT_EQ(op.Code(), kLiteRtOpCodeTflAdd); - - const auto op_inputs = op.Inputs(); - ASSERT_EQ(op_inputs.size(), 2); - ASSERT_EQ(op_inputs.front().Get(), subgraph_inputs.front().Get()); - ASSERT_TRUE(MatchWeights(op_inputs.back(), - absl::Span({1.0, 2.0, 3.0, 4.0}))); - - const auto op_outputs = op.Outputs(); - ASSERT_EQ(op_outputs.size(), 1); - ASSERT_EQ(op_outputs.front().Get(), subgraph_outputs.front().Get()); - - ASSERT_FALSE(subgraph_outputs.front().IsConstant()); - ASSERT_FALSE(subgraph_inputs.front().IsConstant()); -} - -INSTANTIATE_TEST_SUITE_P(ModelLoadTests, AddCstTest, - Values(MakeLoadFactory(kAddCst))); - -INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, AddCstTest, - Values(MakeRoundTripFactory(kAddCst))); - -using SimpleMultiOpTest = TestWithModelFactory; - -TEST_P(SimpleMultiOpTest, CheckGraph) { - auto model = LoadModel(); - ASSERT_TRUE(model); - - // func.func @main(arg0) - // 0 = tfl.add arg0, arg0 - // 1 = tfl.mul 0, 0 - // 2 = tfl.mul 1, 1 - // 3 = tfl.add 2, 2 - // return 3 - - auto subgraph = model->MainSubgraph(); - const auto subgraph_inputs = subgraph->Inputs(); - const auto subgraph_outputs = subgraph->Outputs(); - const auto ops = subgraph->Ops(); - - ASSERT_EQ(subgraph_inputs.size(), 1); - ASSERT_EQ(subgraph_outputs.size(), 1); - - const auto& internal_ops = subgraph->Get()->Ops(); - ASSERT_TRUE( - ValidateLocalTopology(internal_ops.cbegin(), internal_ops.cend())); - ASSERT_TRUE(ValidateSubgraphIO(*subgraph->Get())); - - ASSERT_EQ(ops.size(), 4); - - for (const auto& op : ops) { - const auto inputs = op.Inputs(); - ASSERT_EQ(inputs.size(), 2); - ASSERT_EQ(inputs.front().Get(), inputs.back().Get()); - } - - const TensorTypeInfo float_2by2_type(ElementType::Float32, {2, 2}); - - ASSERT_TRUE(MatchOpType(ops.at(2), {float_2by2_type, float_2by2_type}, - {float_2by2_type})); - EXPECT_EQ(ops.at(2).Code(), kLiteRtOpCodeTflMul); -} - -INSTANTIATE_TEST_SUITE_P(ModelLoadTests, SimpleMultiOpTest, - Values(MakeLoadFactory(kSimpleMultiOp))); - -INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, SimpleMultiOpTest, - Values(MakeRoundTripFactory(kSimpleMultiOp))); - -using SimpleMultiSubgraphTest = TestWithModelFactory; - -TEST_P(SimpleMultiSubgraphTest, CheckGraph) { - auto model_wrap = LoadModel(); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap->Get(); - - ASSERT_EQ(model.NumSubgraphs(), 3); - - { - auto& main = *model.MainSubgraph(); - EXPECT_EQ(main.NumInputs(), 1); - EXPECT_EQ(main.NumOutputs(), 1); - EXPECT_EQ(main.Ops().size(), 1); - EXPECT_EQ(main.Tensors().size(), 3); - auto& op = main.Op(0); - auto* cst = op.Inputs().back(); - auto data = Tensor(cst).WeightsData(); - ASSERT_TRUE(data); - EXPECT_THAT(*data, Each(FloatEq(-1.0))); - EXPECT_TRUE(ValidateLocalTopology(main.Ops().cbegin(), main.Ops().cend())); - EXPECT_TRUE(ValidateSubgraphIO(main)); - } - - { - auto& func1 = model.Subgraph(1); - EXPECT_EQ(func1.NumInputs(), 1); - EXPECT_EQ(func1.NumOutputs(), 1); - EXPECT_EQ(func1.Ops().size(), 1); - EXPECT_EQ(func1.Tensors().size(), 3); - auto& op = func1.Op(0); - auto* cst = op.Inputs().back(); - auto data = Tensor(cst).WeightsData(); - ASSERT_TRUE(data); - EXPECT_THAT(*data, Each(FloatEq(1.0))); - EXPECT_TRUE( - ValidateLocalTopology(func1.Ops().cbegin(), func1.Ops().cend())); - EXPECT_TRUE(ValidateSubgraphIO(func1)); - } - - { - auto& func2 = model.Subgraph(2); - EXPECT_EQ(func2.NumInputs(), 1); - EXPECT_EQ(func2.NumOutputs(), 1); - EXPECT_EQ(func2.Ops().size(), 1); - EXPECT_EQ(func2.Tensors().size(), 3); - auto& op = func2.Op(0); - auto* cst = op.Inputs().back(); - auto data = Tensor(cst).WeightsData(); - ASSERT_TRUE(data); - EXPECT_THAT(*data, Each(FloatEq(2.0))); - EXPECT_TRUE( - ValidateLocalTopology(func2.Ops().cbegin(), func2.Ops().cend())); - EXPECT_TRUE(ValidateSubgraphIO(func2)); - } -} - -INSTANTIATE_TEST_SUITE_P(ModelLoadTests, SimpleMultiSubgraphTest, - Values(MakeLoadFactory(kSimpleMultiSubgraph))); - -INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, SimpleMultiSubgraphTest, - Values(MakeRoundTripFactory(kSimpleMultiSubgraph))); - -// Test when flatbuffer export has optimized multiple tensors to share the -// same buffer. -using MultiSubgraphDupeConstTest = TestWithModelFactory; - -TEST_P(MultiSubgraphDupeConstTest, CheckGraph) { - static constexpr std::array kWeights = {1.0, 2.0, 3.0, 4.0}; - - auto model_wrap = LoadModel(); - ASSERT_TRUE(model_wrap); - auto& model = *model_wrap->Get(); - - ASSERT_EQ(model.NumSubgraphs(), 2); - - { - ASSERT_EQ(model.Subgraph(0).Ops().size(), 1); - ASSERT_EQ(model.Subgraph(0).Tensors().size(), 3); - auto& cst = model.Subgraph(0).Op(0).Input(1); - Tensor t(&cst); - EXPECT_THAT(*t.WeightsData(), ElementsAreArray(kWeights)); - } - - { - ASSERT_EQ(model.Subgraph(1).Ops().size(), 1); - ASSERT_EQ(model.Subgraph(1).Tensors().size(), 3); - auto& cst = model.Subgraph(1).Op(0).Input(1); - Tensor t(&cst); - EXPECT_THAT(*t.WeightsData(), ElementsAreArray(kWeights)); - } - auto buf_id_0 = model.Subgraph(0).Op(0).Input(1).Weights().GetBufferId(); - auto buf_id_1 = model.Subgraph(1).Op(0).Input(1).Weights().GetBufferId(); - ASSERT_EQ(buf_id_0, buf_id_1); -} - -INSTANTIATE_TEST_SUITE_P(ModelLoadTests, MultiSubgraphDupeConstTest, - Values(MakeLoadFactory(kCstMultiSubgraph))); - -INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, MultiSubgraphDupeConstTest, - Values(MakeRoundTripFactory(kCstMultiSubgraph))); - -// Tests that programmatically check litert against tflite models. -//===--------------------------------------------------------------------------- - -using ModelLoadOpCheckTest = TestWithModelPath; - -TEST_P(ModelLoadOpCheckTest, CheckOps) { - const auto model_path = GetTestModelPath(); - - auto flatbuffer = FlatbufferWrapper::CreateFromTflFile(model_path); - ASSERT_TRUE(flatbuffer); - auto expected_fb = flatbuffer->get()->Unpack(); - - auto model = LoadModelFromFile(model_path); - ASSERT_TRUE(model); - - const auto* subgraph = model->get()->MainSubgraph(); - const auto& ops = subgraph->Ops(); - - const auto& fb_subgraph = *expected_fb->subgraphs.front(); - const auto& fb_ops = fb_subgraph.operators; - const auto& fb_tensors = fb_subgraph.tensors; - - ASSERT_EQ(ops.size(), fb_ops.size()); - - auto get_tfl_tensor = [&](uint32_t ind) -> const TflTensor& { - return *fb_tensors.at(ind); - }; - - for (auto i = 0; i < ops.size(); ++i) { - ASSERT_TRUE(EqualsFbOp(*ops.at(i), *fb_ops.at(i), get_tfl_tensor)); - } -} - -INSTANTIATE_TEST_SUITE_P(ModelLoadQuantizedOpCheckTest, ModelLoadOpCheckTest, - ::testing::ValuesIn(kAllQModels)); - -INSTANTIATE_TEST_SUITE_P(ModelLoadDynamicOpCheckTest, ModelLoadOpCheckTest, - ::testing::ValuesIn({kDynamicShapeModel})); - -using ModelSerializeOpCheckTest = TestWithModelPath; - -TEST_P(ModelSerializeOpCheckTest, CheckOps) { - const auto model_path = GetTestModelPath(); - - // Save the initial fb for comparison. - auto expected_fb_data = FlatbufferWrapper::CreateFromTflFile(model_path); - ASSERT_TRUE(expected_fb_data); - auto expected_fb = expected_fb_data->get()->Unpack(); - - // Round trip the model. - auto model = LoadModelFromFile(model_path); - ASSERT_TRUE(model); - auto serialized = SerializeModel(std::move(**model)); - - auto actual_fb_data = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(actual_fb_data); - auto actual_fb = actual_fb_data->get()->Unpack(); - - const auto& expected_fb_subgraph = *expected_fb->subgraphs.front(); - const auto& expected_fb_ops = expected_fb_subgraph.operators; - const auto& expected_fb_tensors = expected_fb_subgraph.tensors; - - const auto& actual_fb_subgraph = *actual_fb->subgraphs.front(); - const auto& actual_fb_ops = actual_fb_subgraph.operators; - const auto& actual_fb_tensors = actual_fb_subgraph.tensors; - - ASSERT_EQ(expected_fb_ops.size(), actual_fb_ops.size()); - for (auto i = 0; i < actual_fb_ops.size(); ++i) { - const auto& expected = *expected_fb_ops.at(i); - const auto& actual = *actual_fb_ops.at(i); - EXPECT_EQ(expected.inputs.size(), actual.inputs.size()); - EXPECT_EQ(expected.outputs.size(), actual.outputs.size()); - } - - ASSERT_EQ(expected_fb_tensors.size(), actual_fb_tensors.size()); - for (auto i = 0; i < actual_fb_tensors.size(); ++i) { - const auto& expected = *expected_fb_tensors.at(i); - const auto& actual = *actual_fb_tensors.at(i); - - EXPECT_EQ(actual.type, expected.type); - EXPECT_EQ(actual.shape, expected.shape); - EXPECT_EQ(actual.shape_signature, expected.shape_signature); - - const auto expected_q_params = expected.quantization.get(); - const auto actual_q_params = actual.quantization.get(); - - const auto neither_quantized = - !IsQuantized(expected_q_params) && !IsQuantized(actual_q_params); - const auto both_per_tensor = IsPerTensorQuantized(expected_q_params) && - IsPerTensorQuantized(actual_q_params); - ASSERT_TRUE(neither_quantized || both_per_tensor); - - if (both_per_tensor) { - const auto expected_per_tensor = AsPerTensorQparams(expected_q_params); - const auto actual_per_tensor = AsPerTensorQparams(actual_q_params); - EXPECT_EQ(*expected_per_tensor, *actual_per_tensor); - } - } -} - -INSTANTIATE_TEST_SUITE_P(ModelSerializeOpCheckTest, ModelSerializeOpCheckTest, - ::testing::ValuesIn({kOneMul, kDynamicShapeModel})); - -INSTANTIATE_TEST_SUITE_P(ModelSerializeQuantizedOpCheckTest, - ModelSerializeOpCheckTest, - ::testing::ValuesIn(kAllQModels)); - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_file_test_util.cc b/tensorflow/lite/experimental/litert/core/model/model_file_test_util.cc deleted file mode 100644 index 55bb72fa0c2961..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_file_test_util.cc +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_file_test_util.h" - -#include - -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -namespace litert::internal { - -namespace { - -template -bool EqualsFbQuantizationDetail(LiteRtQType litert_quantization, - const TflQuantization* tfl_quantization) { - return false; -} - -template <> -bool EqualsFbQuantizationDetail( - LiteRtQuantizationPerTensor litert_quantization, - const TflQuantization* tfl_quantization) { - auto tfl_q_params = AsPerTensorQparams(tfl_quantization); - if (!tfl_q_params) return false; - return litert_quantization.zero_point == tfl_q_params->first && - litert_quantization.scale == tfl_q_params->second; -} - -template <> -bool EqualsFbQuantizationDetail( - LiteRtQuantizationPerChannel litert_quantization, - const TflQuantization* tfl_quantization) { - auto tfl_q_params = AsPerChannelQparams(tfl_quantization); - if (!tfl_q_params) return false; - const auto& [quantized_dimension, num_channels, zero_points, scales] = - *tfl_q_params; - const auto qd_eq = - litert_quantization.quantized_dimension == quantized_dimension; - const auto num_chan_eq = litert_quantization.num_channels == num_channels; - const auto zeros_eq = std::equal(zero_points.begin(), zero_points.end(), - litert_quantization.zero_points); - const auto scales_eq = - std::equal(scales.begin(), scales.end(), litert_quantization.scales); - return qd_eq && num_chan_eq && zeros_eq && scales_eq; -} -template -bool EqualsFbTensorTypeDetail(LiteRtTenzorType litert_tensor_type, - const TflTensorType& tfl_tensor) { - LITERT_LOG(LITERT_ERROR, "LiteRtTensorType not supported"); - return false; -} - -template <> -bool EqualsFbTensorTypeDetail( - LiteRtRankedTensorType litert_tensor_type, - const TflTensorType& tfl_tensor_type) { - auto tfl_shape = AsDynamicShape(tfl_tensor_type.second); - if (!tfl_shape) { - LITERT_LOG(LITERT_ERROR, "Not ranked shape"); - return false; - } - - if (MapElementType(tfl_tensor_type.first) != - static_cast(litert_tensor_type.element_type)) { - LITERT_LOG(LITERT_ERROR, "Element type not equal"); - return false; - } - - auto same_or_both_dyn = [](auto l, auto r) { - const auto same_static = l >= 0 && l == r; - const auto both_dyn = l < 0 && r < 0; - return same_static || both_dyn; - }; - - auto& layout = litert_tensor_type.layout; - const bool shape_eq = - AllZip(*tfl_shape, absl::MakeConstSpan(layout.dimensions, layout.rank), - same_or_both_dyn); - if (!shape_eq) { - LITERT_LOG(LITERT_ERROR, "Shapes are not equal"); - return false; - } - - return true; -} - -} // namespace - -bool EqualsFbQuantization(const Quantization& litert_quantization, - const TflQuantization* tfl_quantization) { - switch (litert_quantization.first) { - case kLiteRtQuantizationPerTensor: - return EqualsFbQuantizationDetail(litert_quantization.second.per_tensor, - tfl_quantization); - case kLiteRtQuantizationPerChannel: - return EqualsFbQuantizationDetail(litert_quantization.second.per_channel, - tfl_quantization); - case kLiteRtQuantizationNone: - return !IsQuantized(tfl_quantization); - default: - // Not implemented yet. - return false; - } -} - -// Compare tensor type within litert tensor to the type within flatbuffer -// tensor. -bool EqualsFbTensorType(const TensorType& litert_tensor_type, - const TflTensorType& tfl_tensor_type) { - switch (litert_tensor_type.first) { - case kLiteRtRankedTensorType: - return EqualsFbTensorTypeDetail( - litert_tensor_type.second.ranked_tensor_type, tfl_tensor_type); - default: - LITERT_LOG(LITERT_ERROR, "Tensor kind not supported"); - // Not implemented yet. - return false; - } -} - -bool EqualsFbTensor(const LiteRtTensorT& litert_tensor, - const TflTensor& tfl_tensor) { - if (!EqualsFbTensorType(litert_tensor.Type(), - {tfl_tensor.type, TflShapeInfo(tfl_tensor)})) { - LITERT_LOG(LITERT_ERROR, "Tensor not same type"); - return false; - } - - if (!EqualsFbQuantization(litert_tensor.Qparams(), - tfl_tensor.quantization.get())) { - LITERT_LOG(LITERT_ERROR, "Tensor not same quantization"); - return false; - } - - return true; -} - -bool EqualsFbOp(const LiteRtOpT& litert_op, const TflOp& tfl_op, - GetTflTensor get_tfl_tensor) { - auto check_tensors = [&](auto& litert_tensors, auto& tfl_tensors) { - if (litert_tensors.size() != tfl_tensors.size()) { - LITERT_LOG(LITERT_ERROR, "Tensors not same size"); - return false; - } - - for (auto i = 0; i < litert_tensors.size(); ++i) { - const auto& fb_tensor = get_tfl_tensor(tfl_tensors.at(i)).get(); - const auto& litert_tensor = *litert_tensors.at(i); - - if (!EqualsFbTensor(litert_tensor, fb_tensor)) { - LITERT_LOG(LITERT_ERROR, "Tensor %d not same", i); - return false; - } - } - - return true; - }; - - return check_tensors(litert_op.Inputs(), tfl_op.inputs) && - check_tensors(litert_op.Outputs(), tfl_op.outputs); -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_file_test_util.h b/tensorflow/lite/experimental/litert/core/model/model_file_test_util.h deleted file mode 100644 index df0138e321c063..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_file_test_util.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_FILE_TEST_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_FILE_TEST_UTIL_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -namespace litert::internal { - -// Callback to get a tfl tensor from it's index. -using GetTflTensor = - std::function(uint32_t ind)>; - -// Compare q-params for having the same type and values. -bool EqualsFbQuantization(const Quantization& litert_quantization, - const TflQuantization* tfl_quantization); - -// Compare tensor types for having the same shape and element type. -bool EqualsFbTensorType(const TensorType& litert_tensor_type, - const TflTensorType& tfl_tensor_type); - -// Compare litert op to flatbuffer op along with their input/output tensors -// types and quantization. Takes a callback to lookup tfl tensors the indices -// within the tfl op. -bool EqualsFbOp(const LiteRtOpT& litert_op, const TflOp& tfl_op, - GetTflTensor get_tfl_tensor); - -// Compare litert tensor to flatbuffer tensor for having same types and -// quantization. -bool EqualsFbTensor(const LiteRtTensorT& litert_tensor, - const TflTensor& tfl_tensor); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_FILE_TEST_UTIL_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_graph.cc b/tensorflow/lite/experimental/litert/core/model/model_graph.cc deleted file mode 100644 index f7a8bb4c80ef61..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_graph.cc +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" - -#include -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -namespace { - -bool IsOpDead(const LiteRtOpT& op) { - return op.Inputs().empty() && op.Outputs().empty(); -} - -bool IsTensorDead(const LiteRtTensorT& tensor) { - return tensor.DefiningOp() == nullptr && tensor.NumUses() == 0; -} - -} // namespace - -void CloneTo(const LiteRtTensorT& src, LiteRtTensorT& dest) { - dest.SetName({src.Name().cbegin(), src.Name().cend()}); - dest.SetQarams(src.Qparams()); - dest.SetType(src.Type()); - - // Manully copy per-channel quantization params,quant array is owned by - // tensor. - if (src.Qparams().first == kLiteRtQuantizationPerChannel) { - std::vector scales( - src.Qparams().second.per_channel.scales, - src.Qparams().second.per_channel.scales + - src.Qparams().second.per_channel.num_channels); - std::vector zero_points( - src.Qparams().second.per_channel.zero_points, - src.Qparams().second.per_channel.zero_points + - src.Qparams().second.per_channel.num_channels); - Quantization dest_qparams = MakePerChannelQuantization( - scales, zero_points, - src.Qparams().second.per_channel.quantized_dimension, - [&dest](auto s) { return dest.RequestScratchBuffer(s); }); - dest.SetQarams(std::move(dest_qparams)); - } - - // Move weight buffer from src to dest. - const auto& src_weights = src.Weights(); - auto& dest_weights = dest.Weights(); - - const auto same_manager = - src_weights.GetBufferManager() == dest_weights.GetBufferManager(); - - if (same_manager) { - dest_weights.SetBufferId(src_weights.GetBufferId()); - } else { - OwningBufferRef weights_buffer(src_weights.Buffer().Data(), - src_weights.Buffer().Size()); - SetWeightsFromOwnedBuffer(dest_weights, std::move(weights_buffer)); - } -} - -void CloneTo(const LiteRtOpT& src, LiteRtOpT& dest) { - dest.SetCustomOptions(src.CustomOptions().Data(), src.CustomOptions().Size()); - litert::internal::SetTflOptions(dest, litert::internal::GetTflOptions(src)); - litert::internal::SetTflOpCodeInd(dest, - litert::internal::GetTflOpCodeInd(src)); - dest.SetOpCode(src.OpCode()); -} - -LiteRtTensorT& MakeClone(LiteRtSubgraphT& parent, const LiteRtTensorT& src) { - auto& new_tensor = parent.EmplaceTensor(); - CloneTo(src, new_tensor); - return new_tensor; -} - -LiteRtOpT& MakeClone(LiteRtSubgraphT& parent, const LiteRtOpT& src) { - auto& new_op = parent.EmplaceOp(); - CloneTo(src, new_op); - return new_op; -} - -std::optional FindInput(const LiteRtOpT& op, - const LiteRtTensorT& tensor) { - return FindInd(op.Inputs().cbegin(), op.Inputs().cend(), &tensor); -} - -std::optional FindOutput(const LiteRtOpT& op, - const LiteRtTensorT& tensor) { - return FindInd(op.Outputs().cbegin(), op.Outputs().cend(), &tensor); -} - -std::optional FindInput(const LiteRtSubgraphT& subgraph, - const LiteRtTensorT& tensor) { - return FindInd(subgraph.Inputs().cbegin(), subgraph.Inputs().cend(), &tensor); -} - -std::optional FindOutput(const LiteRtSubgraphT& subgraph, - const LiteRtTensorT& tensor) { - return FindInd(subgraph.Outputs().cbegin(), subgraph.Outputs().cend(), - &tensor); -} - -UseIndices FindUseInds(const LiteRtTensorT& tensor, const LiteRtOpT& op) { - UseIndices res; - for (auto i = 0; i < tensor.NumUses(); ++i) { - if (tensor.Users().at(i) == &op) { - res.push_back(i); - } - } - return res; -} - -bool IsConstant(const LiteRtTensorT& tensor) { - bool is_zero_sized = false; - auto layout = tensor.Type().second.ranked_tensor_type.layout; - if (layout.rank == 1) { - if (layout.dimensions[0] == 0) { - is_zero_sized = true; - } - } - const auto is_const = tensor.Weights().Buffer().Size() > 0 || is_zero_sized; - ABSL_DCHECK(!is_const || tensor.DefiningOp() == nullptr) - << "Constant tensors should not be defined by an op"; - return is_const; -} - -void AttachInput(LiteRtTensor tensor, LiteRtOpT& op) { - op.Inputs().push_back(tensor); - tensor->Users().push_back(&op); - tensor->UserArgInds().push_back(op.Inputs().size() - 1); -} - -void AttachOutput(LiteRtTensor tensor, LiteRtOpT& op) { - ABSL_DCHECK(tensor->DefiningOp() == nullptr) - << "Cannot add an already defined tensor as op output"; - op.Outputs().push_back(tensor); - tensor->SetDefiningOp(op, op.Outputs().size() - 1); -} - -LiteRtTensor DisconnectInput(LiteRtOpT& op, LiteRtParamIndex input_ind) { - ABSL_DCHECK(input_ind < op.Inputs().size()) << "Removing tensor index oob"; - auto& input = op.Input(input_ind); - - // Find the index of the use for the given in edge. - auto target_use_ind = -1; - for (auto i = 0; i < input.NumUses(); ++i) { - if (input.Users().at(i) == &op && input.UserArgInds().at(i) == input_ind) { - target_use_ind = i; - } - } - ABSL_DCHECK_GE(target_use_ind, 0) << "Malformed graph"; - - // Slide latter input use arg inds to the left. - for (auto i = input_ind + 1; i < op.Inputs().size(); ++i) { - auto& r_in = op.Input(i); - for (auto u = 0; u < r_in.NumUses(); ++u) { - auto& r_arg_ind = r_in.UserArgInds().at(u); - if (r_in.Users().at(u) == &op && r_arg_ind > input_ind) { - r_arg_ind -= 1; - } - } - } - - // Update the edges. - input.RemoveUse(target_use_ind); - op.RemoveInput(input_ind); - - return &input; -} - -bool IsIO(const LiteRtSubgraphT& subgraph, const LiteRtTensorT& tensor) { - return FindInput(subgraph, tensor) || FindOutput(subgraph, tensor); -} - -LiteRtTensor DisconnectOutput(LiteRtOpT& op, LiteRtParamIndex output_ind) { - ABSL_DCHECK(output_ind < op.Outputs().size()) << "Removing tensor index oob"; - auto& output = op.Output(output_ind); - output.ClearDefiningOp(); - op.RemoveOutput(output_ind); - return &output; -} - -void Drop(LiteRtOpT& litert_op) { - while (!litert_op.Inputs().empty()) { - DisconnectInput(litert_op, 0); - } - while (!litert_op.Outputs().empty()) { - DisconnectOutput(litert_op, 0); - } -} - -bool DCE(LiteRtSubgraphT& subgraph) { - const auto ops_removed = subgraph.RemoveOpIf(IsOpDead); - - auto rm_tensor = [&subgraph = std::as_const(subgraph)](const auto& t) { - return IsTensorDead(t) && !IsIO(subgraph, t); - }; - const auto tensors_removed = subgraph.RemoveTensorIf(rm_tensor); - LITERT_LOG(LITERT_INFO, "Removed %d ops, %d tensors", ops_removed, - tensors_removed); - - return (ops_removed + tensors_removed) > 0; -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_graph.h b/tensorflow/lite/experimental/litert/core/model/model_graph.h deleted file mode 100644 index 55e00e90c833ee..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_graph.h +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_GRAPH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_GRAPH_H_ - -#include -#include - -#include "absl/container/inlined_vector.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_consts.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -// using IrMapping = absl::flat_hash_map; - -// CLONING - -// Clones the basic data between tensors (like name and data) but not -// things related to incoming/outgoing edges (users, defining op) or weights. -void CloneTo(const LiteRtTensorT& src, LiteRtTensorT& dest); - -// Clones the basic data between ops (like op code and options) but -// things related to incoming/outgoing edges (input/output tensors). -void CloneTo(const LiteRtOpT& src, LiteRtOpT& dest); - -// Same as clone to, but allocates a the dest tensor into given subgraph. -LiteRtTensorT& MakeClone(LiteRtSubgraphT& parent, const LiteRtTensorT& src); - -// Same as clone to, but allocates a the dest op into given subgraph. -LiteRtOpT& MakeClone(LiteRtSubgraphT& parent, const LiteRtOpT& src); - -// OBSERVERS - -// Checks if tensor is input to given op, return its index if so. -std::optional FindInput(const LiteRtOpT& op, - const LiteRtTensorT& tensor); - -// Checks if tensor is output to given op, return its index if so. -std::optional FindOutput(const LiteRtOpT& op, - const LiteRtTensorT& tensor); - -// Checks if tensor is input to given subgraph, return its index if so. -std::optional FindInput(const LiteRtSubgraphT& subgraph, - const LiteRtTensorT& tensor); - -// Checks if tensor is output to given subgraph, return its index if so. -std::optional FindOutput(const LiteRtSubgraphT& subgraph, - const LiteRtTensorT& tensor); - -// Check if tensor is part of subgraph IO. -bool IsIO(const LiteRtSubgraphT& subgraph, const LiteRtTensorT& tensor); - -using UseIndices = - absl::InlinedVector; - -// Checks if tensor is used by op, return the use inds for each use of tensor by -// op (there may be multiple). These are the indexes to call -// LiteRtTensorT::GetUse with. -UseIndices FindUseInds(const LiteRtTensorT& tensor, const LiteRtOpT& op); - -// Is this tensor a constant tensor? -bool IsConstant(const LiteRtTensorT& tensor); - -// MUTATORS - -// Attaches the pre-allocated tensor to be an input of given op. -void AttachInput(LiteRtTensor tensor, LiteRtOpT& op); - -// Attaches the pre-allocated tensor to be an output of given op. -void AttachOutput(LiteRtTensor tensor, LiteRtOpT& op); - -// Remove the input edge from an op. Return the disconnected tensor. -LiteRtTensor DisconnectInput(LiteRtOpT& op, LiteRtParamIndex input_ind); - -// Remove an output edge from an op. Return the disconnected tensor. -LiteRtTensor DisconnectOutput(LiteRtOpT& op, LiteRtParamIndex output_ind); - -// Remove all incoming and outgoing edges from this op. This can prep nodes -// for removal in DCE. -void Drop(LiteRtOpT& litert_op); - -// Run very naive dead code elimination. Removes only ops/tensors that have no -// in/out edges. Ops are handled first. Ignores subgraph IO. Not recursive and -// does only one pass. Returns if the graph was modified. -// NOTE: This de-allocates removed objects, only use when references to these -// objects will not be used. -// TODO: Update this with complete work-list based approach. -bool DCE(LiteRtSubgraphT& subgraph); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_GRAPH_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_graph_test.cc b/tensorflow/lite/experimental/litert/core/model/model_graph_test.cc deleted file mode 100644 index 4258bc9edb7418..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_graph_test.cc +++ /dev/null @@ -1,419 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" - -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/core/model/graph_validation.h" -#include "tensorflow/lite/experimental/litert/core/model/ir_allocator.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { -namespace { - -using ::testing::UnorderedElementsAreArray; - -// Custom matcher; example: -// ``` -// LiteRtTensor tensor ... -// EXPECT_THAT(tensor, HasRankedType(kLiteRtInt, absl::MakeSpan({2, 2}))); -// ``` -// TODO: Update to use dumping API directly and move to shared header. -MATCHER_P2(HasRankedType, element_type, shape, "") { - if (arg.Type().first != kLiteRtRankedTensorType) { - *result_listener << "Not ranked tensor type"; - return false; - } - const auto& ranked_tensor_type = arg.Type().second.ranked_tensor_type; - const auto& layout = ranked_tensor_type.layout; - - const auto element_type_eq = ranked_tensor_type.element_type == element_type; - const auto rank_eq = layout.rank == std::size(shape); - - auto actual_shape = absl::MakeConstSpan(layout.dimensions, layout.rank); - auto expected_shape = - absl::MakeConstSpan(std::cbegin(shape), std::cend(shape)); - const auto shape_eq = actual_shape == expected_shape; - - if (shape_eq && element_type_eq && rank_eq) { - return true; - } - - *result_listener << "\n"; - if (!shape_eq) { - *result_listener << "Not correct shape\n"; - } - if (!element_type_eq) { - *result_listener << "Not correct element type\n"; - } - if (!rank_eq) { - *result_listener << "Not correct rank\n"; - } - - *result_listener << absl::StreamFormat("Actual ElementType is: %d\n", - ranked_tensor_type.element_type); - *result_listener << absl::StreamFormat("Actual Rank is: %lu\n", layout.rank); - *result_listener << "Actual shape is: { "; - for (const auto d : actual_shape) { - *result_listener << absl::StreamFormat("%d, ", d); - } - *result_listener << "}\n"; - - return false; -} - -using ::testing::ElementsAreArray; - -static constexpr size_t kRank = 1; -static constexpr int32_t kDims[] = {2}; -static constexpr absl::Span kDimsSpan(kDims); -static constexpr auto kType = kLiteRtElementTypeInt32; -static constexpr absl::string_view kCustomOptions = "OPTIONS"; -static constexpr auto kOpCode = kLiteRtOpCodeTflMul; - -LiteRtTensorT TestTensor() { - LiteRtTensorT tensor; - tensor.Type().first = kLiteRtRankedTensorType; - tensor.Type().second.ranked_tensor_type.element_type = kType; - tensor.Type().second.ranked_tensor_type.layout.dimensions[0] = kDims[0]; - tensor.Type().second.ranked_tensor_type.layout.rank = kRank; - return tensor; -} - -LiteRtTensorT& TestTensor(LiteRtTensorT& tensor) { - tensor.Type().first = kLiteRtRankedTensorType; - tensor.Type().second.ranked_tensor_type.element_type = kType; - tensor.Type().second.ranked_tensor_type.layout.dimensions[0] = kDims[0]; - tensor.Type().second.ranked_tensor_type.layout.rank = kRank; - return tensor; -} - -LiteRtOpT TestOp() { - LiteRtOpT op; - op.SetOpCode(kOpCode); - op.SetCustomOptions(kCustomOptions); - return op; -} - -TEST(ModelGraphTest, CloneTensor) { - LiteRtTensorT dest; - CloneTo(TestTensor(), dest); - EXPECT_THAT(dest, HasRankedType(kType, kDimsSpan)); -} - -TEST(ModelQuantizationTypeTest, ClonePerChannelQuantization) { - static constexpr std::array kScale = {1.0f, 2.0f}; - static constexpr std::array kZero = {1L, 2L}; - static constexpr int32_t kQdim = 0; - - IrAllocator tensor_allocator; - auto& tensor = tensor_allocator.EmplaceBack(); - LiteRtTensorT dest; - const auto quant = MakePerChannelQuantization( - kScale, kZero, kQdim, - [&tensor](auto s) { return tensor.RequestScratchBuffer(s); }); - - ASSERT_EQ(quant.first, kLiteRtQuantizationPerChannel); - const auto& per_channel = quant.second.per_channel; - - const auto size = per_channel.num_channels; - ASSERT_EQ(size, 2); - EXPECT_EQ(per_channel.quantized_dimension, 0); - tensor.SetQarams(quant); - - CloneTo(tensor, dest); - // Mimic DCE. - tensor_allocator.RemoveIf([](auto& t) { return true; }); - auto dest_quant = dest.Qparams(); - - auto scales = absl::MakeConstSpan(dest_quant.second.per_channel.scales, - dest_quant.second.per_channel.num_channels); - auto zeros = absl::MakeConstSpan(dest_quant.second.per_channel.zero_points, - dest_quant.second.per_channel.num_channels); - - ASSERT_EQ(scales.size(), 2); - ASSERT_EQ(zeros.size(), 2); - EXPECT_THAT(scales, ElementsAreArray(kScale)); - EXPECT_THAT(zeros, ElementsAreArray(kZero)); -} - -TEST(ModelGraphTest, MakeCloneTensor) { - LiteRtSubgraphT subgraph; - auto& dest = MakeClone(subgraph, TestTensor()); - EXPECT_THAT(dest, HasRankedType(kType, kDimsSpan)); -} - -TEST(ModelGraphTest, CloneCstSameManager) { - OwningBufferRef buffer("DATA"); - LiteRtModelT model; - const auto num_buffers = model.Buffers()->NumBuffers(); - auto& sg = model.EmplaceSubgraph(); - auto& src = TestTensor(sg.EmplaceTensor()); - SetWeightsFromUnownedBuffer(src.Weights(), buffer); - auto& dest = MakeClone(sg, src); - EXPECT_EQ(dest.Weights().Buffer().StrView(), buffer.StrView()); - EXPECT_EQ(model.Buffers()->NumBuffers(), num_buffers + 1); - EXPECT_EQ(dest.Weights().GetBufferId(), src.Weights().GetBufferId()); - EXPECT_EQ(dest.Weights().GetBufferManager(), - src.Weights().GetBufferManager()); - EXPECT_EQ(dest.Weights().Buffer().Data(), src.Weights().Buffer().Data()); -} - -TEST(ModelGraphTest, CloneCstDifferentManager) { - OwningBufferRef buffer("DATA"); - LiteRtSubgraphT sg; - auto& src = TestTensor(sg.EmplaceTensor()); - SetWeightsFromUnownedBuffer(src.Weights(), buffer); - auto& dest = MakeClone(sg, src); - EXPECT_EQ(dest.Weights().Buffer().StrView(), buffer.StrView()); - EXPECT_NE(dest.Weights().GetBufferManager(), - src.Weights().GetBufferManager()); - EXPECT_NE(dest.Weights().Buffer().Data(), src.Weights().Buffer().Data()); -} - -TEST(ModelGraphTest, CloneOp) { - LiteRtOpT dest; - CloneTo(TestOp(), dest); - EXPECT_EQ(dest.OpCode(), kOpCode); - EXPECT_EQ(dest.CustomOptions().StrView(), kCustomOptions); -} - -TEST(ModelGraphTest, MakeCloneOp) { - LiteRtSubgraphT subgraph; - auto& dest = MakeClone(subgraph, TestOp()); - EXPECT_EQ(dest.OpCode(), kOpCode); - EXPECT_EQ(dest.CustomOptions().StrView(), kCustomOptions); -} - -TEST(ModelGraphTest, OpFindInput) { - auto op = TestOp(); - auto tensor = TestTensor(); - AttachInput(&tensor, op); - auto input = FindInput(op, tensor); - ASSERT_TRUE(input); - EXPECT_EQ(*input, 0); -} - -TEST(ModelGraphTest, OpFindOutput) { - auto op = TestOp(); - auto tensor = TestTensor(); - AttachOutput(&tensor, op); - auto output = FindOutput(op, tensor); - ASSERT_TRUE(output); - EXPECT_EQ(*output, 0); -} - -TEST(ModelGraphTest, SubgraphFindInput) { - LiteRtSubgraphT subgraph; - auto tensor = TestTensor(); - subgraph.Inputs().push_back(&tensor); - auto input = FindInput(subgraph, tensor); - ASSERT_TRUE(input); - EXPECT_EQ(*input, 0); -} - -TEST(ModelGraphTest, SubgraphFindOutput) { - LiteRtSubgraphT subgraph; - auto tensor = TestTensor(); - subgraph.Outputs().push_back(&tensor); - auto output = FindOutput(subgraph, tensor); - ASSERT_TRUE(output); - EXPECT_EQ(*output, 0); -} - -TEST(ModelGraphTest, TensorFindUseInds) { - auto op1 = TestOp(); - auto op2 = TestOp(); - auto tensor = TestTensor(); - - AttachInput(&tensor, op1); - AttachInput(&tensor, op2); - AttachInput(&tensor, op1); - - auto use_inds = FindUseInds(tensor, op1); - auto uses = GetTensorUses(tensor, use_inds); - ASSERT_EQ(uses.size(), 2); - - LiteRtTensorT::UseVec expected = {{&op1, 0}, {&op1, 1}}; - EXPECT_THAT(uses, UnorderedElementsAreArray(expected)); -} - -TEST(ModelGraphTest, OpAttachInput) { - auto op = TestOp(); - auto tensor = TestTensor(); - AttachInput(&tensor, op); - EXPECT_THAT(op.Inputs(), ElementsAreArray({&tensor})); - EXPECT_THAT(tensor.Users(), ElementsAreArray({&op})); - EXPECT_THAT(tensor.UserArgInds(), ElementsAreArray({0})); -} - -TEST(ModelGraphTest, OpAttachOutput) { - auto op = TestOp(); - auto tensor = TestTensor(); - AttachOutput(&tensor, op); - EXPECT_THAT(op.Outputs(), ElementsAreArray({&tensor})); - EXPECT_EQ(tensor.DefiningOp(), &op); - EXPECT_EQ(tensor.DefiningOpOutInd(), 0); -} - -TEST(ModelGraphTest, DisconnectInputOp) { - auto op = TestOp(); - auto tensor = TestTensor(); - AttachInput(&tensor, op); - auto disconnected = DisconnectInput(op, 0); - EXPECT_EQ(disconnected, &tensor); - EXPECT_TRUE(op.Inputs().empty()); - EXPECT_TRUE(tensor.Users().empty()); - EXPECT_TRUE(tensor.UserArgInds().empty()); -} - -TEST(ModelGraphTest, DisconnectMiddleInputOp) { - auto op = TestOp(); - - auto tensor1 = TestTensor(); - auto tensor2 = TestTensor(); - auto tensor3 = TestTensor(); - - AttachInput(&tensor1, op); - AttachInput(&tensor2, op); - AttachInput(&tensor3, op); - - auto disconnected = DisconnectInput(op, 1); - - EXPECT_EQ(disconnected, &tensor2); - ASSERT_EQ(op.Inputs().size(), 2); - EXPECT_EQ(op.Inputs().front(), &tensor1); - EXPECT_EQ(op.Inputs().back(), &tensor3); - ASSERT_TRUE(tensor2.Users().empty()); - ASSERT_TRUE(tensor2.UserArgInds().empty()); - - ASSERT_TRUE(ValidateLocalTopology(op)); -} - -TEST(ModelGraphTest, DisconnectOutputOp) { - auto op = TestOp(); - auto tensor = TestTensor(); - AttachOutput(&tensor, op); - auto disconnected = DisconnectOutput(op, 0); - EXPECT_EQ(disconnected, &tensor); - EXPECT_EQ(tensor.DefiningOp(), nullptr); - EXPECT_TRUE(op.Outputs().empty()); -} - -TEST(ModelGraphTest, DropOp) { - LiteRtOpT op; - - LiteRtTensorT input1; - LiteRtTensorT input2; - LiteRtTensorT output; - - AttachInput(&input1, op); - AttachInput(&input2, op); - AttachOutput(&output, op); - - Drop(op); - - EXPECT_TRUE(op.Inputs().empty()); - EXPECT_TRUE(op.Outputs().empty()); - EXPECT_TRUE(input1.Users().empty()); - EXPECT_TRUE(input2.Users().empty()); - EXPECT_EQ(output.DefiningOp(), nullptr); -} - -TEST(ModelGraphTestDCE, NoDeadCode) { - LiteRtSubgraphT subgraph; - - auto& input = subgraph.EmplaceTensor(); - auto& output = subgraph.EmplaceTensor(); - - auto& op = subgraph.EmplaceOp(); - - AttachInput(&input, op); - AttachOutput(&output, op); - - subgraph.Inputs().push_back(&input); - subgraph.Outputs().push_back(&output); - - ASSERT_FALSE(DCE(subgraph)); - EXPECT_EQ(subgraph.Ops().size(), 1); - EXPECT_EQ(subgraph.Tensors().size(), 2); - - ASSERT_TRUE( - ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); - ASSERT_TRUE(ValidateSubgraphIO(subgraph)); -} - -TEST(ModelGraphTestDCE, DeadTensor) { - LiteRtSubgraphT subgraph; - subgraph.EmplaceTensor(); - - ASSERT_TRUE(DCE(subgraph)); - EXPECT_TRUE(subgraph.Tensors().empty()); - - ASSERT_TRUE( - ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); - ASSERT_TRUE(ValidateSubgraphIO(subgraph)); -} - -TEST(ModelGraphTestDCE, DeadOp) { - LiteRtSubgraphT subgraph; - subgraph.EmplaceOp(); - - ASSERT_TRUE(DCE(subgraph)); - EXPECT_TRUE(subgraph.Ops().empty()); - - ASSERT_TRUE( - ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); - ASSERT_TRUE(ValidateSubgraphIO(subgraph)); -} - -TEST(ModelGraphTestDCE, SomeDead) { - LiteRtSubgraphT subgraph; - - auto& input = subgraph.EmplaceTensor(); - auto& output = subgraph.EmplaceTensor(); - - auto& op = subgraph.EmplaceOp(); - - AttachInput(&input, op); - AttachOutput(&output, op); - - // Dead - subgraph.EmplaceTensor(); - subgraph.EmplaceOp(); - - subgraph.Inputs().push_back(&input); - subgraph.Outputs().push_back(&output); - - ASSERT_TRUE(DCE(subgraph)); - EXPECT_EQ(subgraph.Ops().size(), 1); - EXPECT_EQ(subgraph.Tensors().size(), 2); - - ASSERT_TRUE( - ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); - ASSERT_TRUE(ValidateSubgraphIO(subgraph)); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_load.cc b/tensorflow/lite/experimental/litert/core/model/model_load.cc deleted file mode 100644 index 86f11054608b05..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_load.cc +++ /dev/null @@ -1,432 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_load.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/model/buffer_manager.h" -#include "tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert::internal { -namespace { - -// Provides a view of model-level resources when constructing litert graph. -class FlatbufferContext { - public: - using LiteRtBufferId = uint32_t; - using TflBufferInd = uint32_t; - using BufferIdMap = absl::flat_hash_map; - - FlatbufferContext(const FlatbufferWrapper& tfl_flatbuffer, - BufferManager* buffer_manager) - : tfl_flatbuffer_(tfl_flatbuffer), buffer_manager_(buffer_manager) {} - - void SetOpCode(LiteRtOpT& litert_op, uint32_t ind) { - const auto builtin_code = - PackedModel()->operator_codes()->Get(ind)->builtin_code(); - litert_op.SetOpCode(static_cast(builtin_code)); - litert::internal::SetTflOpCodeInd(litert_op, ind); - } - - // Get the buffer at the given index in the tflite model. - Expected GetTflBuffer(uint32_t ind) const { - const auto* packed_model = tfl_flatbuffer_.PackedModel(); - if (ind >= packed_model->buffers()->size()) { - LITERT_LOG(LITERT_ERROR, "Buffer index out of range"); - return Error(kLiteRtStatusErrorInvalidArgument); - } - return packed_model->buffers()->Get(ind); - } - - BufferManager* GetBufferManager() { return buffer_manager_; } - - const uint8_t* AllocBase() const { return tfl_flatbuffer_.AllocBase(); } - - const TflPackedModel* PackedModel() const { - return tfl_flatbuffer_.PackedModel(); - } - - BufferIdMap& RegisteredTflBufferIds() { return registered_tfl_buffer_ids_; } - - private: - const FlatbufferWrapper& tfl_flatbuffer_; - BufferManager* buffer_manager_; - BufferIdMap registered_tfl_buffer_ids_; -}; - -LiteRtStatus UnpackOp(FlatbufferContext& context, LiteRtSubgraphT& parent, - const TflPackedOp& tfl_op, LiteRtOpT& litert_op) { - // I/O TENSORS - - if (tfl_op.intermediates() && tfl_op.intermediates()->size() != 0) { - // TODO: b/365299994 - Support intermediates. - LITERT_LOG(LITERT_ERROR, "Intermediate tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - if (tfl_op.mutating_variable_inputs() && - tfl_op.mutating_variable_inputs()->size() != 0) { - // TODO: b/365299994 - Support mutating variable inputs. - LITERT_LOG(LITERT_ERROR, "Mutating variable inputs not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - const auto num_inputs = tfl_op.inputs()->size(); - for (auto i = 0; i < num_inputs; ++i) { - const auto input_ind = tfl_op.inputs()->Get(i); - // Skipping optional input tensor. - if (input_ind == -1) { - continue; - } - AttachInput(&parent.Tensor(input_ind), litert_op); - } - - const auto num_outputs = tfl_op.outputs()->size(); - for (auto i = 0; i < num_outputs; ++i) { - const auto output_ind = tfl_op.outputs()->Get(i); - AttachOutput(&parent.Tensor(output_ind), litert_op); - } - - // OPTIONS - - if (tfl_op.large_custom_options_size() != 0) { - // TODO: b/365299994 - Support large custom options. - LITERT_LOG(LITERT_ERROR, "Large custom options not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - const auto* custom_opts = tfl_op.custom_options(); - if (custom_opts) { - litert_op.SetCustomOptions(custom_opts->data(), custom_opts->size()); - } - - // TODO figure out how to parse builtins with the packed flatbuffer api. - TflOpPtr tfl_op_ptr(tfl_op.UnPack()); - litert::internal::SetTflOptions(litert_op, - std::move(tfl_op_ptr->builtin_options)); - litert::internal::SetTflOptions2(litert_op, - std::move(tfl_op_ptr->builtin_options_2)); - - // OP CODE - - context.SetOpCode(litert_op, tfl_op.opcode_index()); - - return kLiteRtStatusOk; -} - -struct TflBufferContext { - BufferRef buffer; - // Is buffer appended to the flatbuffer? - bool is_external; -}; - -Expected ReadBuffer(FlatbufferContext& context, - uint32_t buffer_ind) { - auto buffer = context.GetTflBuffer(buffer_ind); - if (!buffer) { - return buffer.Error(); - } - - const auto& tfl_buffer = **buffer; - - if (tfl_buffer.offset() != 0) { - // Data is appended to the end of the flatbuffer. - - const auto* alloc_base = context.AllocBase(); - const auto offset = tfl_buffer.offset(); - const auto size = tfl_buffer.size(); - - return TflBufferContext{BufferRef(alloc_base + offset, size), - true}; - } else if (tfl_buffer.data()) { - // Data is in the flatbuffer. - - const auto* start = tfl_buffer.data()->data(); - const auto size = tfl_buffer.data()->size(); - - return TflBufferContext{BufferRef(start, size), false}; - } else { - return TflBufferContext{}; - } -} - -LiteRtStatus UnpackTensor(FlatbufferContext& context, - const TflPackedTensor& tfl_tensor, - LiteRtTensorT& litert_tensor) { - const auto buffer_ind = tfl_tensor.buffer(); - if (buffer_ind != 0) { - auto buffer = ReadBuffer(context, buffer_ind); - if (!buffer) { - return buffer.Error().Status(); - } - - auto it = context.RegisteredTflBufferIds().find(buffer_ind); - if (it != context.RegisteredTflBufferIds().end()) { - litert_tensor.Weights().SetBufferId(it->second); - } else { - BufferContext lrt_buf_ctx; - lrt_buf_ctx.should_append = buffer->is_external; - SetWeightsFromUnownedBuffer(litert_tensor.Weights(), buffer->buffer, - lrt_buf_ctx); - context.RegisteredTflBufferIds()[buffer_ind] = - litert_tensor.Weights().GetBufferId(); - } - } - - // TENSOR TYPE - - TflTensorType tfl_tensor_type(tfl_tensor.type(), TflShapeInfo(tfl_tensor)); - auto tensor_type = MapTensorType(tfl_tensor_type); - if (!tensor_type) { - return tensor_type.Error().Status(); - } - - litert_tensor.SetType(std::move(*tensor_type)); - - // QUANTIZATION - - if (tfl_tensor.quantization()) { - TflQuantizationPtr tfl_quantization(tfl_tensor.quantization()->UnPack()); - auto quantization = MapQuantization(tfl_quantization.get(), litert_tensor); - if (!quantization) { - return quantization.Error().Status(); - } - litert_tensor.SetQarams(std::move(*quantization)); - } - - // MISC - - if (tfl_tensor.name()) { - litert_tensor.SetName(tfl_tensor.name()->str()); - } - - if (tfl_tensor.is_variable()) { - // TODO: b/365299994 - Support variable tensors. - LITERT_LOG(LITERT_ERROR, "Variable tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - if (tfl_tensor.variant_tensors() && - tfl_tensor.variant_tensors()->size() != 0) { - // TODO: b/365299994 - Support variant tensors. - LITERT_LOG(LITERT_ERROR, "Variant tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - if (tfl_tensor.sparsity() != nullptr) { - // TODO: b/365299994 - Support sparsity tensors. - LITERT_LOG(LITERT_ERROR, "Sparsity tensors not yet supported."); - return kLiteRtStatusErrorUnsupported; - } - - return kLiteRtStatusOk; -} - -LiteRtStatus UnpackSubgraph(FlatbufferContext& context, - const TflPackedSubgraph& tfl_subgraph, - LiteRtSubgraphT& litert_subgraph) { - // Unpack tensors. - const auto num_tensors = tfl_subgraph.tensors()->size(); - for (auto i = 0; i < num_tensors; ++i) { - const auto* tfl_tensor = tfl_subgraph.tensors()->Get(i); - LITERT_RETURN_IF_ERROR( - UnpackTensor(context, *tfl_tensor, litert_subgraph.EmplaceTensor())); - } - - // Unpack ops, pass litert_subgraph so they can look up the new litert - // tensors. - const auto num_ops = tfl_subgraph.operators()->size(); - for (auto i = 0; i < num_ops; ++i) { - const auto* tfl_op = tfl_subgraph.operators()->Get(i); - LITERT_RETURN_IF_ERROR(UnpackOp(context, litert_subgraph, *tfl_op, - litert_subgraph.EmplaceOp())); - } - - // Update subgraph I/O. - const auto num_inputs = tfl_subgraph.inputs()->size(); - for (auto i = 0; i < num_inputs; ++i) { - const auto tfl_input_ind = tfl_subgraph.inputs()->Get(i); - litert_subgraph.Inputs().push_back(&litert_subgraph.Tensor(tfl_input_ind)); - } - const auto num_outputs = tfl_subgraph.outputs()->size(); - for (auto i = 0; i < num_outputs; ++i) { - const auto tfl_output_ind = tfl_subgraph.outputs()->Get(i); - litert_subgraph.Outputs().push_back( - &litert_subgraph.Tensor(tfl_output_ind)); - } - - return kLiteRtStatusOk; -} - -LiteRtStatus UnpackSignatures(std::vector& tfl_signatures, - LiteRtModelT& parent) { - for (auto& tfl_signature : tfl_signatures) { - if (tfl_signature->subgraph_index >= parent.Subgraphs().size()) { - LITERT_LOG(LITERT_ERROR, - "Signature does not refer to a valid subgraph index."); - return kLiteRtStatusErrorInvalidArgument; - } - - auto* litert_subgraph = - parent.Subgraphs().at(tfl_signature->subgraph_index); - - auto& tfl_inputs = tfl_signature->inputs; - auto& tfl_outputs = tfl_signature->outputs; - - // Tflite signatures map a tensor index to a name. The input & output - // indexes of signatures and subgraph are not matched, but the nubmer of - // inputs and outputs should be the same. - if (tfl_inputs.size() != litert_subgraph->Inputs().size() || - tfl_outputs.size() != litert_subgraph->Outputs().size()) { - LITERT_LOG(LITERT_ERROR, - "Signature has incorrect number of input/outputs"); - return kLiteRtStatusErrorInvalidFlatbuffer; - } - - // The tensor names may not be matched between signature and subgraph. - // Update the tensor names with the signature names since the signature - // names are used for LiteRT APIs. - for (auto i = 0; i < tfl_inputs.size(); ++i) { - const auto& tfl_input = tfl_inputs.at(i); - auto* index_litert_input = - litert_subgraph->Tensors().at(tfl_input->tensor_index); - index_litert_input->SetName(tfl_input->name); - } - for (auto i = 0; i < tfl_outputs.size(); ++i) { - const auto& tfl_output = tfl_outputs.at(i); - auto* index_litert_output = - litert_subgraph->Tensors().at(tfl_output->tensor_index); - index_litert_output->SetName(tfl_output->name); - } - - // Keep signature input/output names in the same order as the subgraph. - std::vector input_names; - input_names.reserve(tfl_inputs.size()); - for (auto& tensor : litert_subgraph->Inputs()) { - input_names.push_back(std::string(tensor->Name())); - } - std::vector output_names; - output_names.reserve(tfl_outputs.size()); - for (auto& tensor : litert_subgraph->Outputs()) { - output_names.push_back(std::string(tensor->Name())); - } - - parent.EmplaceSignature(litert_subgraph, std::move(input_names), - std::move(output_names), - tfl_signature->signature_key); - } - - if (tfl_signatures.empty()) { - parent.EmplaceSignature(MakeDefaultSignature(parent.MainSubgraph())); - } - - return kLiteRtStatusOk; -} - -Expected UnpackModel(FlatbufferWrapper&& flatbuffer) { - auto litert_model = std::make_unique(std::move(flatbuffer)); - - FlatbufferContext context(litert::internal::GetTflFlatbuffer(*litert_model), - litert_model->Buffers()); - const auto* packed_model = context.PackedModel(); - - if (packed_model->subgraphs()) { - const auto num_subgraphs = packed_model->subgraphs()->size(); - for (auto i = 0; i < num_subgraphs; ++i) { - const auto* tfl_subgraph = packed_model->subgraphs()->Get(i); - LITERT_RETURN_IF_ERROR(UnpackSubgraph(context, *tfl_subgraph, - litert_model->EmplaceSubgraph())); - } - } - - // TODO Figure out how to load signatures in packed flatbuffer. - if (packed_model->signature_defs()) { - std::vector tfl_signatures; - for (auto i = 0; i < packed_model->signature_defs()->size(); ++i) { - const auto* tfl_signature = packed_model->signature_defs()->Get(i); - tfl_signatures.push_back(TflSignaturePtr(tfl_signature->UnPack())); - } - LITERT_RETURN_IF_ERROR(UnpackSignatures(tfl_signatures, *litert_model)); - } else { - litert_model->EmplaceSignature( - MakeDefaultSignature(litert_model->MainSubgraph())); - } - - if (packed_model->metadata()) { - const auto num_metadata = packed_model->metadata()->size(); - for (auto i = 0; i < num_metadata; ++i) { - const auto* tfl_metadata = packed_model->metadata()->Get(i); - auto name = tfl_metadata->name()->str(); - const auto buf_id = tfl_metadata->buffer(); - auto buf = ReadBuffer(context, buf_id); - if (!buf) { - return buf.Error(); - } - - litert_model->PushMetadata(name, buf->buffer.Data(), buf->buffer.Size()); - } - } - - if (packed_model->operator_codes()) { - const auto num_operator_codes = packed_model->operator_codes()->size(); - std::vector tfl_op_codes(num_operator_codes); - for (auto i = 0; i < num_operator_codes; ++i) { - const auto* tfl_op_code = packed_model->operator_codes()->Get(i); - TflOpCodePtr tfl_op_code_ptr(tfl_op_code->UnPack()); - tfl_op_codes[i] = std::move(tfl_op_code_ptr); - } - litert::internal::SetTflOpCodes(*litert_model, std::move(tfl_op_codes)); - } - - return litert_model; -} - -} // namespace - -Expected LoadModelFromBuffer(BufferRef buffer) { - auto flatbuffer = FlatbufferWrapper::CreateFromBuffer(buffer); - if (!flatbuffer) { - return flatbuffer.Error(); - } - return UnpackModel(std::move(**flatbuffer)); -} - -Expected LoadModelFromFile(absl::string_view filename) { - auto flatbuffer = FlatbufferWrapper::CreateFromTflFile(filename); - if (!flatbuffer) { - return flatbuffer.Error(); - } - return UnpackModel(std::move(**flatbuffer)); -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_load.h b/tensorflow/lite/experimental/litert/core/model/model_load.h deleted file mode 100644 index b6a8c2cdd0f650..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_load.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_LOAD_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_LOAD_H_ - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -Expected> LoadModelFromFile( - absl::string_view filename); - -Expected> LoadModelFromBuffer( - BufferRef buffer); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_LOAD_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_serialize.cc b/tensorflow/lite/experimental/litert/core/model/model_serialize.cc deleted file mode 100644 index bc3f1467fd8a6f..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_serialize.cc +++ /dev/null @@ -1,559 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// schema/mutable/schema_generated.h and schema/schema_generated.h (included -// through flatbuffer_tools.h via model.h) have the same #ifdef, thus this line -// need to be put at the top to ensure we get the "mutable" version. -#if 1 -#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" -#endif - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" -#include "tensorflow/lite/experimental/litert/core/insert_order_map.h" -#include "tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/schema/mutable/schema_generated.h" - -namespace litert::internal { -namespace { - -using TensorMap = absl::flat_hash_map; - -// This is expected to be used to serialize the dispatch op custom code. -TflOpCodePtr MakeCustomOpCode(std::string custom_code_name) { - auto custom_code = std::make_unique(); - custom_code->builtin_code = ::tflite::BuiltinOperator_CUSTOM; - custom_code->custom_code = std::move(custom_code_name); - custom_code->version = 1; - return custom_code; -} - -// Utility for accessing flatbuffer state and other relevant state. -class SerializationContext { - public: - // Subgraph and op index pair. - using TflOpInd = std::pair; - using TflOpAssetMap = - absl::flat_hash_map; - using TflBufferInd = uint32_t; - using TflOffsetTensorMap = - absl::flat_hash_map; - using TflBufferIdMap = - absl::flat_hash_map; - - explicit SerializationContext(uint32_t dispatch_op_code_ind, - LiteRtModelT& litert_model, - size_t bytecode_alignment) - : tfl_model_(std::make_unique()), - dispatch_op_code_ind_(dispatch_op_code_ind), - litert_model_(litert_model), - bytecode_alignment_(bytecode_alignment) { - // Tfl expects empty buffer 0. - tfl_model_->buffers.push_back(std::make_unique()); - } - - TflModel& Model() { return *tfl_model_.get(); } - - TflModelPtr Release() && { return std::move(tfl_model_); } - - LiteRtModelT& LitertModel() { return litert_model_; } - - size_t BytecodeAlignment() const { return bytecode_alignment_; } - - LiteRtStatus HandleTensorBuffer(TflTensor& tfl_tensor, - const LiteRtTensorT& litert_tensor) { - const auto litert_buf_id = litert_tensor.Weights().GetBufferId(); - auto* buffer_manager = litert_tensor.Weights().GetBufferManager(); - - auto litert_buf_ctx = buffer_manager->GetContext(litert_buf_id); - if (!litert_buf_ctx) { - LITERT_LOG(LITERT_ERROR, "Failed to get buffer context"); - return litert_buf_ctx.Error().Status(); - } - - auto litert_buf = buffer_manager->GetBuffer(litert_buf_id); - if (!litert_buf) { - LITERT_LOG(LITERT_ERROR, "Failed to get buffer"); - return litert_buf.Error().Status(); - } - - TflBufferInd tfl_buffer_ind; - if (buffer_id_map_.contains(litert_buf_id)) { - tfl_buffer_ind = buffer_id_map_.at(litert_buf_id); - } else { - auto& tfl_buffer = - tfl_model_->buffers.emplace_back(std::make_unique()); - tfl_buffer_ind = tfl_model_->buffers.size() - 1; - - if (litert_buf_ctx->get().should_append) { - tfl_buffer->offset = 1; - tfl_buffer->size = 1; - offset_tensor_map_.emplace(tfl_buffer_ind, litert_buf_id); - } else { - tfl_buffer->data.assign(litert_buf->Data(), - litert_buf->Data() + litert_buf->Size()); - } - buffer_id_map_[litert_buf_id] = tfl_buffer_ind; - } - - tfl_tensor.buffer = tfl_buffer_ind; - - return kLiteRtStatusOk; - } - - // Add to tfl model metadata. - void PushMetadata(std::string key, BufferRef data) { - auto& tfl_buffer = - tfl_model_->buffers.emplace_back(std::make_unique()); - const auto tfl_buffer_ind = tfl_model_->buffers.size() - 1; - tfl_buffer->data.assign(data.Data(), data.Data() + data.Size()); - tfl_model_->metadata_buffer.push_back(tfl_buffer_ind); - auto tfl_metadata = std::make_unique(); - tfl_metadata->name = key; - tfl_metadata->buffer = tfl_buffer_ind; - tfl_model_->metadata.push_back(std::move(tfl_metadata)); - } - - // Keep track of the given ops index as having a particular asset. - // These will be used to update the ops with the correct offset and size - // after the model is fully packed. - void AttachAssetToOp(size_t subgraph_ind, size_t op_ind, - LiteRtModelT::OpAssetReference asset) { - TflOpInd tfl_op_ind = {subgraph_ind, op_ind}; - op_asset_map_.emplace(tfl_op_ind, asset); - } - - const TflOpAssetMap& OpAssetMap() const { return op_asset_map_; } - - const TflOffsetTensorMap& OffsetTensorMap() const { - return offset_tensor_map_; - } - - // Get the index in the tfl op codes for the dispatch custom code. - // This should be the only new custom code added after loading the initial - // tfl. - uint32_t DispatchOpCodeInd() const { return dispatch_op_code_ind_; } - - private: - TflModelPtr tfl_model_; - uint32_t dispatch_op_code_ind_; - LiteRtModelT& litert_model_; - - TflOpAssetMap op_asset_map_; - TflOffsetTensorMap offset_tensor_map_; - TflBufferIdMap buffer_id_map_; - size_t bytecode_alignment_ = 0; -}; - -void SetOptions(const LiteRtOpT& litert_op, TflOp& tfl_op) { - tfl_op.builtin_options = litert::internal::GetTflOptions(litert_op); - if (litert_op.CustomOptions().Size() != 0) { - tfl_op.custom_options = litert_op.CustomOptions().ToVec(); - tfl_op.custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS; - } -} - -LiteRtStatus PackOp(SerializationContext& builder, LiteRtOpT& litert_op, - TflOp& tfl_op, const TensorMap& tensor_map) { - // Get index of the op code in the tfl model. - auto tfl_op_code_ind = litert::internal::GetTflOpCodeInd(litert_op); - const bool is_dispatch_op = - tfl_op_code_ind == litert::internal::kDispatchOpCodeTflInd; - - if (is_dispatch_op) { - tfl_op_code_ind = builder.DispatchOpCodeInd(); - } - - tfl_op.opcode_index = tfl_op_code_ind; - - // Look up the tensor indices in the tfl model. - for (auto* in : litert_op.Inputs()) { - tfl_op.inputs.push_back(tensor_map.at(in)); - } - for (auto* out : litert_op.Outputs()) { - tfl_op.outputs.push_back(tensor_map.at(out)); - } - - // Set generic options. - tfl_op.builtin_options = litert::internal::GetTflOptions(litert_op); - - return kLiteRtStatusOk; -} - -LiteRtStatus PackTensor(SerializationContext& builder, - LiteRtTensorT& litert_tensor, TflTensor& tfl_tensor) { - auto tfl_tensor_type = MapTensorType(litert_tensor.Type()); - if (!tfl_tensor_type) { - return tfl_tensor_type.Error().Status(); - } - auto [tfl_elem_type, tfl_shape] = *tfl_tensor_type; - - tfl_tensor.type = tfl_elem_type; - tfl_tensor.shape.assign(tfl_shape.shape.begin(), tfl_shape.shape.end()); - tfl_tensor.has_rank = tfl_shape.has_rank; - tfl_tensor.shape_signature.assign(tfl_shape.shape_signature.begin(), - tfl_shape.shape_signature.end()); - - auto tfl_quantization = MapQuantization(litert_tensor.Qparams()); - if (!tfl_quantization) { - return tfl_quantization.Error().Status(); - } - tfl_tensor.quantization = std::move(*tfl_quantization); - - LITERT_RETURN_IF_ERROR(builder.HandleTensorBuffer(tfl_tensor, litert_tensor)); - - tfl_tensor.name = std::string(litert_tensor.Name()); - - return kLiteRtStatusOk; -} - -LiteRtStatus PackSubgraph(SerializationContext& builder, - LiteRtSubgraphT& litert_subgraph, - TflSubgraph& tfl_subgraph, TensorMap& tensor_map, - size_t subgraph_ind) { - for (auto* tensor : litert_subgraph.Tensors()) { - tfl_subgraph.tensors.push_back(std::make_unique()); - tensor_map.insert({tensor, tfl_subgraph.tensors.size() - 1}); - LITERT_RETURN_IF_ERROR( - PackTensor(builder, *tensor, *tfl_subgraph.tensors.back())); - } - - for (auto i = 0; i < litert_subgraph.Ops().size(); ++i) { - auto* op = litert_subgraph.Ops().at(i); - - tfl_subgraph.operators.push_back(std::make_unique()); - auto& tfl_op = *tfl_subgraph.operators.back(); - LITERT_RETURN_IF_ERROR(PackOp(builder, *op, tfl_op, tensor_map)); - - // Set custom options. - if (auto op_asset = builder.LitertModel().FindOpAsset(op)) { - // This mechanism is currently only used for dispatch ops to store - // location of bytecode. Here we update the name and placeholder values - // for offset and size. These will be updated when the model is fully - // packed. - auto dispatch_opts = MakeDispatchOpOptions({ - 1, - 1, - std::string(op_asset->second), - }); - tfl_op.custom_options = dispatch_opts.ToVec(); - - // Save the "location" of the op and its asset. - builder.AttachAssetToOp(subgraph_ind, i, *op_asset); - - } else if (op->CustomOptions().Size() != 0) { - tfl_op.custom_options = op->CustomOptions().ToVec(); - } - - tfl_op.custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS; - } - - for (auto* in : litert_subgraph.Inputs()) { - tfl_subgraph.inputs.push_back(tensor_map.at(in)); - } - - for (auto* out : litert_subgraph.Outputs()) { - tfl_subgraph.outputs.push_back(tensor_map.at(out)); - } - - return kLiteRtStatusOk; -} - -Expected PackAsTflite(SerializationContext& builder) { - auto& litert_model = builder.LitertModel(); - - // Pack litert subgraphs into tfl subgraphs and save the mapping of - // tensors. - TensorMap tensor_map; - for (auto i = 0; i < litert_model.Subgraphs().size(); ++i) { - auto& litert_subgraph = litert_model.Subgraph(i); - auto& tfl_subgraph = *builder.Model().subgraphs.emplace_back( - std::make_unique()); - LITERT_RETURN_IF_ERROR( - PackSubgraph(builder, litert_subgraph, tfl_subgraph, tensor_map, i)); - } - - // Serialize the signatures using saved tensor mapping. - for (auto* litert_signature : litert_model.Signatures()) { - auto* litert_subgraph = &litert_signature->GetSubgraph(); - - auto& tfl_signature = *builder.Model().signature_defs.emplace_back( - std::make_unique()); - tfl_signature.signature_key = std::string(litert_signature->Key()); - - auto begin = litert_model.Subgraphs().cbegin(); - auto end = litert_model.Subgraphs().cend(); - const auto litert_subgraph_ind = - std::find(begin, end, litert_subgraph) - begin; - tfl_signature.subgraph_index = litert_subgraph_ind; - - auto input_ind = 0; - for (const auto& litert_name : litert_signature->InputNames()) { - auto& tfl_input = *tfl_signature.inputs.emplace_back( - std::make_unique<::tflite::TensorMapT>()); - tfl_input.name = litert_name; - tfl_input.tensor_index = - tensor_map.find(litert_subgraph->Inputs().at(input_ind))->second; - ++input_ind; - } - - auto output_ind = 0; - for (const auto& litert_name : litert_signature->OutputNames()) { - auto& tfl_output = *tfl_signature.outputs.emplace_back( - std::make_unique<::tflite::TensorMapT>()); - tfl_output.name = litert_name; - tfl_output.tensor_index = - tensor_map.find(litert_subgraph->Outputs().at(output_ind))->second; - ++output_ind; - } - } - - // Serialize metadata. - for (auto it = litert_model.MetadataBegin(); it != litert_model.MetadataEnd(); - ++it) { - const auto& [key, buf_id] = *it; - auto buf = litert_model.Buffers()->GetBuffer(buf_id); - if (!buf) { - LITERT_LOG(LITERT_ERROR, "Failed to find metadata buffer"); - return buf.Error(); - } - builder.PushMetadata(key, *buf); - } - - builder.Model().version = 3; - - return std::move(builder).Release(); -} - -// Appends external buffers to the back of the serialized tflite model. Updates -// the ops that references them with the correct offset and size in-place. -Expected> SerializeWithAppendedBuffers( - SerializationContext& builder, OwningBufferRef serialized_tfl, - LiteRtModelT& litert_model) { - if (builder.OpAssetMap().empty() && builder.OffsetTensorMap().empty()) { - return serialized_tfl; - } - - const auto align = builder.BytecodeAlignment(); - // Pad the original model to the next multiple of the alignment. - auto align_offset = [align](size_t& cur_offset) { - cur_offset = (cur_offset + align - 1) & ~(align - 1); - }; - - size_t cur_offset = serialized_tfl.Size(); - align_offset(cur_offset); - - // Calculate the offset and size of each op asset. - InsertOrderMap> - asset_buffer_offsets; - for (auto it = builder.OpAssetMap().cbegin(); - it != builder.OpAssetMap().cend(); ++it) { - const auto& [buf_id, name] = it->second; - auto asset_buf = litert_model.Buffers()->GetBuffer(buf_id); - if (!asset_buf) { - return asset_buf.Error(); - } - if (asset_buffer_offsets.Contains(buf_id)) { - continue; - } - asset_buffer_offsets.InsertOrAssign(buf_id, - {cur_offset, asset_buf->Size()}); - cur_offset += asset_buf->Size(); - align_offset(cur_offset); - } - - // Calculate the offset and size of each offset tensor. - InsertOrderMap> - offset_tensor_offsets; - for (auto it = builder.OffsetTensorMap().cbegin(); - it != builder.OffsetTensorMap().cend(); ++it) { - const auto& [tfl_buffer_ind, litert_buf_id] = *it; - auto litert_buf = litert_model.Buffers()->GetBuffer(litert_buf_id); - if (!litert_buf) { - LITERT_LOG(LITERT_ERROR, "Failed to find offset tensor buffer"); - return litert_buf.Error(); - } - if (offset_tensor_offsets.Contains(tfl_buffer_ind)) { - continue; - } - offset_tensor_offsets.InsertOrAssign(tfl_buffer_ind, - {cur_offset, litert_buf->Size()}); - cur_offset += litert_buf->Size(); - } - - // Read serialized tflite in packed form. - auto* tfl_model = tflite::GetMutableModel(serialized_tfl.Data()); - - // Find the ops that have external buffers and mark them with the future size - // and offset. - for (auto sg_ind = 0; sg_ind < tfl_model->mutable_subgraphs()->size(); - ++sg_ind) { - auto* sg = tfl_model->mutable_subgraphs()->GetMutableObject(sg_ind); - - for (auto op_ind = 0; op_ind < sg->mutable_operators()->size(); ++op_ind) { - SerializationContext::TflOpInd ind = {sg_ind, op_ind}; - - auto asset_buffer = builder.OpAssetMap().find(ind); - if (asset_buffer == builder.OpAssetMap().end()) { - // No external buffer for this op. - continue; - } - - auto* op = sg->mutable_operators()->GetMutableObject(op_ind); - - // The id of the buffer in the litert model. - const auto buf_id = asset_buffer->second.first; - - // The real offset and size of the buffer in the serialized tflite model. - const auto offset_and_size = asset_buffer_offsets.Find(buf_id); - if (!offset_and_size) { - LITERT_LOG(LITERT_ERROR, "Failed to find offset and size for buffer"); - return Error(kLiteRtStatusErrorInvalidFlatbuffer); - } - const auto [offset, size] = offset_and_size->get().second; - - // The custom options should have already been set with the name and - // placeholder values for size and offset. - MutableBufferRef old_raw_opts( - op->mutable_custom_options()->data(), - op->mutable_custom_options()->size()); - - // Update with real size and offset. - DispatchOpOptions dispach_opts(GetDispatchOpOptions(old_raw_opts)); - dispach_opts.bytecode_offset = offset; - dispach_opts.bytecode_size = size; - - if (!UpdateDispatchOpOptionsInPlace(dispach_opts, old_raw_opts)) { - LITERT_LOG(LITERT_ERROR, "Failed to update dispatch op options"); - return Error(kLiteRtStatusErrorInvalidFlatbuffer); - } - } - } - - // Find the buffers that are offset buffers and mark them with the future - // size and offset. - for (auto i = 0; i < tfl_model->mutable_buffers()->size(); ++i) { - auto* tfl_buffer = tfl_model->mutable_buffers()->GetMutableObject(i); - auto offset_size = offset_tensor_offsets.Find(i); - if (!offset_size) { - // Not offset buffer. - continue; - } - const auto [offset, size] = offset_size->get().second; - const auto offset_ok = tfl_buffer->mutate_offset(offset); - const auto size_ok = tfl_buffer->mutate_size(size); - if (!offset_ok || !size_ok) { - LITERT_LOG(LITERT_ERROR, "Failed to update offset and size for buffer"); - return Error(kLiteRtStatusErrorInvalidFlatbuffer); - } - } - - // Allocate buffer enough for original model and appendd buffers and copy. - OwningBufferRef final_model(cur_offset); - - // Copy serialized tflite model. - uint8_t* const start = final_model.Data(); - std::memcpy(start, serialized_tfl.Data(), serialized_tfl.Size()); - - // Copy asset buffers (aligned). - for (auto it = asset_buffer_offsets.Begin(); it != asset_buffer_offsets.End(); - ++it) { - const auto buf_id = it->first; - - auto asset_buf = litert_model.Buffers()->GetBuffer(buf_id); - if (!asset_buf) { - LITERT_LOG(LITERT_ERROR, "Failed to find asset buffer"); - return asset_buf.Error(); - } - uint8_t* const offset = start + it->second.first; - std::memcpy(offset, asset_buf->Data(), asset_buf->Size()); - } - - // Copy offset tensor buffers. - for (auto it = offset_tensor_offsets.Begin(); - it != offset_tensor_offsets.End(); ++it) { - const auto buf_id = it->first; - - auto offset_buf = litert_model.Buffers()->GetBuffer(buf_id); - if (!offset_buf) { - LITERT_LOG(LITERT_ERROR, "Failed to find offset tensor buffer"); - return offset_buf.Error(); - } - - uint8_t* const offset = start + it->second.first; - std::memcpy(offset, offset_buf->Data(), offset_buf->Size()); - } - - return final_model; -} - -} // namespace - -Expected> SerializeModel(LiteRtModelT&& model, - size_t bytecode_alignment) { - // Pass the op code list through that was saved during loading. Add one more - // op code for the dispatch ops - auto tfl_op_codes = litert::internal::TakeTflOpCodes(model); - tfl_op_codes.push_back( - MakeCustomOpCode(std::string(kLiteRtDispatchOpCustomCode))); - - SerializationContext builder(tfl_op_codes.size() - 1, model, - bytecode_alignment); - builder.Model().operator_codes = std::move(tfl_op_codes); - - auto tfl_model = PackAsTflite(builder); - if (!tfl_model) { - LITERT_LOG(LITERT_ERROR, "Failed to pack as tflite"); - return tfl_model.Error(); - } - - auto serialized_tfl = SerializeFlatbuffer(**tfl_model); - auto serialized_with_buffers = - SerializeWithAppendedBuffers(builder, std::move(serialized_tfl), model); - if (!serialized_with_buffers) { - LITERT_LOG(LITERT_ERROR, "Failed to serialize with appended buffers"); - return serialized_with_buffers.Error(); - } - - if (!VerifyFlatbuffer(serialized_with_buffers->Span())) { - LITERT_LOG(LITERT_ERROR, "Failed to verify flatbuffer"); - return Error(kLiteRtStatusErrorInvalidFlatbuffer); - } - - return serialized_with_buffers; -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_serialize.h b/tensorflow/lite/experimental/litert/core/model/model_serialize.h deleted file mode 100644 index 0ffa2d878ba8c3..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_serialize.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_SERIALIZE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_SERIALIZE_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -Expected> SerializeModel( - LiteRtModelT&& model, size_t bytecode_alignment = 1); - -} // namespace litert::internal - - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_SERIALIZE_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_test.cc b/tensorflow/lite/experimental/litert/core/model/model_test.cc deleted file mode 100644 index 52dfcdc3778a4f..00000000000000 --- a/tensorflow/lite/experimental/litert/core/model/model_test.cc +++ /dev/null @@ -1,531 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/core/model/buffer_manager.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert::internal { -namespace { - -using ::testing::ElementsAreArray; - -// -// Model -// - -TEST(ModelTest, GetMetadata) { - static constexpr absl::string_view kMetadata = "VALUE"; - static constexpr absl::string_view kKey = "KEY"; - - LiteRtModelT model; - LITERT_ASSERT_OK(model.PushMetadata(kKey, kMetadata)); - auto found_metadata = model.FindMetadata(kKey); - ASSERT_TRUE(found_metadata); - EXPECT_EQ(found_metadata->StrView(), kMetadata); -} - -TEST(ModelTest, MetadataDNE) { - LiteRtModelT model; - auto res = model.FindMetadata("FOO"); - ASSERT_FALSE(res.HasValue()); -} - -TEST(ModelTest, GetBuildStamp) { - static constexpr absl::string_view kSocManufacturer = "honda"; - static constexpr absl::string_view kSocModel = "accord"; - - LiteRtModelT model; - - LITERT_ASSERT_OK(model.PushMetadata( - kLiteRtBuildStampKey, *MakeBuildStamp(kSocManufacturer, kSocModel))); - auto build_stamp = GetBuildStamp(model); - ASSERT_TRUE(build_stamp); - EXPECT_TRUE(IsCompiled(model)); - EXPECT_EQ(build_stamp->soc_manufacturer, kSocManufacturer); - EXPECT_EQ(build_stamp->soc_model, kSocModel); -} - -TEST(ModelTest, EmplaceSubgraph) { - LiteRtModelT model; - auto& sg = model.EmplaceSubgraph(); - EXPECT_EQ(model.Subgraphs().size(), 1); - auto& tensor = sg.EmplaceTensor(); - EXPECT_EQ(tensor.Weights().GetBufferManager(), model.Buffers()); -} - -TEST(ModelTest, Signature) { - static constexpr absl::string_view kSignatureName = "MY_SIGNATURE"; - - const std::vector inputs = {"input_1", "input_2"}; - const std::vector outputs = {"output_1"}; - - LiteRtModelT model; - auto& subgraph = model.EmplaceSubgraph(); - - auto& signature = model.EmplaceSignature(&subgraph, inputs, outputs, - std::string(kSignatureName)); - - auto found_signature = model.FindSignature(kSignatureName); - ASSERT_TRUE(found_signature); - EXPECT_EQ(found_signature->get(), signature); -} - -TEST(ModelTest, SignatureDNE) { - static constexpr absl::string_view kSignatureName = "MY_SIGNATURE"; - LiteRtModelT model; - auto found_signature = model.FindSignature(kSignatureName); - EXPECT_FALSE(found_signature); -} - -TEST(ModelTest, AttachExternalBufferToOp) { - static constexpr absl::string_view kBufferData = "BUFFER_DATA"; - static constexpr absl::string_view kOpName = "OP1"; - static constexpr absl::string_view kOp2Name = "OP2"; - - LiteRtModelT model; - auto& subgraph = model.EmplaceSubgraph(); - auto& op = subgraph.EmplaceOp(); - auto& op2 = subgraph.EmplaceOp(); - - OwningBufferRef external_buf(kBufferData); - - auto buf1_id = model.Buffers()->RegisterOwnedBuffer(std::move(external_buf)); - - model.AttachAssetToOp(&op, buf1_id, std::string(kOpName)); - model.AttachAssetToOp(&op2, buf1_id, std::string(kOp2Name)); - - auto op_1_res = model.FindOpAsset(&op); - ASSERT_TRUE(op_1_res); - EXPECT_EQ(op_1_res->second, kOpName); - EXPECT_EQ(op_1_res->first, buf1_id); - - auto op_2_res = model.FindOpAsset(&op2); - ASSERT_TRUE(op_2_res); - EXPECT_EQ(op_2_res->second, kOp2Name); - EXPECT_EQ(op_2_res->first, buf1_id); -} - -TEST(ModelTest, ExternalBufferNotFound) { - LiteRtModelT model; - LiteRtOpT op; - ASSERT_FALSE(model.FindOpAsset(&op)); -} - -// -// Subgraph -// - -TEST(ModelSubgraphTest, Input) { - LiteRtTensorT tensor; - LiteRtSubgraphT subgraph; - subgraph.Inputs().push_back(&tensor); - EXPECT_EQ(&subgraph.Input(0), subgraph.Inputs().front()); -} - -TEST(ModelSubgraphTest, Output) { - LiteRtTensorT tensor; - LiteRtSubgraphT subgraph; - subgraph.Outputs().push_back(&tensor); - EXPECT_EQ(&subgraph.Output(0), subgraph.Outputs().front()); -} - -TEST(ModelSubgraphTest, EmplaceTensor) { - LiteRtSubgraphT subgraph; - auto& tensor = subgraph.EmplaceTensor(); - ASSERT_EQ(subgraph.Tensors().size(), 1); - EXPECT_THAT(subgraph.Tensors(), ElementsAreArray({&tensor})); -} - -TEST(ModelSubgraphTest, EmplaceOp) { - LiteRtSubgraphT subgraph; - auto& op = subgraph.EmplaceOp(); - ASSERT_EQ(subgraph.Ops().size(), 1); - EXPECT_THAT(subgraph.Ops(), ElementsAreArray({&op})); -} - -// -// Op -// - -TEST(ModelOpTest, Input) { - LiteRtOpT op; - LiteRtTensorT tensor; - op.Inputs().push_back(&tensor); - EXPECT_EQ(&op.Input(0), op.Inputs().front()); -} - -TEST(ModelOpTest, Output) { - LiteRtOpT op; - LiteRtTensorT tensor; - op.Outputs().push_back(&tensor); - EXPECT_EQ(&op.Output(0), op.Outputs().front()); -} - -TEST(ModelOpTest, CustomOptions) { - static constexpr absl::string_view kOpts = "OPTIONS"; - - LiteRtOpT op; - op.SetCustomOptions(kOpts); - EXPECT_EQ(op.CustomOptions().StrView(), kOpts); -} - -TEST(ModelOpTest, Options) { - static constexpr auto kOptsType = ::tflite::BuiltinOptions_AddOptions; - - TflOptions options; - options.type = kOptsType; - options.Set(::tflite::AddOptionsT()); - - LiteRtOpT op; - litert::internal::SetTflOptions(op, std::move(options)); - - ASSERT_EQ(litert::internal::GetTflOptions(op).type, kOptsType); -} - -TEST(ModelOpTest, OpCode) { - constexpr static auto kOpCode = kLiteRtOpCodeTflMul; - - LiteRtOpT op; - op.SetOpCode(kOpCode); - EXPECT_EQ(op.OpCode(), kOpCode); -} - -// -// Tensor -// - -TEST(ModelTensorTypeTest, MakeRankedTensorType) { - static constexpr const int32_t kDims[] = {2, 2}; - static constexpr auto kDimsSpan = absl::MakeConstSpan(kDims); - static constexpr auto kElementType = kLiteRtElementTypeFloat32; - const auto tensor_type = MakeRankedTensorType(kElementType, kDimsSpan); - ASSERT_EQ(tensor_type.first, kLiteRtRankedTensorType); - EXPECT_EQ(tensor_type.second.ranked_tensor_type.element_type, kElementType); - const auto& layout = tensor_type.second.ranked_tensor_type.layout; - ASSERT_EQ(layout.rank, kDimsSpan.size()); - EXPECT_THAT(absl::MakeConstSpan(layout.dimensions, kDimsSpan.size()), - ElementsAreArray(kDimsSpan)); -} - -TEST(ModelQuantizationTypeTest, MakePerTensor) { - static constexpr auto kScale = 1.0f; - static constexpr auto kZero = 1L; - const auto quant = MakePerTensorQuantization(kScale, kZero); - ASSERT_EQ(quant.first, kLiteRtQuantizationPerTensor); - const auto& per_tensor = quant.second.per_tensor; - EXPECT_EQ(per_tensor.scale, kScale); - EXPECT_EQ(per_tensor.zero_point, kZero); -} - -TEST(ModelQuantizationTypeTest, MakePerChannel) { - static constexpr std::array kScale = {1.0f, 2.0f}; - static constexpr std::array kZero = {1L, 2L}; - static constexpr int32_t kQdim = 0; - - LiteRtTensorT tensor; - const auto quant = MakePerChannelQuantization( - kScale, kZero, kQdim, - [&tensor](auto s) { return tensor.RequestScratchBuffer(s); }); - - ASSERT_EQ(quant.first, kLiteRtQuantizationPerChannel); - const auto& per_channel = quant.second.per_channel; - - const auto size = per_channel.num_channels; - ASSERT_EQ(size, 2); - EXPECT_EQ(per_channel.quantized_dimension, 0); - - auto scales = absl::MakeConstSpan(per_channel.scales, size); - auto zeros = absl::MakeConstSpan(per_channel.zero_points, size); - - EXPECT_THAT(scales, ElementsAreArray(kScale)); - EXPECT_THAT(zeros, ElementsAreArray(kZero)); -} - -TEST(ModelWeightsTest, EmptyWeights) { - LiteRtWeightsT weights; - EXPECT_EQ(weights.Buffer().Size(), 0); -} - -TEST(ModelWeightsTest, WeightsWithExternalBufferManager) { - static constexpr absl::string_view kData = "some_data"; - BufferManager manager; - - LiteRtWeightsT weights; - weights.SetBufferManager(&manager); - - BufferRef buf(kData.data(), kData.size()); - SetWeightsFromUnownedBuffer(weights, buf); - - EXPECT_EQ(manager.GetBuffer(weights.GetBufferId())->StrView(), kData); - EXPECT_EQ(weights.Buffer().StrView(), kData); -} - -TEST(ModelWeightsTest, WeightsFromUnownedBuffer) { - static constexpr absl::string_view kData = "some_data"; - - LiteRtWeightsT weights; - BufferRef buf(kData.data(), kData.size()); - SetWeightsFromUnownedBuffer(weights, buf); - - EXPECT_EQ(weights.Buffer().StrView(), kData); -} - -TEST(ModelWeightsTest, WeightsFromOwnedBuffer) { - static constexpr absl::string_view kData = "some_data"; - - LiteRtWeightsT weights; - - OwningBufferRef buf(kData); - SetWeightsFromUnownedBuffer(weights, std::move(buf)); - - EXPECT_EQ(weights.Buffer().StrView(), kData); -} - -TEST(ModelWeightsTest, OverwriteBuffer) { - static constexpr absl::string_view kData = "some_data"; - static constexpr absl::string_view kData2 = "some_data2"; - - LiteRtWeightsT weights; - - { - OwningBufferRef buf(kData); - SetWeightsFromOwnedBuffer(weights, std::move(buf)); - } - - { - OwningBufferRef buf(kData2); - SetWeightsFromOwnedBuffer(weights, std::move(buf)); - } - - EXPECT_EQ(weights.Buffer().StrView(), kData2); -} - -TEST(ModelTensorTest, Name) { - static constexpr absl::string_view kName = "TENSOR_NAME"; - - LiteRtTensorT tensor; - tensor.SetName(std::string(kName.begin(), kName.end())); - EXPECT_EQ(tensor.Name(), kName); -} - -TEST(ModelTensorTest, Use) { - LiteRtTensorT tensor; - tensor.Users().emplace_back(); - tensor.UserArgInds().push_back(0); - auto [user, ind] = tensor.GetUse(0); - EXPECT_EQ(user, tensor.Users().front()); - EXPECT_EQ(ind, 0); -} - -TEST(ModelTensorTest, DefiningOp) { - LiteRtTensorT tensor; - LiteRtOpT op; - tensor.SetDefiningOp(op, 0); - EXPECT_EQ(tensor.DefiningOp(), &op); - EXPECT_EQ(tensor.DefiningOpOutInd(), 0); -} - -TEST(ModelTest, TransferSubgraphToReindexComposite) { - LiteRtModelT model; - - auto& subgraph = model.EmplaceSubgraph(); - auto& other_subgraph = model.EmplaceSubgraph(); - auto& decomp_subgraph = model.EmplaceSubgraph(); - - auto& composite = subgraph.EmplaceOp(); - composite.SetOpCode(kLiteRtOpCodeShloComposite); - ::tflite::StableHLOCompositeOptionsT opts; - opts.name = "composite"; - opts.decomposition_subgraph_index = 2; - TflOptions2 options; - options.type = tflite::BuiltinOptions2_StableHLOCompositeOptions; - options.Set(std::move(opts)); - litert::internal::SetTflOptions2(composite, std::move(options)); - - LiteRtSubgraphT::Alloc dest; - std::vector indices = {1}; - model.TransferSubgraphTo(dest, std::move(indices)); - - EXPECT_THAT(model.Subgraphs(), - ElementsAreArray({&subgraph, &decomp_subgraph})); - EXPECT_THAT(dest.Elements(), ElementsAreArray({&other_subgraph})); - - const auto& new_opts = litert::internal::GetTflOptions2(composite); - const auto new_decomp_ind = - new_opts.AsStableHLOCompositeOptions()->decomposition_subgraph_index; - EXPECT_EQ(new_decomp_ind, 1); -} - -TEST(ModelTest, TransferSubgraphToReindexCompositeNoChange) { - LiteRtModelT model; - - auto& subgraph = model.EmplaceSubgraph(); - auto& decomp_subgraph = model.EmplaceSubgraph(); - auto& other_subgraph = model.EmplaceSubgraph(); - - auto& composite = subgraph.EmplaceOp(); - composite.SetOpCode(kLiteRtOpCodeShloComposite); - ::tflite::StableHLOCompositeOptionsT opts; - opts.name = "composite"; - opts.decomposition_subgraph_index = 1; - TflOptions2 options; - options.type = tflite::BuiltinOptions2_StableHLOCompositeOptions; - ; - options.Set(std::move(opts)); - litert::internal::SetTflOptions2(composite, std::move(options)); - - LiteRtSubgraphT::Alloc dest; - std::vector indices = {2}; - model.TransferSubgraphTo(dest, std::move(indices)); - - EXPECT_THAT(model.Subgraphs(), - ElementsAreArray({&subgraph, &decomp_subgraph})); - EXPECT_THAT(dest.Elements(), ElementsAreArray({&other_subgraph})); - - const auto& new_opts = litert::internal::GetTflOptions2(composite); - const auto new_decomp_ind = - new_opts.AsStableHLOCompositeOptions()->decomposition_subgraph_index; - EXPECT_EQ(new_decomp_ind, 1); -} - -TEST(ModelTest, TransferSubgraphToReindexCompositeMultiple) { - LiteRtModelT model; - - auto& subgraph = model.EmplaceSubgraph(); - auto& other_subgraph = model.EmplaceSubgraph(); - auto& other_subgraph2 = model.EmplaceSubgraph(); - auto& other_subgraph3 = model.EmplaceSubgraph(); - auto& decomp_subgraph = model.EmplaceSubgraph(); - auto& other_subgraph4 = model.EmplaceSubgraph(); - - auto& composite = subgraph.EmplaceOp(); - composite.SetOpCode(kLiteRtOpCodeShloComposite); - ::tflite::StableHLOCompositeOptionsT opts; - opts.name = "composite"; - opts.decomposition_subgraph_index = 4; - TflOptions2 options; - options.type = tflite::BuiltinOptions2_StableHLOCompositeOptions; - ; - options.Set(std::move(opts)); - litert::internal::SetTflOptions2(composite, std::move(options)); - - LiteRtSubgraphT::Alloc dest; - std::vector indices = {1, 3, 5}; - model.TransferSubgraphTo(dest, std::move(indices)); - - EXPECT_THAT(model.Subgraphs(), ElementsAreArray({&subgraph, &other_subgraph2, - &decomp_subgraph})); - EXPECT_THAT( - dest.Elements(), - ElementsAreArray({&other_subgraph, &other_subgraph3, &other_subgraph4})); - - const auto& new_opts = litert::internal::GetTflOptions2(composite); - const auto new_decomp_ind = - new_opts.AsStableHLOCompositeOptions()->decomposition_subgraph_index; - EXPECT_EQ(new_decomp_ind, 2); -} - -// -// Misc Ir Containers -// - -TEST(ModelOpListTest, Push) { - LiteRtOpListT op_list; - LiteRtOpT op; - op_list.Push(&op); - auto vec = op_list.Values(); - EXPECT_EQ(vec.front().first, &op); -} - -TEST(ModelOpListTest, PushWithIndex) { - LiteRtOpListT op_list; - LiteRtOpT op; - op_list.Push(&op, 1); - auto vec = op_list.Values(); - EXPECT_EQ(vec.front().first, &op); - EXPECT_EQ(vec.front().second, 1); -} - -// -// Traversal Utils -// - -TEST(CcForEachIrTest, OpF3) { - LiteRtModelT model; - model.EmplaceSubgraph().EmplaceOp(); - - int count = 0; - ForEachIr(&model, [&](LiteRtSubgraph subgraph, int32_t subgraph_index, - LiteRtOp op) { count++; }); - EXPECT_EQ(count, 1); -} - -TEST(CcForEachIrTest, OpF1) { - LiteRtModelT model; - model.EmplaceSubgraph().EmplaceOp(); - - int count = 0; - ForEachIr(&model, [&](LiteRtOp op) { count++; }); - EXPECT_EQ(count, 1); -} - -TEST(CcForEachIrTest, OpF2) { - LiteRtModelT model; - model.EmplaceSubgraph().EmplaceOp(); - - int count = 0; - ForEachIr(&model, [&](LiteRtSubgraph subgraph, LiteRtOp op) { count++; }); - EXPECT_EQ(count, 1); -} - -TEST(CcForEachIrTest, SgF1) { - LiteRtModelT model; - model.EmplaceSubgraph().EmplaceOp(); - - int count = 0; - ForEachIr(&model, [&](LiteRtSubgraph subgraph) { count++; }); - EXPECT_EQ(count, 1); -} - -TEST(CcForEachIrTest, SgF2) { - LiteRtModelT model; - model.EmplaceSubgraph().EmplaceOp(); - - int count = 0; - ForEachIr(&model, - [&](LiteRtSubgraph subgraph, int32_t subgraph_index) { count++; }); - EXPECT_EQ(count, 1); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/util/BUILD b/tensorflow/lite/experimental/litert/core/util/BUILD deleted file mode 100644 index 88fb50a693cb11..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/BUILD +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - # copybara:uncomment "//third_party/mediapipe/calculators/tensor:__subpackages__", - "//tensorflow/lite/experimental/litert:__subpackages__", - ], -) - -cc_library( - name = "flatbuffer_tools", - srcs = ["flatbuffer_tools.cc"], - hdrs = [ - "flatbuffer_tools.h", - "//tensorflow/lite/experimental/litert/cc:litert_consts.h", - ], - deps = [ - "//tensorflow/compiler/mlir/lite:allocation", - "//tensorflow/lite:model_builder", - "//tensorflow/lite:stderr_reporter", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@flatbuffers//:runtime_cc", - ], -) - -cc_test( - name = "flatbuffer_tools_test", - srcs = ["flatbuffer_tools_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - ], - deps = [ - ":flatbuffer_tools", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "tensor_type_util", - srcs = [ - "tensor_type_util.cc", - ], - hdrs = [ - "tensor_type_util.h", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "tensor_type_util_test", - srcs = ["tensor_type_util_test.cc"], - deps = [ - ":tensor_type_util", - "//tensorflow/lite/experimental/litert/c:litert_model", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc deleted file mode 100644 index ab67b75b2cbdc0..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc +++ /dev/null @@ -1,330 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -#include -#include -#include -#include - -#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers -#include "tensorflow/compiler/mlir/lite/allocation.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" - -#ifndef NDEBUG -// Make flatbuffers verifier `assert` in debug mode. -#define FLATBUFFERS_DEBUG_VERIFICATION_FAILURE - -#include "flatbuffers/flatbuffers.h" // from @flatbuffers // IWYU pragma: keep -#endif - -#include -#include - -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "flatbuffers/verifier.h" // from @flatbuffers -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/stderr_reporter.h" - -namespace litert::internal { - -using ::flatbuffers::Verifier; -using ::tflite::VerifyModelBuffer; - -namespace { - -Expected FindMetadataInd(const TflModel& model, - absl::string_view key) { - tflite::MetadataT* fb_metadata = nullptr; - for (auto& m : model.metadata) { - if (m->name == key) { - fb_metadata = m.get(); - break; - } - } - if (fb_metadata == nullptr) { - return Error(kLiteRtStatusErrorNotFound); - } - return fb_metadata->buffer; -} - -} // namespace - -absl::string_view FbBufToStr(const uint8_t* fb_data, size_t size) { - auto fb_buf_raw = reinterpret_cast(fb_data); - return absl::string_view(fb_buf_raw, size); -} - -absl::string_view FbBufToStr(absl::Span fb_buf) { - auto fb_buf_raw = reinterpret_cast(fb_buf.data()); - const size_t fb_buf_size = fb_buf.size(); - return absl::string_view(fb_buf_raw, fb_buf_size); -} - -absl::Span FbBufToStr(absl::Span fb_buf) { - return absl::MakeSpan(reinterpret_cast(fb_buf.data()), fb_buf.size()); -} - -absl::Span FbBufToStr(uint8_t* fb_data, size_t size) { - return absl::MakeSpan(reinterpret_cast(fb_data), size); -} - -bool VerifyFlatbuffer(absl::Span buf) { - return VerifyFlatbuffer(buf.data(), buf.size()); -} - -bool VerifyFlatbuffer(const uint8_t* buf, size_t buf_size) { - flatbuffers::Verifier::Options options; -#ifndef NDEBUG - options.assert = true; -#endif - flatbuffers::Verifier verifier(buf, buf_size, options); - return VerifyModelBuffer(verifier); -} - -Expected> GetMetadata(absl::string_view key, - TflModel& model) { - auto buffer_ind = FindMetadataInd(model, key); - if (!buffer_ind) { - // Metadata key already has value. - return buffer_ind.Error(); - } - auto& fb_vec = model.buffers.at(*buffer_ind)->data; - return MutableBufferRef(fb_vec.data(), fb_vec.size()); -} - -Expected> GetMetadata(absl::string_view key, - const TflModel& model) { - auto metadata = GetMetadata(key, const_cast(model)); - if (!metadata) { - return metadata.Error(); - } - return *metadata; -} - -LiteRtStatus PushMetadata(absl::string_view key, TflModel& model, - BufferRef metadata) { - auto buffer_ind = FindMetadataInd(model, key); - if (buffer_ind) { - // Metadata key already has value. - return kLiteRtStatusErrorInvalidArgument; - } - - auto& new_metadata = - model.metadata.emplace_back(std::make_unique()); - new_metadata->name.assign(key.data(), key.size()); - - const auto new_m_buffer_ind = model.buffers.size(); - new_metadata->buffer = new_m_buffer_ind; - - auto& new_buffer = model.buffers.emplace_back(std::make_unique()); - new_buffer->data.assign(metadata.Data(), metadata.Data() + metadata.Size()); - - return kLiteRtStatusOk; -} - -Expected> GetTflBuffer(TflModel& tfl_model, - uint32_t buffer_ind) { - if (buffer_ind >= tfl_model.buffers.size()) { - return Error(kLiteRtStatusErrorIndexOOB); - } - auto& tfl_data = tfl_model.buffers.at(buffer_ind)->data; - return MutableBufferRef(tfl_data.data(), tfl_data.size()); -} - -Expected> GetTflBuffer(const TflModel& tfl_model, - uint32_t buffer_ind) { - auto buffer = GetTflBuffer(const_cast(tfl_model), buffer_ind); - if (!buffer) { - return buffer.Error(); - } - return *buffer; -} - -Expected GetBuffer(const TflModel& tfl_model, - uint32_t buffer_ind) { - if (buffer_ind >= tfl_model.buffers.size()) { - return Error(kLiteRtStatusErrorIndexOOB); - } - return tfl_model.buffers.at(buffer_ind).get(); -} - -Expected TakeBuffer(TflModel& tfl_model, uint32_t buffer_ind) { - if (buffer_ind >= tfl_model.buffers.size()) { - return Error(kLiteRtStatusErrorIndexOOB); - } - return std::move(tfl_model.buffers.at(buffer_ind)); -} - -Expected PushTflBuffer(TflModel& tfl_model, - BufferRef buffer) { - tfl_model.buffers.emplace_back(std::make_unique<::tflite::BufferT>()) - ->data.assign(buffer.Data(), buffer.Data() + buffer.Size()); - return tfl_model.buffers.size() - 1; -} - -Expected GetTflOpCode(const TflModel& tfl_model, - uint32_t op_code_ind) { - if (op_code_ind >= tfl_model.operator_codes.size()) { - return Error(kLiteRtStatusErrorIndexOOB); - } - return std::move(tfl_model.operator_codes.at(op_code_ind)->builtin_code); -} - -bool IsRankedTensorType(const TflShapeInfo& tfl_shape) { - return tfl_shape.has_rank; -} - -bool IsStaticTensorType(const TflShapeInfo& tfl_shape) { - return !IsRankedTensorType(tfl_shape) || - std::none_of(tfl_shape.shape_signature.begin(), - tfl_shape.shape_signature.end(), - [](auto d) { return d < 0; }); -} - -Expected> AsStaticShape( - const TflShapeInfo& tfl_shape) { - if (!IsStaticTensorType(tfl_shape)) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - return absl::MakeConstSpan(tfl_shape.shape.data(), tfl_shape.shape.size()); -} - -Expected> AsDynamicShape( - const TflShapeInfo& tfl_shape) { - auto static_shape = AsStaticShape(tfl_shape); - if (static_shape) { - return static_shape; - } - if (!IsRankedTensorType(tfl_shape)) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - return absl::MakeConstSpan(tfl_shape.shape_signature.data(), - tfl_shape.shape_signature.size()); -} - -bool IsQuantized(const TflQuantization* tfl_quantization) { - return tfl_quantization && - (!tfl_quantization->scale.empty() || - tfl_quantization->details.type != tflite::QuantizationDetails_NONE); -} - -bool IsPerChannelQuantized(const TflQuantization* tfl_quantization) { - return tfl_quantization && tfl_quantization->scale.size() > 1; -} - -bool IsPerTensorQuantized(const TflQuantization* tfl_quantization) { - return tfl_quantization && tfl_quantization->scale.size() == 1; -} - -bool IsBlockwiseQuantized(const TflQuantization* tfl_quantization) { - return tfl_quantization && - tfl_quantization->details.type == - tflite::QuantizationDetails_BlockwiseQuantization; -} - -bool IsCustomQuantized(const TflQuantization* tfl_quantization) { - return tfl_quantization && tfl_quantization->details.type == - tflite::QuantizationDetails_CustomQuantization; -} - -Expected AsPerTensorQparams( - const TflQuantization* tfl_quantization) { - if (!IsPerTensorQuantized(tfl_quantization)) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - return std::make_pair(tfl_quantization->zero_point.front(), - tfl_quantization->scale.front()); -} - -Expected AsPerChannelQparams( - const TflQuantization* tfl_quantization) { - if (!IsPerChannelQuantized(tfl_quantization)) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - return TflPerChannelQParams(tfl_quantization->quantized_dimension, - tfl_quantization->zero_point.size(), - tfl_quantization->zero_point, - tfl_quantization->scale); -} - -::tflite::Allocation::Ptr MakeAllocation(BufferRef buf) { - return std::make_unique<::tflite::MemoryAllocation>( - buf.Data(), buf.Size(), ::tflite::DefaultErrorReporter()); -} - -Expected FlatbufferWrapper::CreateFromBuffer( - OwningBufferRef&& buffer) { - static constexpr size_t k2GiB = 2e+9; - if (buffer.Size() < k2GiB && - !VerifyFlatbuffer(buffer.Data(), buffer.Size())) { - return Error(kLiteRtStatusErrorInvalidFlatbuffer); - } - - auto alloc = MakeAllocation(buffer); - - if (alloc == nullptr) { - return Error(kLiteRtStatusErrorFileIO); - } - - auto fb_model = ::tflite::FlatBufferModel::BuildFromBuffer( - reinterpret_cast(alloc->base()), alloc->bytes()); - if (fb_model == nullptr) { - return Error(kLiteRtStatusErrorFileIO); - } - - return FlatbufferWrapper::Ptr(new FlatbufferWrapper( - std::move(fb_model), std::move(alloc), std::move(buffer))); -} - -Expected FlatbufferWrapper::CreateFromBuffer( - BufferRef buffer) { - return FlatbufferWrapper::CreateFromBuffer( - OwningBufferRef(buffer.Data(), buffer.Size())); -} - -Expected FlatbufferWrapper::CreateFromTflFile( - absl::string_view path) { - auto buf = LoadBinaryFile(path); - if (!buf) { - return buf.Error(); - } - return FlatbufferWrapper::CreateFromBuffer(std::move(*buf)); -} - -OwningBufferRef SerializeFlatbuffer(const TflModel& tfl_model) { - flatbuffers::FlatBufferBuilder b; - auto model_offset = tflite::Model::Pack(b, &tfl_model); - tflite::FinishModelBuffer(b, model_offset); - - OwningBufferRef buffer; - auto [new_buf, new_size, new_offset] = buffer.GetWeak(); - new_buf = b.ReleaseRaw(new_size, new_offset); - - return buffer; -} - -OwningBufferRef SerializeFlatbuffer( - const FlatbufferWrapper& flatbuffer) { - auto tfl_model = flatbuffer.Unpack(); - return SerializeFlatbuffer(*tfl_model); -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h deleted file mode 100644 index bf0ccf6604f737..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h +++ /dev/null @@ -1,314 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/mlir/lite/allocation.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_consts.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert::internal { - -// Flatbuffer IR - -using TflTensor = ::tflite::TensorT; -using TflOp = ::tflite::OperatorT; -using TflBuffer = ::tflite::BufferT; -using TflSubgraph = ::tflite::SubGraphT; -using TflModel = ::tflite::ModelT; -using TflOpCodeEnum = ::tflite::BuiltinOperator; -using TflOpCode = ::tflite::OperatorCodeT; -using TflQuantization = ::tflite::QuantizationParametersT; -using TflElementType = ::tflite::TensorType; -using TflOptions = ::tflite::BuiltinOptionsUnion; -using TflOptions2 = ::tflite::BuiltinOptions2Union; -using TflSignature = ::tflite::SignatureDefT; -using TflMetadata = ::tflite::MetadataT; - -using TflPackedModel = ::tflite::Model; -using TflPackedSubgraph = ::tflite::SubGraph; -using TflPackedOp = ::tflite::Operator; -using TflPackedTensor = ::tflite::Tensor; -using TflPackedBuffer = ::tflite::Buffer; - -using TflBufferPtr = std::unique_ptr; -using TflModelPtr = std::unique_ptr; -using TflQuantizationPtr = std::unique_ptr; -using TflOpCodePtr = std::unique_ptr; -using TflSubgraphPtr = std::unique_ptr; -using TflTensorPtr = std::unique_ptr; -using TflOpPtr = std::unique_ptr; -using TflSignaturePtr = std::unique_ptr; -using TflMetadataPtr = std::unique_ptr; - -// Code and verion. -using TflOpCodeDetail = std::pair; - -// Zero-point, scale. -using TflPerTensorQParams = std::pair; - -// Quantized dim, num channels, zero-points, scales. -using TflPerChannelQParams = - std::tuple, std::vector>; - -// Mirror of all the tensor type related fields in flatbuffer tensor definition. -struct TflShapeInfo { - // Fixed or dynamic rank. - bool has_rank; - - // Basic shape, all elements are non-negative (even if this is a dynamic - // shape). - absl::InlinedVector shape; - - // Dynamic dyn info. If this is not empty, then its length is equal to shape. - // If i is a dyn dim, then shape[i] == 1 and shape_signature[i] < 0. Otherwise - // shape_signature[i] == shape[i]. - absl::InlinedVector shape_signature; - - // Convert from a single dims array. Will detect if array is static/dynamic - // and populate fields accordingly. - explicit TflShapeInfo(absl::Span shape_data) : has_rank(true) { - bool is_dyn = false; - shape.reserve(shape_data.size()); - shape_signature.reserve(shape_data.size()); - for (auto d : shape_data) { - if (d >= 0) { - shape.push_back(d); - shape_signature.push_back(d); - } else { - is_dyn = true; - shape.push_back(1); - shape_signature.push_back(-1); - } - } - if (!is_dyn) { - shape_signature.clear(); - } - } - - // Convert from tensor. - explicit TflShapeInfo(const TflTensor& tfl_tensor) - : has_rank(tfl_tensor.has_rank), - shape(tfl_tensor.shape.begin(), tfl_tensor.shape.end()), - shape_signature(tfl_tensor.shape_signature.begin(), - tfl_tensor.shape_signature.end()) {} - - explicit TflShapeInfo(const TflPackedTensor& tfl_tensor) - : has_rank(tfl_tensor.has_rank()) { - if (tfl_tensor.shape()) { - shape.assign(tfl_tensor.shape()->begin(), tfl_tensor.shape()->end()); - } - - if (tfl_tensor.shape_signature()) { - shape_signature.assign(tfl_tensor.shape_signature()->begin(), - tfl_tensor.shape_signature()->end()); - } - } -}; - -using TflTensorType = std::pair; - -// Flatbuffer bytes util. - -// Convenience method to get string view from native flatbuffer chars. -absl::string_view FbBufToStr(const uint8_t* fb_data, size_t size); - -// Span version. -absl::string_view FbBufToStr(absl::Span fb_buf); - -// Convenience method to get mutable signed char span from native flatbuffer -// chars. -absl::Span FbBufToStr(uint8_t* fb_data, size_t size); - -// Span to span version. -absl::Span FbBufToStr(absl::Span fb_buf); - -// Flatbuffer verifiers. - -// Verifies given serialized flatbuffer -bool VerifyFlatbuffer(const uint8_t* buf, size_t buf_size); - -// Override of above with view input. -bool VerifyFlatbuffer(absl::Span buf); - -// TFL flatbuffer IR helpers. - -// Get the metadata buffer under given key if it exists. -Expected> GetMetadata(absl::string_view key, - const TflModel& model); - -// Get the metadata buffer under given key if it exists that can be written to. -Expected> GetMutableMetadata(absl::string_view key, - TflModel& model); - -// Push the given metadata to the given key if the key does not already exist. -LiteRtStatus PushMetadata(absl::string_view key, TflModel& model, - BufferRef metadata); - -// Get the buffer object at the given index if it exists. -Expected> GetTflBuffer(const TflModel& tfl_model, - uint32_t buffer_ind); - -// Get the buffer object at the given index if it exists that can be written to. -Expected> GetMutableTflBuffer(TflModel& tfl_model, - uint32_t buffer_ind); - -// Get a non-owning view of tfl buffer if it exists. -Expected GetBuffer(const TflModel& tfl_model, - uint32_t buffer_ind); - -// Move and take ownership of the buffer object at given index if it exists. -Expected TakeBuffer(TflModel& tfl_model, uint32_t buffer_ind); - -// Add a new buffer to the tflite model, returning its index. -Expected PushTflBuffer(TflModel& tfl_model, - BufferRef buffer); - -// Make a tflite buffer from data. -template -TflBufferPtr MakeTflBuffer(std::initializer_list data) { - auto res = std::make_unique(); - const auto byte_size = data.size() * sizeof(T); - res->data.resize(byte_size); - for (auto it = data.begin(); it != data.end(); ++it) { - auto* write_to = - reinterpret_cast(res->data.data()) + (it - data.begin()); - *write_to = *it; - } - res->size = res->data.size(); - res->offset = 0; - return res; -} - -// Get the op code from the model at the given index if it exists. -Expected GetTflOpCode(const TflModel& tfl_model, - uint32_t op_code_ind); - -// Is tensor fixed rank, with possible dynamic dims. -bool IsRankedTensorType(const TflShapeInfo& tfl_shape); - -// Is ranked tensor type with static shape. -bool IsStaticTensorType(const TflShapeInfo& tfl_shape); - -// Get static shape info if given is indeed a static shape. -Expected> AsStaticShape( - const TflShapeInfo& tfl_shape); - -// Get ranked dynamic shape info if given is indeed a ranked. Still works with -// static shapes. -Expected> AsDynamicShape( - const TflShapeInfo& tfl_shape); - -// Is the tensor quantized. -bool IsQuantized(const TflQuantization* tfl_quantization); - -// Is the tensor per-tensor quantized. -bool IsPerTensorQuantized(const TflQuantization* tfl_quantization); - -// Is the tensor per-channel quantized. -bool IsPerChannelQuantized(const TflQuantization* tfl_quantization); - -// Is the tensor block-wise quantized. -bool IsBlockWiseQuantized(const TflQuantization* tfl_quantization); - -// Does tensor have custom quantization. -bool IsCustomQuantized(const TflQuantization* tfl_quantization); - -// Get the per-tensor tensor q-params if given tensor has them. -Expected AsPerTensorQparams( - const TflQuantization* tfl_quantization); - -// Get the per-channel tensor q-params if given tensor has them. -Expected AsPerChannelQparams( - const TflQuantization* tfl_quantization); - -// Flatbuffer management helpers. - -// Make a tfl allocation from buffer. -::tflite::Allocation::Ptr MakeAllocation(BufferRef buf); - -// Wrapper around a tflite model buffer. -class FlatbufferWrapper { - public: - using Ptr = std::unique_ptr; - - // TODO Don't return a unique_ptr, this can just be a move only type, all the - // fields are unique_ptrs. Load flatbuffer from file. - static Expected CreateFromTflFile(absl::string_view path); - - // Load flatbuffer from allocated buffer that will be copied. - static Expected CreateFromBuffer(BufferRef buffer); - - // Load flatbuffer from allocated buffer and take ownership. - static Expected CreateFromBuffer(OwningBufferRef&& buffer); - - // Underlying buffer. - BufferRef Buf() const { - return BufferRef(alloc_->base(), alloc_->bytes()); - } - - // Underlying model object. - const ::tflite::FlatBufferModel& FlatbufferModel() const { - return *fb_model_; - } - - // Packed schema object. - const TflPackedModel* PackedModel() const { return fb_model_->GetModel(); } - - // Unpack the contained flatbuffer. - TflModelPtr Unpack() const { - return TflModelPtr(fb_model_->GetModel()->UnPack()); - } - - // Address of first byte of the raw model buffer. - const uint8_t* AllocBase() const { return Buf().Data(); } - - // Default construct for compatibility. - FlatbufferWrapper() = default; - - private: - FlatbufferWrapper(::tflite::FlatBufferModel::Ptr fb_model, - ::tflite::Allocation::Ptr alloc, - OwningBufferRef&& model_buf) - : fb_model_(std::move(fb_model)), - alloc_(std::move(alloc)), - model_buf_(std::forward>(model_buf)) {} - - ::tflite::FlatBufferModel::Ptr fb_model_; - ::tflite::Allocation::Ptr alloc_; - OwningBufferRef model_buf_; -}; - -// Re-serialize the unpacked model from flatbuffer wrapper. -OwningBufferRef SerializeFlatbuffer( - const FlatbufferWrapper& flatbuffer); -OwningBufferRef SerializeFlatbuffer(const TflModel& tfl_model); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ diff --git a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools_test.cc b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools_test.cc deleted file mode 100644 index bc4fd6c493647c..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools_test.cc +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" - -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace litert::internal { -namespace { - -using ::testing::ElementsAre; -using ::testing::ElementsAreArray; -using ::testing::Lt; - -FlatbufferWrapper::Ptr TestFlatbuffer( - absl::string_view filename = "one_mul.tflite") { - const auto tfl_path = testing::GetTestFilePath(filename); - return *FlatbufferWrapper::CreateFromTflFile(tfl_path); -} - -static const absl::string_view kKey = "MyKey"; -static const absl::string_view kData = "MyData"; - -TEST(FlatbufferToolsTest, Metadata) { - auto flatbuffer = TestFlatbuffer(); - ASSERT_NE(flatbuffer, nullptr); - auto tfl_model = flatbuffer->Unpack(); - - LITERT_ASSERT_OK(PushMetadata( - kKey, *tfl_model, BufferRef(kData.data(), kData.size()))); - - auto metadata = GetMetadata(kKey, *tfl_model); - ASSERT_TRUE(metadata); - EXPECT_EQ(metadata->StrView(), kData); -} - -TEST(FlatbufferToolsTest, GetMetadataNotFound) { - auto flatbuffer = TestFlatbuffer(); - auto tfl_model = flatbuffer->Unpack(); - ASSERT_NE(flatbuffer, nullptr); - EXPECT_FALSE(GetMetadata(kKey, *tfl_model)); -} - -TEST(FlatbufferToolsTest, TflBuffer) { - auto flatbuffer = TestFlatbuffer(); - ASSERT_NE(flatbuffer, nullptr); - auto tfl_model = flatbuffer->Unpack(); - - auto ind = PushTflBuffer((*tfl_model), - BufferRef(kData.data(), kData.size())); - ASSERT_TRUE(ind); - - auto buf = GetTflBuffer((*tfl_model), *ind); - ASSERT_TRUE(buf); - ASSERT_EQ(buf->StrView(), kData); -} - -TEST(FlatbufferToolsTest, GetTflBufferNotFound) { - auto flatbuffer = TestFlatbuffer(); - ASSERT_NE(flatbuffer, nullptr); - auto tfl_model = flatbuffer->Unpack(); - - auto buf = GetTflBuffer((*tfl_model), 100); - ASSERT_FALSE(buf); -} - -TEST(FlatbufferToolsTest, GetTflOpCode) { - auto flatbuffer = TestFlatbuffer(); - ASSERT_NE(flatbuffer, nullptr); - auto tfl_model = flatbuffer->Unpack(); - - auto op_code = GetTflOpCode((*tfl_model), 0); - ASSERT_TRUE(op_code); -} - -TEST(FlatbufferToolsTest, GetTflOpCodeNotFound) { - auto flatbuffer = TestFlatbuffer(); - ASSERT_NE(flatbuffer, nullptr); - auto tfl_model = flatbuffer->Unpack(); - - auto op_code = GetTflOpCode((*tfl_model), 100); - ASSERT_FALSE(op_code); -} - -TEST(FlatbufferToolsTest, StaticTensorTypeTest) { - auto flatbuffer = TestFlatbuffer(); - auto tfl_model = flatbuffer->Unpack(); - auto& tensor = tfl_model->subgraphs.front()->tensors.front(); - - TflShapeInfo shape(*tensor); - - ASSERT_TRUE(IsRankedTensorType(shape)); - ASSERT_TRUE(IsStaticTensorType(shape)); - - auto static_shape = AsStaticShape(shape); - - ASSERT_TRUE(static_shape); - ASSERT_THAT(*static_shape, ElementsAreArray({2, 2})); -} - -TEST(FlatbufferToolsTest, UnrankedTensorTypeTest) { - auto flatbuffer = TestFlatbuffer("unranked_tensor.tflite"); - auto tfl_model = flatbuffer->Unpack(); - auto& tensor = tfl_model->subgraphs.front()->tensors.front(); - - TflShapeInfo shape(*tensor); - - ASSERT_FALSE(IsRankedTensorType(shape)); -} - -TEST(FlatbufferToolsTest, RankedDynamicTensorTypeTest) { - auto flatbuffer = TestFlatbuffer("dynamic_shape_tensor.tflite"); - auto tfl_model = flatbuffer->Unpack(); - auto& tensor = tfl_model->subgraphs.front()->tensors.front(); - - TflShapeInfo shape(*tensor); - - ASSERT_TRUE(IsRankedTensorType(shape)); - ASSERT_FALSE(IsStaticTensorType(shape)); - - auto dyn_shape = AsDynamicShape(shape); - - ASSERT_TRUE(dyn_shape); - ASSERT_THAT(*dyn_shape, ElementsAre(Lt(0), 2)); -} - -TEST(FlatbufferToolsTest, PerTensorQuantizedTest) { - auto flatbuffer = - TestFlatbuffer("single_add_default_a16w8_recipe_quantized.tflite"); - auto tfl_model = flatbuffer->Unpack(); - auto& tensor = tfl_model->subgraphs.front()->tensors.front(); - - const auto* const q_parms = tensor->quantization.get(); - - ASSERT_TRUE(IsQuantized(q_parms)); - EXPECT_TRUE(IsPerTensorQuantized(q_parms)); - - auto per_tensor = AsPerTensorQparams(q_parms); - ASSERT_TRUE(per_tensor); -} - -TEST(FlatbufferToolsTest, PerChannelQuantizedTest) { - auto flatbuffer = TestFlatbuffer("static_w8_a16_quantized_k_einsum.tflite"); - auto tfl_model = flatbuffer->Unpack(); - auto& tensor = tfl_model->subgraphs.front()->tensors[1]; - - const auto* const q_parms = tensor->quantization.get(); - - ASSERT_TRUE(IsQuantized(q_parms)); - EXPECT_TRUE(IsPerChannelQuantized(q_parms)); - - auto per_channel = AsPerChannelQparams(q_parms); - ASSERT_TRUE(per_channel); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/util/tensor_type_util.cc b/tensorflow/lite/experimental/litert/core/util/tensor_type_util.cc deleted file mode 100644 index 4e3284374d24a2..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/tensor_type_util.cc +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/util/tensor_type_util.h" - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { -namespace internal { - -Expected GetElementSize(LiteRtElementType element_type) { - switch (element_type) { - case kLiteRtElementTypeInt4: - return Ratio{1, 2}; - case kLiteRtElementTypeBool: - return Ratio{1, 1}; - case kLiteRtElementTypeInt8: - case kLiteRtElementTypeUInt8: - return Ratio{1, 1}; - case kLiteRtElementTypeInt16: - case kLiteRtElementTypeUInt16: - case kLiteRtElementTypeFloat16: - case kLiteRtElementTypeBFloat16: - return Ratio{2, 1}; - case kLiteRtElementTypeInt32: - case kLiteRtElementTypeUInt32: - case kLiteRtElementTypeFloat32: - return Ratio{4, 1}; - case kLiteRtElementTypeInt64: - case kLiteRtElementTypeUInt64: - case kLiteRtElementTypeFloat64: - return Ratio{8, 1}; - case kLiteRtElementTypeComplex64: - return Ratio{16, 1}; - case kLiteRtElementTypeComplex128: - return Ratio{32, 1}; - default: - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Unexpected element type"); - } -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/util/tensor_type_util.h b/tensorflow/lite/experimental/litert/core/util/tensor_type_util.h deleted file mode 100644 index 9663b2ac337403..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/tensor_type_util.h +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_TENSOR_TYPE_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_TENSOR_TYPE_UTIL_H_ - -#include - -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -struct Ratio { - using Type = int; - Type num; - Type denom; - std::string ToString() const { return absl::StrCat(num, "/", denom); } -}; - -Expected GetElementSize(LiteRtElementType element_type); - -// Get the number of elements in a tensor with given dimensions. -template -Expected GetNumElements(absl::Span dimensions) { - size_t num_elements = 1; - for (auto i = 0; i < dimensions.size(); ++i) { - auto dim = dimensions[i]; - if (dim < 0) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Unexpected negative dimension"); - } else if (dim == 0) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Unexpected 0 dimension"); - } - num_elements *= dim; - } - return num_elements; -} - -inline Expected GetNumElements( - const LiteRtRankedTensorType& tensor_type) { - return GetNumElements( - absl::MakeSpan(tensor_type.layout.dimensions, tensor_type.layout.rank)); -} - -// Get the minimum number of bytes necessary to represent a packed tensor with a -// given element type and dimensions. -template -Expected GetNumPackedBytes(LiteRtElementType element_type, - absl::Span dimensions) { - auto element_size = GetElementSize(element_type); - if (!element_size) { - return element_size.Error(); - } - auto num_elements = GetNumElements(dimensions); - if (!num_elements) { - return num_elements.Error(); - } - return ((*num_elements * element_size->num) + (element_size->denom - 1)) / - element_size->denom; -} - -// Get the number of bytes necessary to represent a packed tensor type, ignoring -// any stride information. -inline Expected GetNumPackedBytes( - const LiteRtRankedTensorType& tensor_type) { - return GetNumPackedBytes( - tensor_type.element_type, - absl::MakeSpan(tensor_type.layout.dimensions, tensor_type.layout.rank)); -} - -// Get the minimum number of bytes necessary to represent a possibly unpacked -// tensor with a given element type, dimensions, and strides. -template -Expected GetNumBytes(LiteRtElementType element_type, - absl::Span dimensions, absl::Span strides) { - if (dimensions.size() != strides.size()) { - return Unexpected( - kLiteRtStatusErrorInvalidArgument, - "Dimensions and strides have different number of elements"); - } - auto element_size = GetElementSize(element_type); - if (!element_size) { - return element_size.Error(); - } - auto rank = dimensions.size(); - size_t num_elements = 1; - for (auto i = 0; i < rank; ++i) { - num_elements += (dimensions[i] - 1) * strides[i]; - } - return ((num_elements * element_size->num) + (element_size->denom - 1)) / - element_size->denom; -} - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_TENSOR_TYPE_UTIL_H_ diff --git a/tensorflow/lite/experimental/litert/core/util/tensor_type_util_test.cc b/tensorflow/lite/experimental/litert/core/util/tensor_type_util_test.cc deleted file mode 100644 index bfb084140eb073..00000000000000 --- a/tensorflow/lite/experimental/litert/core/util/tensor_type_util_test.cc +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/core/util/tensor_type_util.h" - -#include -#include - -#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" - -using litert::internal::GetNumBytes; -using litert::internal::GetNumElements; -using litert::internal::GetNumPackedBytes; - -TEST(TensorTypeUtil, GetNumElements) { - constexpr std::array dimensions = {3, 2, 1}; - auto num_elements = GetNumElements(absl::MakeSpan(dimensions)); - EXPECT_TRUE(num_elements); - EXPECT_EQ(*num_elements, 6); -} - -TEST(TensorTypeUtil, GetNumElementsWithUnknownDimension) { - constexpr std::array dimensions = {3, -1, 1}; - auto num_elements = GetNumElements(absl::MakeSpan(dimensions)); - EXPECT_FALSE(num_elements); -} - -TEST(TensorTypeUtil, GetNumElementsWithZeroDimension) { - constexpr std::array dimensions = {3, 0, 1}; - auto num_elements = GetNumElements(absl::MakeSpan(dimensions)); - EXPECT_FALSE(num_elements); -} - -TEST(TensorTypeUtil, GetNumPackedBytes) { - LiteRtElementType element_type = kLiteRtElementTypeInt32; - constexpr std::array dimensions = {3, 2, 1}; - auto num_bytes = GetNumPackedBytes(element_type, absl::MakeSpan(dimensions)); - EXPECT_TRUE(num_bytes); - EXPECT_EQ(*num_bytes, sizeof(int32_t) * 6); -} - -TEST(TensorTypeUtil, GetNumBytes) { - LiteRtElementType element_type = kLiteRtElementTypeInt32; - constexpr std::array dimensions = {3, 2, 1}; - constexpr std::array strides = {1, 4, 8}; - // The data should be allocated as follows (where 'X' is a used cell and 'o' - // is an unused/padding cell): - // - // XXXo XXX - // - // The total is 4 + 3 = 7 cells - auto num_bytes = GetNumBytes(element_type, absl::MakeSpan(dimensions), - absl::MakeSpan(strides)); - EXPECT_TRUE(num_bytes); - EXPECT_EQ(*num_bytes, sizeof(int32_t) * 7); -} diff --git a/tensorflow/lite/experimental/litert/core/version.h b/tensorflow/lite/experimental/litert/core/version.h deleted file mode 100644 index fa9b017917c349..00000000000000 --- a/tensorflow/lite/experimental/litert/core/version.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_VERSION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_VERSION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -namespace litert::internal { - -// Return true if two API versions are the same. -inline bool IsSameVersion(const LiteRtApiVersion& v1, - const LiteRtApiVersion& v2) { - return (v1.major == v2.major) && (v1.minor == v2.minor) && - (v1.patch == v2.patch); -} - -// Return true if a given API version is the same as the current runtime. -inline bool IsSameVersionAsRuntime(const LiteRtApiVersion& v) { - return IsSameVersion(v, {LITERT_API_VERSION_MAJOR, LITERT_API_VERSION_MINOR, - LITERT_API_VERSION_PATCH}); -} - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_VERSION_H_ diff --git a/tensorflow/lite/experimental/litert/integration_test/BUILD b/tensorflow/lite/experimental/litert/integration_test/BUILD deleted file mode 100644 index d36062a53e6c26..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/BUILD +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:tfl_model_gen.bzl", "tfl_model_gen") -load("//tensorflow/lite/experimental/litert/integration_test:run_on_device.bzl", "litert_integration_test") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -# C++ TEST SCAFFOLD ################################################################################ - -cc_test( - name = "gen_device_test", - srcs = ["gen_device_test.cc"], - copts = ["-DGOOGLE_COMMANDLINEFLAGS_FULL_API=1"], - data = [":single_op_models"], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = ["manual"], - deps = [ - ":gen_device_test_lib", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/tools:dump", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/flags:parse", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "gen_device_test_lib", - testonly = True, - hdrs = ["gen_device_test_lib.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - ], -) - -# TEST MODELS ###################################################################################### - -filegroup( - name = "classic_ml_models", - srcs = glob(["classic_ml_models/*.tflite"]), -) - -tfl_model_gen( - name = "single_op_models", - srcs = glob(["single_op_models/*.mlir"]), - subdir = "single_op_models", -) - -filegroup( - name = "pre_compiled_models", - srcs = glob(["pre_compiled_models/*.tflite"]), -) - -# ON DEVICE INTEGRATION TESTS ###################################################################### - -# NOTE: Everything here should be built with -c opt --config=android_arm64. - -sh_binary( - name = "run_on_device_driver_OSS", - srcs = ["run_on_device_driver_OSS.sh"], -) - -litert_integration_test( - name = "single_op_device_tests_cpu", - hw = "cpu", - models = ":single_op_models", -) - -litert_integration_test( - name = "single_op_device_tests_qualcomm_JIT", - hw = "qualcomm", - models = ":single_op_models", - skips = [ - "greater_f32", # TODO: lukeboyer - Investigate (segfault). - "less_f32", # TODO: lukeboyer - Investigate (segfault). - ], -) - -litert_integration_test( - name = "classic_ml_device_tests_cpu", - hw = "cpu", - models = ":classic_ml_models", -) - -litert_integration_test( - name = "classic_ml_device_tests_qualcomm_JIT", - hw = "qualcomm", - models = ":classic_ml_models", -) - -litert_integration_test( - name = "pre_compiled_device_tests_qualcomm", - hw = "qualcomm", - models = ":pre_compiled_models", -) diff --git a/tensorflow/lite/experimental/litert/integration_test/gen_device_test.cc b/tensorflow/lite/experimental/litert/integration_test/gen_device_test.cc deleted file mode 100644 index fd2f05e70dd21f..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/gen_device_test.cc +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include // NOLINT -#include -#include -#include - -#include -#include "absl/flags/flag.h" -#include "absl/flags/parse.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/integration_test/gen_device_test_lib.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/tools/dump.h" - -ABSL_FLAG(std::string, model_path, "", - "Tflite models to test. This can be a single tflite model or a " - "directory containing multiple tflite models."); -ABSL_FLAG(std::string, dispatch_library_dir, "/data/local/tmp/", - "Path to the dispatch library."); -ABSL_FLAG(std::string, compiler_library_dir, "/data/local/tmp/", - "Path to the compiler plugin library."); -ABSL_FLAG(std::string, hw, "cpu", "Which accelerator to use."); -ABSL_FLAG(std::vector, skips, std::vector{}, - "Substrings of models to skip."); - -namespace litert::test { -namespace { - -// UTILS /////////////////////////////////////////////////////////////////////// - -bool IsTfliteModel(const std::filesystem::path& path) { - return std::filesystem::is_regular_file(path) && - path.extension() == ".tflite"; -} - -std::vector GetModelPaths(const std::string& model_path_str) { - std::filesystem::path model_path = model_path_str; - std::vector models; - if (std::filesystem::is_directory(model_path)) { - for (const auto& entry : std::filesystem::directory_iterator(model_path)) { - if (!IsTfliteModel(entry.path())) { - continue; - } - models.push_back(entry.path().generic_string()); - } - return models; - } - - if (IsTfliteModel(model_path)) { - return {model_path.generic_string()}; - } - - return {}; -} - -std::string ModelName(const std::filesystem::path& path) { - return path.filename().replace_extension().generic_string(); -} - -} // namespace - -// FIXTURES //////////////////////////////////////////////////////////////////// - -class GenDeviceTestFixt : public ::testing::Test {}; - -// A test that simply calls the model and ensures it doesn't crash. -// Works with any accelerator. -template -class InvokeOnceTest : public GenDeviceTestFixt { - public: - InvokeOnceTest(std::string model_path, std::string dispatch_library_dir, - std::string compiler_library_dir) - : model_path_(std::move(model_path)), - dispatch_library_dir_(std::move(dispatch_library_dir)), - compiler_library_dir_(std::move(compiler_library_dir)) {} - - // Opens model and initializes the underlying invoker. - void SetUp() override { - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - absl::string_view(dispatch_library_dir_), - }, - litert::Environment::Option{ - litert::Environment::OptionTag::CompilerPluginLibraryDir, - absl::string_view(compiler_library_dir_), - }, - }; - LITERT_ASSERT_OK_AND_ASSIGN( - auto env, litert::Environment::Create(environment_options)); - - LITERT_ASSERT_OK_AND_ASSIGN(auto model, - litert::Model::CreateFromFile(model_path_)); - litert::internal::Dump(*model.Get()); - - invoker_ = std::make_unique(std::move(env), std::move(model)); - invoker_->MaybeSkip(); - ASSERT_NO_FATAL_FAILURE(invoker_->Setup()); - } - - void TestBody() override { ASSERT_NO_FATAL_FAILURE(invoker_->Run()); } - - private: - std::string model_path_; - std::string dispatch_library_dir_; - std::string compiler_library_dir_; - - CmInvoker::Ptr invoker_; -}; - -// REGISTRATION //////////////////////////////////////////////////////////////// - -// Registers tests dynamically based on the hw flag and the model_path flag. -void ParseTests() { - auto model_path_flag = absl::GetFlag(FLAGS_model_path); - // Provide a sensible default based on linux/android. - if (model_path_flag.empty()) { -#if defined(__ANDROID__) - model_path_flag = "/data/local/tmp/"; -#else - // Set this on linux for smoke check linux presubmit. - model_path_flag = testing::GetLiteRtPath( - "integration_test/single_op_models/add_f32.tflite"); -#endif - } - const auto model_paths = GetModelPaths(model_path_flag); - const auto hw = absl::GetFlag(FLAGS_hw); - const auto dispatch_library_dir = absl::GetFlag(FLAGS_dispatch_library_dir); - const auto compiler_library_dir = absl::GetFlag(FLAGS_compiler_library_dir); - const auto skips = absl::GetFlag(FLAGS_skips); - - LITERT_LOG(LITERT_INFO, "hw: %s", hw.c_str()); - LITERT_LOG(LITERT_INFO, "model_path: %s", model_path_flag.c_str()); - LITERT_LOG(LITERT_INFO, "dispatch_library_dir: %s", - dispatch_library_dir.c_str()); - LITERT_LOG(LITERT_INFO, "compiler_library_dir: %s", - compiler_library_dir.c_str()); - LITERT_LOG(LITERT_INFO, "skips: %s", absl::StrJoin(skips, ",").c_str()); - - if (model_paths.empty()) { - LITERT_LOG(LITERT_WARNING, "No models found to test."); - return; - } - - for (const auto& model_path : model_paths) { - LITERT_LOG(LITERT_INFO, "model_path: %s", model_path.c_str()); - - const auto test_name = absl::StrFormat("%s_%s", ModelName(model_path), hw); - const auto should_skip = - std::any_of(skips.cbegin(), skips.cend(), [&](const auto& skip) { - return (model_path.find(skip) != std::string::npos); - }); - - ::testing::RegisterTest( - "GenDeviceTest", test_name.c_str(), nullptr, nullptr, __FILE__, - __LINE__, [=]() -> GenDeviceTestFixt* { - if (should_skip) { - return new InvokeOnceTest( - model_path, dispatch_library_dir, compiler_library_dir); - } else if (hw == "npu") { - return new InvokeOnceTest( - model_path, dispatch_library_dir, compiler_library_dir); - } else { - return new InvokeOnceTest( - model_path, dispatch_library_dir, compiler_library_dir); - } - }); - } -} - -} // namespace litert::test - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - absl::ParseCommandLine(argc, argv); - litert::test::ParseTests(); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/lite/experimental/litert/integration_test/gen_device_test_lib.h b/tensorflow/lite/experimental/litert/integration_test/gen_device_test_lib.h deleted file mode 100644 index b2e585d9a277b4..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/gen_device_test_lib.h +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_INTEGRATION_TEST_GEN_DEVICE_TEST_LIB_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_INTEGRATION_TEST_GEN_DEVICE_TEST_LIB_H_ - -namespace litert::test { - -// Absract wrapper for the invocation of the compiled model API within a -// standard test environment. -class CmInvoker { - public: - using Ptr = std::unique_ptr; - - CmInvoker(Environment&& env, Model&& model) - : env_(std::move(env)), model_(std::move(model)) {} - - // Setup the compiled model api and initialize the input and output buffers. - // Assumes default signature. - void Setup() { - LITERT_ASSERT_OK_AND_ASSIGN( - compiled_model_, CompiledModel::Create(env_, model_, Accelerator())); - const auto sig = model_.DefaultSignatureKey(); - LITERT_ASSERT_OK_AND_ASSIGN(input_buffers_, - compiled_model_.CreateInputBuffers(sig)); - LITERT_ASSERT_OK_AND_ASSIGN(output_buffers_, - compiled_model_.CreateOutputBuffers(sig)); - } - - // Invoke the compiled model api. Must be called after Setup(). - void Run() { - ASSERT_TRUE(compiled_model_.Run(model_.DefaultSignatureKey(), - input_buffers_, output_buffers_)); - } - - // Is this test in a state where it should be skipped? Implementations should - // call GTEST_SKIP(). - virtual void MaybeSkip() const = 0; - - // Which accelerator option to use. - virtual LiteRtHwAccelerators Accelerator() const = 0; - - std::vector& GetInputBuffers() { return input_buffers_; } - std::vector& GetOutputBuffers() { return output_buffers_; } - - virtual ~CmInvoker() = default; - - protected: - Environment env_; - Model model_; - - CompiledModel compiled_model_; - std::vector input_buffers_; - std::vector output_buffers_; -}; - -class SkippedCmInvoker : public CmInvoker { - public: - SkippedCmInvoker(Environment&& env, Model&& model) - : CmInvoker(std::move(env), std::move(model)) {} - void MaybeSkip() const override { - GTEST_SKIP() << "User requested skip for this model."; - } - - LiteRtHwAccelerators Accelerator() const override { - return kLiteRtHwAcceleratorNone; - }; -}; - -// Invocation of the compiled model API for the NPU accelerator. This handles -// both JIT and pre-compiled models. -class CmNpuInvoker : public CmInvoker { - public: - CmNpuInvoker(Environment&& env, Model&& model) - : CmInvoker(std::move(env), std::move(model)) {} - - // Will invocation require compilation. - bool IsJit() const { - auto& m = *model_.Get(); - return !IsCompiled(m); - } - - LiteRtHwAccelerators Accelerator() const override { - return IsJit() ? kLiteRtHwAcceleratorNpu : kLiteRtHwAcceleratorNone; - } - - void MaybeSkip() const override { -#if !defined(__ANDROID__) - GTEST_SKIP() << "NPU test must run on android device."; -#endif - } -}; - -// Invocation of the compiled model API on CPU. This can run on linux in -// addition to android. -class CmCpuInvoker : public CmInvoker { - public: - CmCpuInvoker(Environment&& env, Model&& model) - : CmInvoker(std::move(env), std::move(model)) {} - - LiteRtHwAccelerators Accelerator() const override { - return kLiteRtHwAcceleratorCpu; - } - - void MaybeSkip() const override { - if (IsCompiled(*model_.Get())) { - GTEST_SKIP() << "Cannot run CPU test on a compiled model."; - } - } -}; - -} // namespace litert::test - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_INTEGRATION_TEST_GEN_DEVICE_TEST_LIB_H_ diff --git a/tensorflow/lite/experimental/litert/integration_test/run_on_device.bzl b/tensorflow/lite/experimental/litert/integration_test/run_on_device.bzl deleted file mode 100644 index 3d229478e7c4c4..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/run_on_device.bzl +++ /dev/null @@ -1,296 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This module defines the `run_on_device` macro, which helps to execute a binary target on a device. -""" - -load("//tensorflow:tensorflow.bzl", "if_oss") -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "absolute_label") - -# DEVICE PATHS ##################################################################################### - -DEVICE_RLOCATION_ROOT = "/data/local/tmp/runfiles" - -def device_rlocation(label = None, get_parent = False): - """Get the path on device for a given label. - - Args: - label: The label to get the path for. If None, returns the root path. - get_parent: If true, get the parent directory of the resolved path. - - Returns: - The path on device for the given label. - """ - if not label: - return DEVICE_RLOCATION_ROOT - abs_label = absolute_label(label) - res = DEVICE_RLOCATION_ROOT + "/" + abs_label.replace("//", "").replace(":", "/") - if get_parent: - return res[:res.rfind("/")] - return res - -def make_path_args(spec): - """Formats shell path-like variable assignment exprs from common directories in given labels - - Useful for making things like LD_LIBRARY_PATH=... for paths on device. - - An entry of the spec contains a key, and a list of labels. Unique leaf directories paths are - extracted from the labels and joined into a colon-separated string. - - Example: - ``` - make_path_args({ - "LD_LIBRARY_PATH": [ - "// foo : bar", - ], - "ADSP_LIBRARY_PATH": [ - "// foo : baz", - "// foo : bat" - ], - }) - ``` - will return: - ``` - LD_LIBRARY_PATH=/data/local/tmp/runfiles/foo/bar - ADSP_LIBRARY_PATH=/data/local/tmp/runfiles/foo/baz:/data/local/tmp/runfiles/foo/bat - ``` - - Args: - spec: A dict of path variable names to lists of labels. - - Returns: - A list of shell variable assignment expressions. - """ - - res = [] - for path_var, values in spec.items(): - # TODO: Figure out why OSS doesn't have `set` core datatype. - dirs = [] - for v in values: - parent = device_rlocation(v, True) - if parent not in dirs: - dirs.append(parent) - res.append("{path_var}={paths}".format( - path_var = path_var, - paths = ":".join(dirs), - )) - return res - -# DYNAMIC LIBRARY DEPENDENCIES ##################################################################### - -LITERT_CORE_LIBS = [ - "//tensorflow/lite/experimental/litert/c:libLiteRtRuntimeCApi.so", -] - -def make_lib_spec(**kwargs): - return struct( - litert_base_libs = LITERT_CORE_LIBS, - core_libs = kwargs["core_libs"], - kernel_libs = kwargs["kernel_libs"], - dispatch_lib = kwargs["dispatch_lib"], - compiler_lib = kwargs["compiler_lib"], - ) - -BASE_LIB_SPEC = make_lib_spec( - core_libs = [], - kernel_libs = [], - dispatch_lib = None, - compiler_lib = None, -) - -def all_libs(spec): - """ - Returns all the dynamic libraries needed for the given spec. - - Args: - spec: The lib spec to get the libs for. - - Returns: - A list of all the dynamic libraries needed for the given spec. - """ - libs = spec.litert_base_libs + spec.core_libs + spec.kernel_libs - for lib in [spec.dispatch_lib, spec.compiler_lib]: - if lib: - libs.append(lib) - return libs - -# QNN - -QUALCOMM_LIB_SPEC = make_lib_spec( - core_libs = [ - "//third_party/qairt/latest:lib/aarch64-android/libQnnHtp.so", - "//third_party/qairt/latest:lib/aarch64-android/libQnnHtpV75Stub.so", - "//third_party/qairt/latest:lib/aarch64-android/libQnnSystem.so", - "//third_party/qairt/latest:lib/aarch64-android/libQnnHtpPrepare.so", - ], - kernel_libs = ["//third_party/qairt/latest:lib/hexagon-v75/unsigned/libQnnHtpV75Skel.so"], - dispatch_lib = "//tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch:dispatch_api_so", - compiler_lib = "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:qnn_compiler_plugin_so", -) - -# MTK -# TODO - -# GOOGLE TENSOR -# TODO - -def get_lib_spec(backend_id): - """ - Returns the dynamic library spec for the given backend id. - - Args: - backend_id: The backend id to get the lib spec for. - - Returns: - The dynamic library spec for the given backend id. - """ - if backend_id == "qualcomm": - return QUALCOMM_LIB_SPEC - if backend_id == "cpu": - return BASE_LIB_SPEC - else: - fail("Unsupported backend id: {}".format(backend_id)) - -# RUN ON DEVICE MACRO ############################################################################## - -def get_driver(): - return if_oss( - "//tensorflow/lite/experimental/litert/integration_test:run_on_device_driver_OSS", - "//tensorflow/lite/experimental/litert/integration_test/google:run_on_device_driver", - ) - -def run_on_device( - name, - target, - driver, - data = [], - exec_args = [], - exec_env_vars = []): - """ - Macro to execute a binary target on a device (locally through ADB). - - The output of this macro is an executable shell script that pushes all the necessary files to - the device and executes the target with the given arguments and environment variables. - - Args: - name: Name of the target. - target: The binary target to execute on device. - driver: The driver script to use for execution. - data: List of data files to push to the device. - exec_args: List of arguments to pass to the executable. - exec_env_vars: List of environment variables to set before executing the target. - """ - call_mobile_install = """ - echo '$(location {driver}) \ - --bin=$(rlocationpath {target}) \ - --data={data} \ - --do_exec=true \ - --exec_args={exec_args} \ - --exec_env_vars={exec_env_vars} \ - '\ - > $@ - """ - - concat_targ_data = "$$(echo \"$(rlocationpaths {})\" | sed \"s/ /,/g\")" - data_str = ",".join([concat_targ_data.format(d) for d in data]) - - # NOTE: Tilde delimiter here (also see driver script) to allow passing list args to underlying - # binary. - exec_args_str = "~".join(["{}".format(a) for a in exec_args]) - exec_env_vars_str = ",".join(["{}".format(a) for a in exec_env_vars]) - - driver_targ = driver.removesuffix(".sh") - driver_sh = driver_targ + ".sh" - - cmd = call_mobile_install.format( - driver = driver_sh, - target = target, - data = data_str, - exec_args = exec_args_str, - exec_env_vars = exec_env_vars_str, - ) - - exec_script = name + "_exec.sh" - - native.genrule( - name = name + "_gen_script", - srcs = [driver_sh] + [target] + data, - outs = [exec_script], - tags = ["manual", "notap"], - cmd = cmd, - testonly = True, - ) - - native.sh_binary( - testonly = True, - tags = ["manual", "notap"], - name = name, - deps = [driver_targ], - srcs = [exec_script], - data = [target] + data, - ) - -def litert_integration_test( - name, - models, - hw = "cpu", - skips = []): - """ - Higher level macro that configures run_on_device or a mobile test to run with gen_device_test. - - Args: - name: Name of the target. - models: A single target that may contain model or many models in the same directory. - hw: The backend to test against (see gen_device_test). - skips: List of substrings of models to skip. - """ - - # Get libs for the given backend. - lib_spec = get_lib_spec(hw) - - # Accelerator option to pass to the compiled model api on device. - hw_cfg = hw if hw == "cpu" else "npu" - - # Create env args for paths to dynamic libraries. - env_args = make_path_args({ - "LD_LIBRARY_PATH": lib_spec.litert_base_libs + lib_spec.core_libs + [lib_spec.dispatch_lib, lib_spec.compiler_lib], - "ADSP_LIBRARY_PATH": lib_spec.kernel_libs, - }) - - skips_str = ",".join(skips) - - # Create CLI args for the gen_device_test binary on device. - cli_args = [ - "--model_path={}".format(device_rlocation(models)), - "--dispatch_library_dir={}".format(device_rlocation(lib_spec.dispatch_lib, True)), - "--compiler_library_dir={}".format(device_rlocation(lib_spec.compiler_lib, True)), - "--hw={}".format(hw_cfg), - "--skips={}".format(skips_str), - ] - - data = [models] + all_libs(lib_spec) - driver = get_driver() - target = "//tensorflow/lite/experimental/litert/integration_test:gen_device_test" - - # TODO: Also kick off a xeno mobile test here. - - run_on_device( - name = name, - target = target, - driver = driver, - data = data, - exec_args = cli_args, - exec_env_vars = env_args, - ) diff --git a/tensorflow/lite/experimental/litert/integration_test/run_on_device_driver_OSS.sh b/tensorflow/lite/experimental/litert/integration_test/run_on_device_driver_OSS.sh deleted file mode 100755 index 6fa24babd8b328..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/run_on_device_driver_OSS.sh +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -#!/bin/bash - -# TODO: Implement this script to leverage integration tests in OSS. - -# This script must handle the following flags: - -# DEFINE_string --required bin "" "The binary to execute on the device." -# DEFINE_array data --type=string "" "The data files to install on the device." -# DEFINE_bool do_exec false "Whether to execute the target on the device." -# DEFINE_array exec_args --type=string "" "The arguments to pass to the executable on device." -# DEFINE_array exec_env_vars --type=string "" "The environment variables to set for the executable on device." -# DEFINE_string device_rlocation_root "/data/local/tmp/runfiles" "The root directory for device relative locations." - -# This script must push the bin file and all the data files to the device under -# the device_rlocation_root directory. If do_exec is true, it must execute the -# binary on the device with the given exec_args and exec_env_vars. \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/add_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/add_f32.mlir deleted file mode 100644 index d4e9d5f59da6dc..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/add_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>) -> tensor<256x256xf32> attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "tfl.add"}} { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<256x256xf32> - return %0 : tensor<256x256xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/concatenate_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/concatenate_f32.mlir deleted file mode 100644 index ff1d3172f76e36..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/concatenate_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x4x2xf32>, %arg2: tensor<2x1x2xf32>) -> tensor<2x8x2xf32> attributes {tf.entry_function = {inputs = "arg0,arg1,arg2", outputs = "tfl.concatenation"}} { - %0 = "tfl.concatenation"(%arg0, %arg1, %arg2) <{axis = 1 : i32, fused_activation_function = "NONE"}> : (tensor<2x3x2xf32>, tensor<2x4x2xf32>, tensor<2x1x2xf32>) -> tensor<2x8x2xf32> - return %0 : tensor<2x8x2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/divide_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/divide_f32.mlir deleted file mode 100644 index 8bb1cf1b5f95af..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/divide_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.6.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>) -> tensor<256x256xf32> attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "tfl.div"}} { - %0 = tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<256x256xf32> - return %0 : tensor<256x256xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/greater_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/greater_f32.mlir deleted file mode 100644 index 00fef7d8448236..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/greater_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.14.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>) -> tensor<256x256xi1> attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "tfl.greater"}} { - %0 = tfl.greater(%arg0, %arg1) : (tensor<256x256xf32>, tensor<256x256xf32>) -> tensor<256x256xi1> - return %0 : tensor<256x256xi1> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/less_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/less_f32.mlir deleted file mode 100644 index 0c59c9e5c889ad..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/less_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.14.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>) -> tensor<256x256xi1> attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "tfl.less"}} { - %0 = tfl.less(%arg0, %arg1) : (tensor<256x256xf32>, tensor<256x256xf32>) -> tensor<256x256xi1> - return %0 : tensor<256x256xi1> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/multiply_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/multiply_f32.mlir deleted file mode 100644 index 3390ac72910615..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/multiply_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>) -> tensor<256x256xf32> attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "tfl.mul"}} { - %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<256x256xf32> - return %0 : tensor<256x256xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32.mlir deleted file mode 100644 index 342cfcc69fa61c..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<3x4xf32>) -> tensor<4x3xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.reshape"}} { - %0 = "tfl.pseudo_const"() <{value = dense<[4, 3]> : tensor<2xi32>}> : () -> tensor<2xi32> - %1 = "tfl.reshape"(%arg0, %0) : (tensor<3x4xf32>, tensor<2xi32>) -> tensor<4x3xf32> - return %1 : tensor<4x3xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32_large_rank.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32_large_rank.mlir deleted file mode 100644 index 63df1776dd7b25..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/reshape_f32_large_rank.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<2x3x4x5x6x7x8xf32>) -> tensor<8x7x6x5x4x3x2xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.reshape"}} { - %0 = "tfl.pseudo_const"() <{value = dense<[8, 7, 6, 5, 4, 3, 2]> : tensor<7xi32>}> : () -> tensor<7xi32> - %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3x4x5x6x7x8xf32>, tensor<7xi32>) -> tensor<8x7x6x5x4x3x2xf32> - return %1 : tensor<8x7x6x5x4x3x2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/rsqrt_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/rsqrt_f32.mlir deleted file mode 100644 index 51e03dc9cdcdaf..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/rsqrt_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.10.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256xf32>) -> tensor<256xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.rsqrt"}} { - %0 = "tfl.rsqrt"(%arg0) : (tensor<256xf32>) -> tensor<256xf32> - return %0 : tensor<256xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/select_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/select_f32.mlir deleted file mode 100644 index c37db98eee2114..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/select_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.14.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<2x2xi1>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> attributes {tf.entry_function = {inputs = "arg0,arg1,arg2", outputs = "tfl.select"}} { - %0 = "tfl.select"(%arg0, %arg1, %arg2) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/slice_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/slice_f32.mlir deleted file mode 100644 index 50b62c65be8ff2..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/slice_f32.mlir +++ /dev/null @@ -1,8 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.14.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<3x4xf32>) -> tensor<2x2xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.slice"}} { - %0 = "tfl.pseudo_const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32> - %1 = "tfl.pseudo_const"() <{value = dense<2> : tensor<2xi32>}> : () -> tensor<2xi32> - %2 = "tfl.slice"(%arg0, %0, %1) : (tensor<3x4xf32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xf32> - return %2 : tensor<2x2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/subtract_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/subtract_f32.mlir deleted file mode 100644 index 4265a8c95eeabd..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/subtract_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.6.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>) -> tensor<256x256xf32> attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "tfl.sub"}} { - %0 = tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<256x256xf32> - return %0 : tensor<256x256xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/integration_test/single_op_models/tanh_f32.mlir b/tensorflow/lite/experimental/litert/integration_test/single_op_models/tanh_f32.mlir deleted file mode 100644 index 022392b4f294a4..00000000000000 --- a/tensorflow/lite/experimental/litert/integration_test/single_op_models/tanh_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.14.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<256xf32>) -> tensor<256xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.tanh"}} { - %0 = "tfl.tanh"(%arg0) : (tensor<256xf32>) -> tensor<256xf32> - return %0 : tensor<256xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/python/BUILD b/tensorflow/lite/experimental/litert/python/BUILD deleted file mode 100644 index eeab7c9f0c21b0..00000000000000 --- a/tensorflow/lite/experimental/litert/python/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) diff --git a/tensorflow/lite/experimental/litert/runtime/BUILD b/tensorflow/lite/experimental/litert/runtime/BUILD deleted file mode 100644 index a0e34ead449a1c..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/BUILD +++ /dev/null @@ -1,414 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "gtest_main_no_heapcheck_deps") -load("//tensorflow/lite/experimental/litert/build_common:special_rule.bzl", "gles_deps", "gles_linkopts", "lite_rt_friends") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "event", - srcs = [ - "event.cc", - ], - hdrs = [ - "event.h", - ], - deps = [ - ":gpu_environment", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_event_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/runtime/opencl:cl_event", - "@com_google_absl//absl/strings:str_format", - ], -) - -cc_library( - name = "tensor_buffer", - srcs = [ - "ahwb_buffer.cc", - "dmabuf_buffer.cc", - "fastrpc_buffer.cc", - "gl_buffer.cc", - "gl_texture.cc", - "ion_buffer.cc", - "open_cl_buffer.cc", - "tensor_buffer.cc", - ], - hdrs = [ - "ahwb_buffer.h", - "dmabuf_buffer.h", - "event.h", - "fastrpc_buffer.h", - "gl_buffer.h", - "gl_texture.h", - "ion_buffer.h", - "open_cl_buffer.h", - "tensor_buffer.h", - "tensor_buffer_requirements.h", - "//tensorflow/lite/experimental/litert/c:litert_event.h", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer.h", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_requirements.h", - ], - linkopts = gles_linkopts(), - deps = [ - ":gpu_environment", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_event_type", - "//tensorflow/lite/experimental/litert/c:litert_gl_types", - "//tensorflow/lite/experimental/litert/c:litert_layout", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - "//tensorflow/lite/experimental/litert/cc:litert_event", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer_utils", - "//tensorflow/lite/experimental/litert/core/util:tensor_type_util", - "//tensorflow/lite/experimental/litert/runtime/opencl:buffer", - "//tensorflow/lite/experimental/litert/runtime/opencl:cl_command_queue", - "//tensorflow/lite/experimental/litert/runtime/opencl:cl_context", - "//tensorflow/lite/experimental/litert/runtime/opencl:opencl_wrapper", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@opencl_headers", - ] + gles_deps() + select({ - "//tensorflow:android": [ - "//tensorflow/lite/delegates/gpu/gl:gl_buffer", - "//tensorflow/lite/delegates/gpu/gl:gl_texture", - ], - "//conditions:default": [], - }), -) - -cc_library( - name = "gpu_environment", - srcs = [ - "gpu_environment.cc", - ], - hdrs = [ - "gpu_environment.h", - ], - visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - "//tensorflow/lite/experimental/litert/c:__subpackages__", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime/opencl:cl_command_queue", - "//tensorflow/lite/experimental/litert/runtime/opencl:cl_context", - "//tensorflow/lite/experimental/litert/runtime/opencl:cl_device", - "//tensorflow/lite/experimental/litert/runtime/opencl:opencl_wrapper", - "@opencl_headers", - ], -) - -cc_test( - name = "gpu_environment_test", - srcs = ["gpu_environment_test.cc"], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - "requires-gpu-nvidia", - ], - deps = [ - ":gpu_environment", - "@com_google_googletest//:gtest_main", - # copybara:uncomment_begin(google-only) - # "//third_party/ml_drift/cl:environment", - # "//third_party/ml_drift/cl:opencl_wrapper", - # copybara:uncomment_end - "//tensorflow/lite/experimental/litert/c:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/runtime/opencl:opencl_wrapper", - ], -) - -cc_library( - name = "tfl_utils", - srcs = [ - "tfl_utils.cc", - ], - hdrs = [ - "tfl_utils.h", - ], - deps = [ - "//tensorflow/lite/c:c_api", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_detail", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - ], -) - -cc_library( - name = "external_litert_buffer_context", - srcs = ["external_litert_buffer_context.cc"], - hdrs = ["external_litert_buffer_context.h"], - visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - ] + lite_rt_friends(), - deps = [ - ":tfl_utils", - "//tensorflow/lite/c:c_api", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer_requirements", - ], -) - -cc_library( - name = "compiled_model", - srcs = ["compiled_model.cc"], - hdrs = ["compiled_model.h"], - deps = [ - ":accelerator", - ":accelerator_model_compilation_data", - ":compilation_options", - ":external_litert_buffer_context", - ":tensor_buffer", - "//tensorflow/compiler/mlir/lite:allocation", - "//tensorflow/lite:builtin_ops", - "//tensorflow/lite:framework", - "//tensorflow/lite:model_builder", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core:cc_api_stable", - "//tensorflow/lite/delegates/utils:simple_opaque_delegate", - "//tensorflow/lite/experimental/litert/c:litert_accelerator", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_event", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer_requirements", - "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_plugin", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:model_serialize", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "compiled_model_test", - srcs = ["compiled_model_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - # require GPU to run OpenCL tests. - tags = [ - "requires-gpu-nvidia", - ], - deps = [ - ":compiled_model", - ":tensor_buffer", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_environment", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "compilation_options", - hdrs = [ - "compilation_options.h", - "//tensorflow/lite/experimental/litert/c:litert_compilation_options.h", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_layout", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_accelerator_compilation_options", - ], -) - -cc_test( - name = "gl_buffer_test", - srcs = ["gl_buffer_test.cc"], - linkopts = select({ - "//tensorflow:android": [ - "-landroid", - ], - "//conditions:default": [], - }), - tags = [ - "notap", - ], - deps = [ - ":tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/test:matchers", - ] + gtest_main_no_heapcheck_deps() + select({ - "//tensorflow:android": [ - "//tensorflow/lite/delegates/gpu/gl:egl_environment", - "//tensorflow/lite/delegates/gpu/gl:gl_buffer", - ], - "//conditions:default": [], - }), -) - -cc_library( - name = "tensor_buffer_conversion", - srcs = ["tensor_buffer_conversion.cc"], - hdrs = ["tensor_buffer_conversion.h"], - linkopts = gles_linkopts(), - deps = [ - ":tensor_buffer", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer_utils", - "@com_google_absl//absl/strings:str_format", - ] + gles_deps(), -) - -cc_test( - name = "tensor_buffer_conversion_test", - srcs = ["tensor_buffer_conversion_test.cc"], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - "notap", - ], - deps = [ - ":tensor_buffer", - ":tensor_buffer_conversion", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_types", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/lite/delegates/gpu/gl:egl_environment", - ], - "//conditions:default": [], - }), -) - -cc_library( - name = "accelerator", - hdrs = ["accelerator.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - ], -) - -cc_library( - name = "accelerator_registry", - srcs = ["accelerator_registry.cc"], - hdrs = ["accelerator_registry.h"], - deps = [ - ":accelerator", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - ], -) - -cc_library( - name = "accelerator_model_compilation_data", - hdrs = ["accelerator_model_compilation_data.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - ], -) - -cc_test( - name = "accelerator_model_compilation_data_test", - srcs = ["accelerator_model_compilation_data_test.cc"], - deps = [ - ":accelerator_model_compilation_data", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/core:version", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/accelerator.h b/tensorflow/lite/experimental/litert/runtime/accelerator.h deleted file mode 100644 index 2574588482976f..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerator.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -// We need to forward declare this to avoid a dependency loop. -struct LiteRtCompiledModelT; -struct LiteRtEnvironmentT; - -struct LiteRtAcceleratorT { - // Points to the type-erased accelerator state. - void* data; - - // Points to the environment that owns this accelerator. - LiteRtEnvironmentT* env; - - // NOLINTBEGIN(*-readability-class-member-naming) - - // Releases the the data. - // - // This function is used by the framework to clean up the accelerator. It - // should not be called by client code. - void (*ReleaseData)(void*); - - // Retrieves the accelerator name. - LiteRtStatus (*GetName)(LiteRtAcceleratorT* accelerator, const char** name); - - // Retrieves the accelerator version. - LiteRtStatus (*GetVersion)(LiteRtAcceleratorT* accelerator, - LiteRtApiVersion* version); - - // Retrieves the accelerator hardware support. - LiteRtStatus (*GetHardwareSupport)( - LiteRtAcceleratorT* accelerator, - LiteRtHwAcceleratorSet* supported_hardware); - - // Creates a delegate for the accelerator. - // Used void** instead of TfLiteOpaqueDelegate** to avoid TFLite dependency. - LiteRtStatus (*CreateDelegate)( - LiteRtAcceleratorT* accelerator, - LiteRtAcceleratorCompilationOptions compilation_options, void** delegate); - - // Destroys created delegate for the accelerator. - // The function signature is matched with existing TfLiteOpaqueDelegate - // interface to use. - // Used void* instead of TfLiteOpaqueDelegate* to avoid TFLite dependency. - void (*DestroyDelegate)(void* delegate); - - LiteRtStatus (*IsTfLiteDelegateResponsibleForJitCompilation)( - LiteRtAcceleratorT* accelerator, bool* does_jit_compilation); - - // NOLINTEND(*-readability-class-member-naming) -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h b/tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h deleted file mode 100644 index 9f465134533d25..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_MODEL_COMPILATION_DATA_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_MODEL_COMPILATION_DATA_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -// Holds environment data that accelerators may need to prepare their -// delegates. -// -// These options are automatically added to the compilation options list -// during the creation of the compiled model. -struct ModelCompilationData { - static constexpr LiteRtApiVersion kVersion = {1, 0, 0}; - static constexpr auto kIdentifier = "environment-compilation-options"; - - static Expected CreateOptions() { - auto* payload_data = new ModelCompilationData; - auto payload_destructor = [](void* payload_data) { - delete reinterpret_cast(payload_data); - }; - return AcceleratorCompilationOptions::Create( - kVersion, kIdentifier, payload_data, payload_destructor); - } - - // Pointer to the start of the model file memory allocation. - const char* allocation_base; - // File descriptor of the model file memory allocation. If there is no such - // file descriptor, this must be set to -1. - int allocation_fd; - - private: - ModelCompilationData() = default; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_MODEL_COMPILATION_DATA_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data_test.cc b/tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data_test.cc deleted file mode 100644 index 8ebb1aa5426ba5..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data_test.cc +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h" - -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/core/version.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { - -using testing::Eq; -using testing::StrEq; - -TEST(ModelCompilationDataTest, CreateSetsUpAllNecessaryFields) { - LITERT_ASSERT_OK_AND_ASSIGN( - auto options, litert::internal::ModelCompilationData::CreateOptions()); - - LITERT_ASSERT_OK_AND_ASSIGN(auto identifier, options.GetIdentifier()); - EXPECT_THAT(identifier, - StrEq(litert::internal::ModelCompilationData::kIdentifier)); - - LITERT_ASSERT_OK_AND_ASSIGN(auto version, options.GetVersion()); - EXPECT_TRUE(litert::internal::IsSameVersion( - version, litert::internal::ModelCompilationData::kVersion)); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/runtime/accelerator_registry.cc b/tensorflow/lite/experimental/litert/runtime/accelerator_registry.cc deleted file mode 100644 index c74577d7f26ef8..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerator_registry.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/accelerator_registry.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" - -namespace litert::internal { - -void AcceleratorRegistry::DestroyAccelerator(LiteRtAcceleratorT* accelerator) { - if (accelerator && accelerator->ReleaseData) { - accelerator->env = nullptr; - accelerator->ReleaseData(accelerator->data); - } - delete accelerator; -} - -Expected AcceleratorRegistry::RegisterAccelerator( - Ptr accelerator) { - if (!accelerator) { - return Error(kLiteRtStatusErrorInvalidArgument, - "Cannot register a null accelerator."); - } - accelerators_.push_back(std::move(accelerator)); - return accelerators_.back().get(); -} - -Expected AcceleratorRegistry::Get(LiteRtParamIndex idx) { - if (idx >= size()) { - return Error(kLiteRtStatusErrorNotFound, "Cannot find accelerator."); - } - return accelerators_[idx].get(); -} - -Expected AcceleratorRegistry::FindAcceleratorIndex( - LiteRtAcceleratorT* accelerator) { - for (size_t idx = 0; idx < accelerators_.size(); ++idx) { - if (accelerator == accelerators_[idx].get()) { - return static_cast(idx); - } - } - return Error(kLiteRtStatusErrorNotFound, - "The accelerator is not registered in the LiteRT environment."); -} - -void AcceleratorRegistry::TakeOwnershipOfSharedLibrary(SharedLibrary lib) { - accelerator_shared_libraries_.push_back(std::move(lib)); -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/runtime/accelerator_registry.h b/tensorflow/lite/experimental/litert/runtime/accelerator_registry.h deleted file mode 100644 index 11c4feec022985..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerator_registry.h +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_REGISTRY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_REGISTRY_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" - -namespace litert::internal { - -// Holds a list of accelerators. -// -// This is a helper class for the LiteRT environment that manages the -// accelerators (and their resources) that are registered with it. -class AcceleratorRegistry { - public: - struct Deleter { - void operator()(LiteRtAcceleratorT* accelerator) { - DestroyAccelerator(accelerator); - } - }; - - // Wraps a pointer for LiteRtAcceleratorT with a custom deleter that handles - // cleaning up the accelerator internal data. - using Ptr = std::unique_ptr<::LiteRtAcceleratorT, Deleter>; - - // Internal implementation for the C API. - [[nodiscard]] - static Ptr CreateEmptyAccelerator() { - return Ptr(new LiteRtAcceleratorT()); - } - - // Internal implementation for the C API. - static void DestroyAccelerator(::LiteRtAcceleratorT* accelerator); - - // Registers an accelerator. - Expected RegisterAccelerator(Ptr accelerator); - - // Returns the idx-th accelerator that was registered. - [[nodiscard]] - Expected Get(LiteRtParamIndex idx); - - // Goes through accelerators and find the index of the given one. - Expected FindAcceleratorIndex( - LiteRtAcceleratorT* accelerator); - - // Gives ownership of the shared library to the registry. - // - // This should be called when an accelerator is loaded from a shared library - // to tie the library lifetime to the registry. - // - // The library will be closed when the registry is destroyed. - void TakeOwnershipOfSharedLibrary(SharedLibrary library); - - // Returns the number of accelerators that have been registered. - size_t size() const { return accelerators_.size(); } - auto begin() const { return accelerators_.begin(); } - auto begin() { return accelerators_.begin(); } - auto end() const { return accelerators_.end(); } - auto end() { return accelerators_.end(); } - - private: - std::vector accelerators_; - // Some accelerators are loaded as shared libraries. This list keeps these - // libraries loaded while the environment uses them. - std::vector accelerator_shared_libraries_; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATOR_REGISTRY_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/accelerator_test.cc b/tensorflow/lite/experimental/litert/runtime/accelerator_test.cc deleted file mode 100644 index 84f88d13b61c75..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerator_test.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" - -#include - -namespace litert::internal { -namespace { - -TEST(AcceleratorRegistryTest, CreateEmptyAcceleratorWorks) { - [[maybe_unused]] - auto accelerator_squeleton = AcceleratorRegistry::CreateEmptyAccelerator(); -} - -TEST(AcceleratorRegistryTest, AcceleratorCanBeRegisteredAndRetrieved) { - AcceleratorRegistry registry; - - auto registered_accelerator1 = registry.RegisterAccelerator( - AcceleratorRegistry::CreateEmptyAccelerator()); - ASSERT_TRUE(registered_accelerator1); - - auto registered_accelerator2 = registry.RegisterAccelerator( - AcceleratorRegistry::CreateEmptyAccelerator()); - ASSERT_TRUE(registered_accelerator2); - - ASSERT_NE(registered_accelerator1, registered_accelerator2); - - auto queried_accelerator1 = registry.Get(0); - ASSERT_TRUE(queried_accelerator1); - EXPECT_EQ(queried_accelerator1, registered_accelerator1); - - auto queried_accelerator2 = registry.Get(1); - ASSERT_TRUE(queried_accelerator2); - EXPECT_EQ(queried_accelerator2, registered_accelerator2); - - EXPECT_FALSE(registry.Get(2)); - EXPECT_FALSE(registry.Get(-1)); - - auto idx1 = registry.FindAcceleratorIndex(queried_accelerator1.Value()); - ASSERT_TRUE(idx1); - EXPECT_EQ(idx1.Value(), 0); - - auto idx2 = registry.FindAcceleratorIndex(queried_accelerator2.Value()); - ASSERT_TRUE(idx2); - EXPECT_EQ(idx2.Value(), 1); -} - -} // namespace -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/BUILD b/tensorflow/lite/experimental/litert/runtime/accelerators/BUILD deleted file mode 100644 index c007232c0c1702..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - ], -) - -cc_library( - name = "auto_registration", - srcs = ["auto_registration.cc"], - hdrs = ["auto_registration.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime/accelerators/dispatch:dispatch_accelerator", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "accelerator_implementation_helper", - hdrs = ["accelerator_implementation_helper.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_accelerator", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_registration", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/runtime:accelerator_model_compilation_data", - "@com_google_absl//absl/strings:string_view", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/accelerator_implementation_helper.h b/tensorflow/lite/experimental/litert/runtime/accelerators/accelerator_implementation_helper.h deleted file mode 100644 index d44d6533267cfc..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/accelerator_implementation_helper.h +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_ACCELERATOR_IMPLEMENTATION_HELPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_ACCELERATOR_IMPLEMENTATION_HELPER_H_ - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h" - -namespace litert::internal { - -struct AcceleratorDestructor { - void operator()(LiteRtAccelerator accelerator) { - LiteRtDestroyAccelerator(accelerator); - } -}; - -// RAII wrapper for accelerator handles. -using AcceleratorGuard = - std::unique_ptr::element_type, - AcceleratorDestructor>; - -// Helps setting up an accelerator handle for accelerators that use the -// `AcceleratorImplementationHelper` template as a base class. -template -Expected SetAcceleratorBoilerplateFunctions( - AcceleratorGuard& accelerator) { - LITERT_RETURN_IF_ERROR( - LiteRtSetAcceleratorGetName(accelerator.get(), T::GetName)); - LITERT_RETURN_IF_ERROR( - LiteRtSetAcceleratorGetVersion(accelerator.get(), T::GetVersion)); - LITERT_RETURN_IF_ERROR(LiteRtSetAcceleratorGetHardwareSupport( - accelerator.get(), T::GetHardwareSupport)); - LITERT_RETURN_IF_ERROR(LiteRtSetDelegateFunction( - accelerator.get(), T::CreateDelegate, T::DestroyDelegate)); - return {}; -} - -// Goes through the options in the linked list and returns the model -// compilation data if it exists. -inline static Expected -GetModelCompilationData(LiteRtAcceleratorCompilationOptions options) { - LiteRtApiVersion payload_version; - void* payload_data; - LITERT_RETURN_IF_ERROR(LiteRtFindAcceleratorCompilationOptionsData( - options, litert::internal::ModelCompilationData::kIdentifier, - &payload_version, &payload_data)); - return reinterpret_cast( - payload_data); -} - -// Helps accelerator implementation by providing a lot of the boilerplate -// needed. -// -// Warning: The provided Ptr assumes that AcceleratorClass instances are -// created using `operator new`. -// -// Warning: `version` should be incremented every time the code of this -// accelerator is updated according to semanting versioning. -template -class AcceleratorImplementationHelper { - public: - // The accelerator name returned by `GetName`. - constexpr static const absl::string_view kName = name_; - // The accelerator version returned by `GetVersion`. - constexpr static const LiteRtApiVersion kVersion = version_; - // The accelerator hardware support returned by `GetHardwareSupport`. - constexpr static const LiteRtHwAcceleratorSet kHwSupport = hardware_support_; - - struct Deleter { - void operator()(AcceleratorClass* accelerator_impl) { - delete accelerator_impl; - } - }; - - // Owning pointer wrapping the accelerator. - using Ptr = std::unique_ptr; - - // Creates a new instance of the accelerator implementation. - template - static Ptr Allocate(Args&&... args) { - return Ptr(new AcceleratorClass(std::forward(args)...)); - } - - // Deletes the accelerator data. - static void Destroy(void* accelerator_impl) { - Deleter()(reinterpret_cast(accelerator_impl)); - } - - // Returns the accelerator's name by setting `name`. - static LiteRtStatus GetName(LiteRtAccelerator accelerator, - const char** name) { - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(name != nullptr, kLiteRtStatusErrorInvalidArgument, - "Name pointer is null."); - *name = kName.data(); - return kLiteRtStatusOk; - } - - // Returns the accelerator's version by setting `version`. - static LiteRtStatus GetVersion(LiteRtAccelerator accelerator, - LiteRtApiVersion* version) { - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(version != nullptr, kLiteRtStatusErrorInvalidArgument, - "Version pointer is null."); - *version = kVersion; - return kLiteRtStatusOk; - } - - // Returns the accelerator's hardware support by setting `hw_set`. - static LiteRtStatus GetHardwareSupport(LiteRtAccelerator accelerator, - LiteRtHwAcceleratorSet* hw_set) { - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(hw_set != nullptr, kLiteRtStatusErrorInvalidArgument, - "Hardware support pointer is null."); - *hw_set = kHwSupport; - return kLiteRtStatusOk; - } -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_ACCELERATOR_IMPLEMENTATION_HELPER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.cc b/tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.cc deleted file mode 100644 index ecda799184f4c8..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.cc +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.h" - -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.h" - -// Define a function pointer to allow the accelerator registration to be -// overridden by the LiteRT environment. This is to use the GPU accelerator -// statically linked. -extern "C" bool (*LiteRtRegisterStaticLinkedAcceleratorGpu)( - LiteRtEnvironmentT& environment) = nullptr; - -namespace litert { - -Expected TriggerAcceleratorAutomaticRegistration( - LiteRtEnvironmentT& environment) { - // Register the NPU accelerator. - - auto npu_registration = - LiteRtRegisterNpuAccelerator(&environment, /*options=*/nullptr); - if (npu_registration != kLiteRtStatusOk) { - LITERT_LOG(LITERT_WARNING, - "GPU accelerator could not be loaded and registered: %s.", - LiteRtGetStatusString(npu_registration)); - } else { - LITERT_LOG(LITERT_INFO, "NPU accelerator registered."); - } - - // Register the GPU accelerator. - if (LiteRtRegisterStaticLinkedAcceleratorGpu != nullptr && - LiteRtRegisterStaticLinkedAcceleratorGpu(environment)) { - LITERT_LOG(LITERT_INFO, "Statically linked GPU accelerator registered."); - return {}; - } - auto gpu_registration = RegisterSharedObjectAccelerator( - environment, /*plugin_path=*/"libLiteRtGpuAccelerator.so", - /*registration_function_name=*/"LiteRtRegisterAcceleratorGpuOpenCl"); - if (!gpu_registration) { - LITERT_LOG(LITERT_WARNING, - "GPU accelerator could not be loaded and registered: %s.", - gpu_registration.Error().Message().c_str()); - } else { - LITERT_LOG(LITERT_INFO, "GPU accelerator registered."); - } - return {}; -}; - -Expected RegisterSharedObjectAccelerator( - LiteRtEnvironmentT& environment, absl::string_view plugin_path, - absl::string_view registration_function_name) { - auto maybe_lib = SharedLibrary::Load(plugin_path, RtldFlags::Lazy().Local()); - if (!maybe_lib.HasValue()) { - maybe_lib = SharedLibrary::Load(RtldFlags::kDefault); - } - // Note: the Load(kDefault) overload always succeeds, so we are sure that - // maybe_lib contains a value. - SharedLibrary lib(std::move(maybe_lib.Value())); - LITERT_ASSIGN_OR_RETURN(auto registration_function, - lib.LookupSymbol( - registration_function_name.data())); - LITERT_RETURN_IF_ERROR(registration_function(&environment)); - environment.GetAcceleratorRegistry().TakeOwnershipOfSharedLibrary( - std::move(lib)); - return {}; -} - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.h b/tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.h deleted file mode 100644 index 5ec12d7ed8a735..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/auto_registration.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_AUTO_REGISTRATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_AUTO_REGISTRATION_H_ - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" - -namespace litert { - -Expected TriggerAcceleratorAutomaticRegistration( - LiteRtEnvironmentT& environment); - -Expected RegisterSharedObjectAccelerator( - LiteRtEnvironmentT& environment, absl::string_view plugin_path, - absl::string_view registration_function_name); - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_AUTO_REGISTRATION_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/BUILD b/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/BUILD deleted file mode 100644 index 68758738b7dc45..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - ], -) - -cc_library( - name = "dispatch_accelerator", - srcs = ["dispatch_accelerator.cc"], - hdrs = ["dispatch_accelerator.h"], - deps = [ - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/experimental/litert/c:litert_accelerator", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_registration", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_dispatch_delegate", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/cc:litert_dispatch_delegate", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime:accelerator_model_compilation_data", - "@com_google_absl//absl/strings:string_view", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.cc b/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.cc deleted file mode 100644 index 66189b4a84b028..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.cc +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.h" - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h" - -namespace litert { - -class NpuAccelerator final { - constexpr static const absl::string_view kName = "NpuAccelerator"; - // Warning: this should be incremented every time the code of this accelerator - // is updated according to semanting versioning. - constexpr static const LiteRtApiVersion kVersion{1, 0, 0}; - constexpr static const LiteRtHwAcceleratorSet kHwSupport = - kLiteRtHwAcceleratorNpu; - - public: - explicit NpuAccelerator(std::string library_folder) - : library_folder_(std::move(library_folder)) {} - - struct Deleter { - void operator()(NpuAccelerator* npu_accelerator) { delete npu_accelerator; } - }; - using Ptr = std::unique_ptr; - - static Expected Create(std::string library_folder) { - LITERT_RETURN_IF_ERROR( - !library_folder.empty(), - Error(kLiteRtStatusErrorInvalidArgument, - "Dispatch API implementation library folder was not specified.")); - return Ptr(new NpuAccelerator(std::move(library_folder))); - } - - // C API - - // Deletes the accelerator data. - static void Destroy(void* npu_accelerator) { - Deleter()(reinterpret_cast(npu_accelerator)); - } - - // Stores the accelerator's name in `name`. - static LiteRtStatus GetName(LiteRtAccelerator accelerator, - const char** name) { - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(name != nullptr, kLiteRtStatusErrorInvalidArgument, - "Name pointer is null."); - *name = kName.data(); - return kLiteRtStatusOk; - } - - // Stores the accelerator's version in `version`. - static LiteRtStatus GetVersion(LiteRtAccelerator accelerator, - LiteRtApiVersion* version) { - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(version != nullptr, kLiteRtStatusErrorInvalidArgument, - "Version pointer is null."); - *version = kVersion; - return kLiteRtStatusOk; - } - - // Stores the accelerator's hardware support in `hw_set`. - static LiteRtStatus GetHardwareSupport(LiteRtAccelerator accelerator, - LiteRtHwAcceleratorSet* hw_set) { - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(hw_set != nullptr, kLiteRtStatusErrorInvalidArgument, - "Harware support pointer is null."); - *hw_set = kHwSupport; - return kLiteRtStatusOk; - } - - // Goes through the options in the linked list and returns the model - // compilation data if it exists. - static Expected - GetModelCompilationData(LiteRtAcceleratorCompilationOptions options) { - LiteRtApiVersion payload_version; - void* payload_data; - LITERT_RETURN_IF_ERROR(LiteRtFindAcceleratorCompilationOptionsData( - options, litert::internal::ModelCompilationData::kIdentifier, - &payload_version, &payload_data)); - return reinterpret_cast( - payload_data); - } - - // Creates a Dispatch delegate instance. - static LiteRtStatus CreateDelegate( - LiteRtAccelerator accelerator, - LiteRtAcceleratorCompilationOptions options, void** delegate) { - LITERT_ENSURE(delegate != nullptr, kLiteRtStatusErrorInvalidArgument, - "Delegate pointer is null."); - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(accelerator->env != nullptr, - kLiteRtStatusErrorInvalidArgument, - "Accelerator is not registered to an environment."); - - LITERT_ASSIGN_OR_RETURN( - const litert::internal::ModelCompilationData* compilation_data, - GetModelCompilationData(options)); - - LITERT_ENSURE(compilation_data->allocation_base, - kLiteRtStatusErrorRuntimeFailure, - "No model allocation was passed by the runtime."); - - auto dispatch_delegate_options = litert::CreateDispatchDelegateOptionsPtr( - &accelerator->env->GetOptions()); - LITERT_ENSURE(dispatch_delegate_options != nullptr, - kLiteRtStatusErrorRuntimeFailure, - "Dispatch delegate options failed to be created."); - - LITERT_ENSURE( - LiteRtDispatchDelegateAddAllocBaseOption( - dispatch_delegate_options.get(), - compilation_data->allocation_base) == kTfLiteOk, - kLiteRtStatusErrorRuntimeFailure, - "Could not add allocation base to dispatch delegate options."); - - if (compilation_data->allocation_fd != -1) { - LITERT_ENSURE(LiteRtDispatchDelegateAddAllocFdOption( - dispatch_delegate_options.get(), - compilation_data->allocation_fd) == kTfLiteOk, - kLiteRtStatusErrorRuntimeFailure, - "Could not add allocation file descriptor to dispatch " - "delegate options."); - } - - auto dispatch_delegate = litert::CreateDispatchDelegatePtr( - &accelerator->env->GetOptions(), std::move(dispatch_delegate_options)); - LITERT_ENSURE(dispatch_delegate != nullptr, - kLiteRtStatusErrorRuntimeFailure, - "Dispatch delegate failed to be created."); - - *delegate = dispatch_delegate.release(); - return kLiteRtStatusOk; - } - - // Destroys a Dispatch delegate instance. - static void DestroyDelegate(void* delegate) { - LiteRtDestroyDispatchDelegate( - reinterpret_cast(delegate)); - } - - private: - // Note: we do not directly use the option structure because we want to copy - // and own all the option data. - - // Folder to the Dispatch API implementation shared library. - std::string library_folder_; -}; - -namespace { - -struct AcceleratorDestructor { - void operator()(LiteRtAccelerator accelerator) { - LiteRtDestroyAccelerator(accelerator); - } -}; - -using AcceleratorGuard = - std::unique_ptr::element_type, - AcceleratorDestructor>; - -} // namespace -} // namespace litert - -extern "C" { - -LiteRtStatus LiteRtRegisterNpuAccelerator( - LiteRtEnvironmentT* environment, LiteRtNpuAcceleratorOptions* options) { - LITERT_ENSURE(environment != nullptr, kLiteRtStatusErrorInvalidArgument, - "accelerator handle is invalid"); - LiteRtAccelerator accelerator_handle; - LITERT_RETURN_IF_ERROR(LiteRtCreateAccelerator(&accelerator_handle)); - litert::AcceleratorGuard accelerator(accelerator_handle); - LITERT_RETURN_IF_ERROR(LiteRtSetAcceleratorGetName( - accelerator.get(), litert::NpuAccelerator::GetName)); - LITERT_RETURN_IF_ERROR(LiteRtSetAcceleratorGetVersion( - accelerator.get(), litert::NpuAccelerator::GetVersion)); - LITERT_RETURN_IF_ERROR(LiteRtSetAcceleratorGetHardwareSupport( - accelerator.get(), litert::NpuAccelerator::GetHardwareSupport)); - - LITERT_RETURN_IF_ERROR(LiteRtSetDelegateFunction( - accelerator.get(), litert::NpuAccelerator::CreateDelegate, - litert::NpuAccelerator::DestroyDelegate)); - - std::string library_folder; - if (options && options->library_folder) { - library_folder = options->library_folder; - } - // Check the environment options if the library folder wasn't set in the - // options. - if (library_folder.empty()) { - if (auto env_library_folder = - environment->GetOption(kLiteRtEnvOptionTagDispatchLibraryDir); - env_library_folder.has_value()) { - LITERT_ASSIGN_OR_RETURN( - library_folder, litert::Get(env_library_folder.value())); - } - } - - LITERT_ASSIGN_OR_RETURN( - auto accelerator_impl, - litert::NpuAccelerator::Create(std::move(library_folder))); - - LITERT_RETURN_IF_ERROR(LiteRtRegisterAccelerator( - environment, accelerator.release(), accelerator_impl.release(), - litert::NpuAccelerator::Destroy)); - return kLiteRtStatusOk; -} - -} // extern "C" diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.h b/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.h deleted file mode 100644 index 9c1d93938eb28c..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/dispatch/dispatch_accelerator.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_DISPATCH_DISPATCH_ACCELERATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_DISPATCH_DISPATCH_ACCELERATOR_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" - -#ifdef __cplusplus -extern "C" { -#endif - -struct LiteRtNpuAcceleratorOptions { - const char* library_folder; -}; - -// Registers the NPU accelerator to the given environment. -// -// `options` may be null, in which case the accelerator is registered with -// a default configuration. -// -// If `options.library_folder` is not specified, the library folder is replaced -// with the `LiteRtEnvOptionTagDispatchLibraryDir` environment option (that was -// passed upon creation). -// -// Once this function has returned, options may be freed or reused. -LiteRtStatus LiteRtRegisterNpuAccelerator(LiteRtEnvironment environment, - LiteRtNpuAcceleratorOptions* options); - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_DISPATCH_DISPATCH_ACCELERATOR_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/BUILD b/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/BUILD deleted file mode 100644 index 0a3140eb8ca7c2..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2025 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - ], -) - -cc_library( - name = "xnnpack_accelerator", - srcs = ["xnnpack_accelerator.cc"], - hdrs = ["xnnpack_accelerator.h"], - deps = [ - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", - "//tensorflow/lite/experimental/litert/c:litert_accelerator", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_compilation_options", - "//tensorflow/lite/experimental/litert/c:litert_accelerator_registration", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:environment", - "//tensorflow/lite/experimental/litert/runtime/accelerators:accelerator_implementation_helper", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.cc b/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.cc deleted file mode 100644 index a69e36147f00da..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.cc +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.h" - -#include - -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_registration.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerators/accelerator_implementation_helper.h" - -namespace litert { - -namespace { -constexpr const char kCpuAcceleratorName[] = "CpuAccelerator"; -constexpr const LiteRtApiVersion kCpuAcceleratorVersion{1, 0, 0}; - -class CpuAccelerator final - : public internal::AcceleratorImplementationHelper< - CpuAccelerator, kCpuAcceleratorName, kCpuAcceleratorVersion, - kLiteRtHwAcceleratorCpu> { - public: - CpuAccelerator() = default; - - static Expected Create() { return Allocate(); } - - // C API - - // Creates a Dispatch delegate instance. - static LiteRtStatus CreateDelegate( - LiteRtAccelerator accelerator, - LiteRtAcceleratorCompilationOptions options, void** delegate) { - LITERT_ENSURE(delegate != nullptr, kLiteRtStatusErrorInvalidArgument, - "Delegate pointer is null."); - LITERT_ENSURE(accelerator != nullptr, kLiteRtStatusErrorInvalidArgument, - "Accelerator handle is invalid."); - LITERT_ENSURE(accelerator->env != nullptr, - kLiteRtStatusErrorInvalidArgument, - "Accelerator is not registered to an environment."); - - // TODO: b/403547017 - Make the CPU accelerator configurable using the - // compilation options. - auto xnn_options = TfLiteXNNPackDelegateOptionsDefault(); - *delegate = TfLiteXNNPackDelegateCreate(&xnn_options); - - LITERT_ENSURE(*delegate != nullptr, kLiteRtStatusErrorRuntimeFailure, - "XNNPack delegate failed to be created."); - return kLiteRtStatusOk; - } - - // Destroys an XNNPack delegate instance. - static void DestroyDelegate(void* delegate) { - TfLiteXNNPackDelegateDelete(reinterpret_cast(delegate)); - } -}; - -} // namespace -} // namespace litert - -extern "C" { - -LiteRtStatus LiteRtRegisterCpuAccelerator( - LiteRtEnvironmentT* environment, LiteRtCpuAcceleratorOptions* options) { - LITERT_ENSURE(environment != nullptr, kLiteRtStatusErrorInvalidArgument, - "accelerator handle is invalid"); - LiteRtAccelerator accelerator_handle; - LITERT_RETURN_IF_ERROR(LiteRtCreateAccelerator(&accelerator_handle)); - litert::internal::AcceleratorGuard accelerator(accelerator_handle); - - LITERT_RETURN_IF_ERROR(litert::internal::SetAcceleratorBoilerplateFunctions< - litert::CpuAccelerator>(accelerator)); - - LITERT_ASSIGN_OR_RETURN(auto accelerator_impl, - litert::CpuAccelerator::Create()); - - LITERT_RETURN_IF_ERROR(LiteRtRegisterAccelerator( - environment, accelerator.release(), accelerator_impl.release(), - litert::CpuAccelerator::Destroy)); - return kLiteRtStatusOk; -} - -} // extern "C" diff --git a/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.h b/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.h deleted file mode 100644 index 01a252c7e1d429..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/accelerators/xnnpack/xnnpack_accelerator.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_XNNPACK_XNNPACK_ACCELERATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_XNNPACK_XNNPACK_ACCELERATOR_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// Options that may be passed to the CPU accelerator when it is registered. -struct LiteRtCpuAcceleratorOptions {}; - -// Registers the CPU accelerator to the given environment. -// -// `options` may be null, in which case the accelerator is registered with -// a default configuration. -// -// Once this function has returned, options may be freed or reused. -LiteRtStatus LiteRtRegisterCpuAccelerator(LiteRtEnvironment environment, - LiteRtCpuAcceleratorOptions* options); - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ACCELERATORS_XNNPACK_XNNPACK_ACCELERATOR_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/ahwb_buffer.cc b/tensorflow/lite/experimental/litert/runtime/ahwb_buffer.cc deleted file mode 100644 index 26746dcd632546..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/ahwb_buffer.cc +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/event.h" -#if LITERT_HAS_AHWB_SUPPORT -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#endif // LITERT_HAS_AHWB_SUPPORT - -namespace litert { -namespace internal { - -bool AhwbBuffer::IsSupported() { -#if LITERT_HAS_AHWB_SUPPORT - return true; -#else - return false; -#endif -} - -Expected AhwbBuffer::Alloc(size_t size) { -#if LITERT_HAS_AHWB_SUPPORT - AHardwareBuffer* ahwb; - AHardwareBuffer_Desc ahwb_desc = { - .width = static_cast(size), - .height = 1, - .layers = 1, - .format = AHARDWAREBUFFER_FORMAT_BLOB, - .usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY | - AHARDWAREBUFFER_USAGE_CPU_READ_RARELY | - AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER}; - if (AHardwareBuffer_allocate(&ahwb_desc, &ahwb) != 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate AHWB"); - } - return AhwbBuffer{/*.ahwb=*/ahwb}; -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffers are not supported on this platform"); -#endif // LITERT_HAS_AHWB_SUPPORT -} - -void AhwbBuffer::Free(AHardwareBuffer* ahwb) { -#if LITERT_HAS_AHWB_SUPPORT - AHardwareBuffer_release(ahwb); -#endif -} - -Expected AhwbBuffer::GetSize(AHardwareBuffer* ahwb) { -#if LITERT_HAS_AHWB_SUPPORT - AHardwareBuffer_Desc ahwb_desc; - AHardwareBuffer_describe(ahwb, &ahwb_desc); - return static_cast(ahwb_desc.width) * ahwb_desc.height * - ahwb_desc.layers; -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffers are not supported on this platform"); -#endif // LITERT_HAS_AHWB_SUPPORT -} - -Expected AhwbBuffer::Lock(AHardwareBuffer* ahwb, LiteRtEventT* event) { -#if LITERT_HAS_AHWB_SUPPORT - int fence = -1; - if (event != nullptr) { - LITERT_ASSIGN_OR_RETURN(fence, event->GetSyncFenceFd()); - } - void* host_addr; - LITERT_RETURN_IF_ERROR( - AHardwareBuffer_lock(ahwb, - AHARDWAREBUFFER_USAGE_CPU_READ_RARELY | - AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, - fence, /*rect=*/nullptr, &host_addr) == 0, - Unexpected(kLiteRtStatusErrorRuntimeFailure, "Failed to lock AHWB")); - return host_addr; -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffers are not supported on this platform"); -#endif -} - -Expected AhwbBuffer::Unlock(AHardwareBuffer* ahwb) { -#if LITERT_HAS_AHWB_SUPPORT - if (AHardwareBuffer_unlock(ahwb, /*fence=*/nullptr) != 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to unlock AHWB"); - } - return {}; -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffers are not supported on this platform"); -#endif -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h b/tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h deleted file mode 100644 index 8722305225c1e7..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_AHWB_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_AHWB_BUFFER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/event.h" - -#if LITERT_HAS_AHWB_SUPPORT -#include -#else -// Define a place holder AHardwareBuffer struct just to enable compilation. -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus -typedef struct AHardwareBuffer AHardwareBuffer; -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus -#endif // LITERT_HAS_AHWB_SUPPORT - -namespace litert::internal { - -struct AhwbBuffer { - AHardwareBuffer* ahwb; - - static bool IsSupported(); - static Expected Alloc(size_t size); - static void Free(AHardwareBuffer* ahwb); - static Expected GetSize(AHardwareBuffer* ahwb); - static Expected Lock(AHardwareBuffer* ahwb, - LiteRtEventT* event = nullptr); - static Expected Unlock(AHardwareBuffer* ahwb); -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_AHWB_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/compilation_options.h b/tensorflow/lite/experimental/litert/runtime/compilation_options.h deleted file mode 100644 index b92f6555a66b2d..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compilation_options.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILATION_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILATION_OPTIONS_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_accelerator_compilation_options.h" - -struct LiteRtCompilationOptionsT { - // This should be updated every time a field is added/edited. - // - // - Renaming a field: increment patch; - // - Adding or deprecating a field: set patch to 0, increment minor. - // - Breaking layout compatibility: set patch and minor to 0, increment major. - // - // Note: Changing a default value does not impact the version. - LiteRtApiVersion version = {.major = 0, .minor = 0, .patch = 1}; - LiteRtHwAcceleratorSet hardware_accelerators = kLiteRtHwAcceleratorNone; - litert::AcceleratorCompilationOptions accelerator_compilation_options; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILATION_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/compiled_model.cc b/tensorflow/lite/experimental/litert/runtime/compiled_model.cc deleted file mode 100644 index 04e538cf72d77b..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compiled_model.cc +++ /dev/null @@ -1,638 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/compiled_model.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/cleanup/cleanup.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator.h" -#include "tensorflow/lite/experimental/litert/c/litert_accelerator_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator.h" -#include "tensorflow/lite/experimental/litert/runtime/accelerator_model_compilation_data.h" - -#if defined(__ANDROID__) -#include -#endif - -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "tensorflow/compiler/mlir/lite/allocation.h" -#include "tensorflow/lite/builtin_ops.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/core/interpreter_builder.h" -#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" -#include "tensorflow/lite/experimental/litert/runtime/compilation_options.h" -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/stderr_reporter.h" - -using litert::Error; -using litert::Expected; -using litert::OwningBufferRef; -using litert::TensorBuffer; -using litert::Unexpected; -using litert::internal::ExternalLiteRtBufferContext; -using litert::internal::SerializeModel; - -Expected LiteRtCompiledModelT::InitializeRuntime() { - tflite::ops::builtin::BuiltinOpResolver resolver; - tflite::InterpreterBuilder(*fb_model_, resolver)(&interp_); - if (interp_ == nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to build TFL interpreter"); - } - - signature_keys_ = interp_->signature_keys(); - if (signature_keys_.empty()) { - static auto* default_signature_key = - new std::string(LiteRtSignatureT::kDefaultSignatureKey); - signature_keys_.push_back(default_signature_key); - } - // Register the ExternalLiteRtBufferContext for TensorBuffer handshaking. - buffer_context_ = - std::make_unique(); - interp_->SetExternalContext(kTfLiteLiteRtBufferContext, - buffer_context_.get()); - - return {}; -} - -Expected LiteRtCompiledModelT::InitializeModel( - LiteRtModelT& model, LiteRtHwAcceleratorSet hw_accelerators, - LiteRtEnvironmentT& env) { - bool need_reserialization = false; - - if (hw_accelerators != kLiteRtHwAcceleratorNone) { - LITERT_LOG(LITERT_INFO, "Applying compiler plugins..."); - auto jit_result = litert::internal::ApplyPlugins( - &env, &model, hw_accelerators, &need_reserialization); - if (!jit_result) { - LITERT_LOG(LITERT_WARNING, "Failed to apply compiler plugins: %s", - jit_result.Error().Message().c_str()); - } else { - LITERT_LOG( - LITERT_INFO, "%d compiler plugins were applied successfully: %s", - jit_result->num_applied_plugins, jit_result->success_message.c_str()); - LITERT_LOG(LITERT_WARNING, "Plugin errs: %s", - jit_result->error_message.c_str()); - } - } - - const auto& tfl_wrapper = litert::internal::GetTflFlatbuffer(model); - // Currently, in all situations where litert model was import from a - // flatbuffer, the litert model will own said flatbuffer and stored it in the - // OwningBufferRef. - auto tfl_buf = tfl_wrapper.Buf(); - - if (!need_reserialization && tfl_buf.Data() != nullptr) { - LITERT_LOG( - LITERT_INFO, - "Flatbuffer model initialized directly from incoming litert model."); - fb_model_ = tflite::FlatBufferModel::BuildFromBuffer(tfl_buf.StrData(), - tfl_buf.Size()); - return {}; - } - - LITERT_LOG(LITERT_INFO, "JIT compilation changed model, reserializing..."); - - auto serialized = SerializeModel(std::move(model)); - if (!serialized) { - return serialized.Error(); - } - - model_buf_ = std::move(*serialized); - fb_model_ = tflite::FlatBufferModel::BuildFromBuffer( - reinterpret_cast(model_buf_.Data()), model_buf_.Size()); - if (fb_model_ == nullptr) { - return Unexpected(kLiteRtStatusErrorFileIO, - "Failed to build flatbuffer from buffer"); - } - - return {}; -} - -namespace { - -// A utility class that allows appending additional compilation options, but -// only for the duration of a scope. -class ScopedCompilationOptionsModifier { - public: - explicit ScopedCompilationOptionsModifier( - LiteRtCompilationOptions compilation_options) - : accelerator_options_( - compilation_options->accelerator_compilation_options) {} - - ~ScopedCompilationOptionsModifier() { - // Remove any option that was appended during the lifetime of this object. - while (--num_appended_options_ >= 0) { - accelerator_options_.Pop(); - } - } - - Expected Append( - litert::AcceleratorCompilationOptions&& accelerator_options) { - auto status = accelerator_options_.Append(std::move(accelerator_options)); - if (status) { - ++num_appended_options_; - } - return status; - } - - private: - litert::AcceleratorCompilationOptions& accelerator_options_; - int num_appended_options_ = 0; -}; - -int GetAllocationFd(const tflite::Allocation* allocation) { - if (allocation != nullptr && - allocation->type() == tflite::Allocation::Type::kMMap) { - auto& mmap_allocation = - static_cast(*allocation); - return mmap_allocation.fd(); - } - return -1; -} - -} // namespace - -Expected LiteRtCompiledModelT::Create( - LiteRtEnvironmentT* env, LiteRtModel model, - LiteRtCompilationOptions jit_compilation_options) { - // If no compilation options were passed, we use default object. This allows - // us to add (for instance) accelerator compilation options. - std::unique_ptr - placeholder_jit_compilation_options; - if (!jit_compilation_options) { - placeholder_jit_compilation_options = - std::make_unique(); - jit_compilation_options = placeholder_jit_compilation_options.get(); - } - - auto compiled_model = std::make_unique(); - - LiteRtHwAcceleratorSet hardware_accelerators = kLiteRtHwAcceleratorNone; - if (jit_compilation_options) { - LiteRtGetCompilationOptionsHardwareAccelerators(jit_compilation_options, - &hardware_accelerators); - } - - LITERT_RETURN_IF_ERROR( - compiled_model->InitializeModel(*model, hardware_accelerators, *env)); - - LITERT_RETURN_IF_ERROR(compiled_model->InitializeRuntime()); - if (compiled_model->GetModelBase() == nullptr) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to initialize model memory."); - } - - // Add a new link in the accelerator compilation options that holds some data - // that is computed during model compilation. - LITERT_ASSIGN_OR_RETURN( - auto model_compilation_data_options, - litert::internal::ModelCompilationData::CreateOptions()); - - LITERT_ASSIGN_OR_RETURN( - auto* model_compilation_data, - model_compilation_data_options - .GetData()); - model_compilation_data->allocation_base = compiled_model->GetModelBase(); - model_compilation_data->allocation_fd = - GetAllocationFd(compiled_model->fb_model_->allocation()); - - // Temporarily append model_compilation_data to the jit_compilation_options, - // but remove it before returning from this function since the caller owns - // jit_compilation_options and may use it for other purposes. - ScopedCompilationOptionsModifier scoped_modifier(jit_compilation_options); - LITERT_RETURN_IF_ERROR( - scoped_modifier.Append(std::move(model_compilation_data_options))); - - // Retrieve the accelerator options list. - LiteRtAcceleratorCompilationOptions accelerator_options = nullptr; - LITERT_RETURN_IF_ERROR(LiteRtGetAcceleratorCompilationOptions( - jit_compilation_options, &accelerator_options)); - - // Apply accelerators matching the requested hardware support to the - // model in the order they were registered. - for (auto& accelerator : env->GetAcceleratorRegistry()) { - bool delegate_responsible_for_jit = false; - LITERT_RETURN_IF_ERROR( - LiteRtIsAcceleratorDelegateResponsibleForJitCompilation( - accelerator.get(), &delegate_responsible_for_jit)); - LiteRtHwAcceleratorSet accelerator_supported_hardware; - LITERT_RETURN_IF_ERROR(accelerator->GetHardwareSupport( - accelerator.get(), &accelerator_supported_hardware)); - // We don't apply the delegate if: - // - the delegate is responsible for JIT compilation - // - and JIT has not been requested for the hardware it supports. - if (delegate_responsible_for_jit && - !(hardware_accelerators & accelerator_supported_hardware)) { - continue; - } - - TfLiteOpaqueDelegate* delegate_ptr = nullptr; - LITERT_RETURN_IF_ERROR( - accelerator->CreateDelegate(accelerator.get(), accelerator_options, - reinterpret_cast(&delegate_ptr))); - - auto delegate = tflite::TfLiteOpaqueDelegateUniquePtr( - delegate_ptr, reinterpret_cast( - accelerator->DestroyDelegate)); - - if (compiled_model->interp_->ModifyGraphWithDelegate(delegate_ptr) != - kTfLiteOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to modify graph with delegate"); - } - compiled_model->RegisterDelegate(std::move(delegate)); - } - - compiled_model->CheckCpuTensors(); - return compiled_model; -} - -void LiteRtCompiledModelT::CheckCpuTensors() { - cpu_tensors_.clear(); - for (int subgraph_no = 0; subgraph_no < interp_->subgraphs_size(); - ++subgraph_no) { - auto* subgraph = interp_->subgraph(subgraph_no); - auto& execution_plan = subgraph->execution_plan(); - auto& nodes_and_registration = subgraph->nodes_and_registration(); - for (int execution_plan_index = 0; - execution_plan_index < execution_plan.size(); execution_plan_index++) { - int node_index = execution_plan[execution_plan_index]; - auto& node = nodes_and_registration[node_index].first; - const TfLiteRegistration& registration = - nodes_and_registration[node_index].second; - - if (registration.builtin_code == kTfLiteBuiltinDelegate) { - continue; - } - if (registration.builtin_code == kTfLiteBuiltinCustom && - litert::internal::kLiteRtDispatchOpCustomCode == - registration.custom_name) - continue; - for (int i = 0; i < node.inputs->size; ++i) { - int input_tensor_index = node.inputs->data[i]; - if (input_tensor_index == kTfLiteOptionalTensor) continue; - cpu_tensors_.insert(subgraph->tensor(input_tensor_index)); - } - } - } -} - -litert::Expected -LiteRtCompiledModelT::GetTensorBufferRequirements(const TfLiteTensor* tensor) { - // Use the buffer context to get the buffer requirements only if the tensor - // is not a CPU tensor. - if (cpu_tensors_.find(tensor) == cpu_tensors_.end()) { - auto requirements = buffer_context_->GetBufferRequirement(tensor); - if (requirements) { - return (*requirements)->Get(); - } - } else { - LITERT_LOG(LITERT_VERBOSE, "Tensor %s is shared with CPU.\n", tensor->name); - } - LiteRtTensorBufferRequirements litert_cpu_buffer_requirements; - LiteRtTensorBufferType cpu_buffer_type[] = { - kLiteRtTensorBufferTypeHostMemory}; - uint32_t cpu_buffer_strides[] = {0}; - auto res = LiteRtCreateTensorBufferRequirements( - /*num_supported_tensor_buffer_types=*/1, cpu_buffer_type, tensor->bytes, - /*num_strides=*/1, cpu_buffer_strides, &litert_cpu_buffer_requirements); - if (res != kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create CPU buffer requirements"); - } - cpu_buffer_requirements_[tensor] = - litert::TensorBufferRequirements(litert_cpu_buffer_requirements); - return litert_cpu_buffer_requirements; -} - -Expected -LiteRtCompiledModelT::GetInputBufferRequirements( - absl::string_view signature_key, size_t input_index) { - auto runner = GetSignatureRunner(signature_key); - if (runner == nullptr) { - return Unexpected(kLiteRtStatusErrorNotFound, - "Failed to get signature runner"); - } - auto input_names = runner->subgraph_input_names(); - if (input_index >= input_names.size()) { - return Unexpected(kLiteRtStatusErrorIndexOOB, "Input index out of range"); - } - auto input_name = input_names[input_index]; - auto* input_tensor = runner->input_tensor(input_name); - if (input_tensor == nullptr) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get input tensor"); - } - - return GetTensorBufferRequirements(input_tensor); -} - -Expected -LiteRtCompiledModelT::GetOutputBufferRequirements( - absl::string_view signature_key, size_t output_index) { - auto runner = GetSignatureRunner(signature_key); - if (runner == nullptr) { - return Unexpected(kLiteRtStatusErrorNotFound, - "Failed to get signature runner"); - } - auto output_names = runner->subgraph_output_names(); - if (output_index >= output_names.size()) { - return Unexpected(kLiteRtStatusErrorIndexOOB, "Output index out of range"); - } - auto output_name = output_names[output_index]; - auto* output_tensor = runner->output_tensor(output_name); - if (output_tensor == nullptr) { - return Unexpected(kLiteRtStatusErrorNotFound, - "Failed to get output tensor"); - } - - return GetTensorBufferRequirements(output_tensor); -} - -tflite::SignatureRunner* LiteRtCompiledModelT::GetSignatureRunner( - absl::string_view signature_key) { - if (signature_runners_.contains(signature_key)) { - return signature_runners_[signature_key]; - } - auto runner = interp_->GetSignatureRunner( - signature_key == LiteRtSignatureT::kDefaultSignatureKey - ? nullptr - : std::string(signature_key).c_str()); - signature_runners_[signature_key] = runner; - return runner; -} - -Expected LiteRtCompiledModelT::RegisterBuffer( - tflite::SignatureRunner* runner, TfLiteTensor* tensor, - const char* tensor_name, LiteRtTensorBuffer buffer, bool is_input, - std::vector& locked_buffers) { - bool backend_requires_cpu_buffer = false; - - auto requirements = buffer_context_->GetBufferRequirement(tensor); - if (requirements) { - auto supported_types = (*requirements)->SupportedTypes(); - if (!supported_types) { - return supported_types.Error(); - } - - for (auto& type : *supported_types) { - if (type == buffer->buffer_type()) { - // Register tensor buffer if it can be used by the backend. - buffer->Duplicate(); - TensorBuffer duplicated_buffer(buffer); - if (auto status = buffer_context_->RegisterTensorBuffer( - tensor, std::move(duplicated_buffer)); - status != kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to register tensor buffer"); - } - // Mark the tensor as non-CPU to avoid TFLite from allocating it. - tensor->allocation_type = kTfLiteNonCpu; - tensor->data.data = nullptr; - return {}; - } - if (type == kLiteRtTensorBufferTypeHostMemory) { - backend_requires_cpu_buffer = true; - } - } - } else { - // If the BufferRequirement is not registered, assumes the backend requires - // CPU buffer. - backend_requires_cpu_buffer = true; - } - - if (backend_requires_cpu_buffer) { - // When backend requires CPU buffer. - bool buffer_is_cpu_compatible = - buffer->buffer_type() == kLiteRtTensorBufferTypeHostMemory || - buffer->buffer_type() == kLiteRtTensorBufferTypeOpenCl; -#if defined(__ANDROID__) - if (buffer->buffer_type() == kLiteRtTensorBufferTypeAhwb) { - if (__builtin_available(android 26, *)) { - auto ahwb = buffer->GetAhwbBuffer(); - if (ahwb) { - // TODO: b/382330322 - Update logic to check if the AHWB (stride) is - // CPU compatible. - AHardwareBuffer_Desc desc; - AHardwareBuffer_describe(*ahwb, &desc); - buffer_is_cpu_compatible = true; - } - } - } -#endif - if (buffer_is_cpu_compatible) { - void* host_mem_addr; - if (auto status = LiteRtLockTensorBuffer(buffer, &host_mem_addr); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to lock the tensor buffer"); - } - locked_buffers.push_back(buffer); - TfLiteCustomAllocation custom_allocation{host_mem_addr, tensor->bytes}; - if (is_input) { - runner->SetCustomAllocationForInputTensor(tensor_name, - custom_allocation, - /*flags=*/0); - } else { - runner->SetCustomAllocationForOutputTensor(tensor_name, - custom_allocation, - /*flags=*/0); - } - return {}; - } - } - - // If the tensor is shared with CPU, register tensor buffer as is and let - // accelerator handle the conversion. - if (cpu_tensors_.find(tensor) != cpu_tensors_.end()) { - void* host_mem_addr; - if (auto status = LiteRtLockTensorBuffer(buffer, &host_mem_addr); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to lock the tensor buffer"); - } - locked_buffers.push_back(buffer); - TfLiteCustomAllocation custom_allocation{host_mem_addr, tensor->bytes}; - if (is_input) { - runner->SetCustomAllocationForInputTensor(tensor_name, custom_allocation, - /*flags=*/0); - } else { - runner->SetCustomAllocationForOutputTensor(tensor_name, custom_allocation, - /*flags=*/0); - } - return {}; - } - // TODO: b/382330322 - Add buffer conversion logic instead of returning error. - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "The given buffer type is not supported."); -} - -Expected LiteRtCompiledModelT::Run( - absl::string_view signature_key, - const std::vector& input_buffers, - const std::vector& output_buffers, bool& async) { - auto runner = GetSignatureRunner(signature_key); - if (runner == nullptr) { - return Unexpected(kLiteRtStatusErrorNotFound, - "Failed to get signature runner"); - } - size_t num_inputs = input_buffers.size(); - if (num_inputs != runner->subgraph_input_names().size()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Input buffer size mismatch"); - } - size_t num_outputs = output_buffers.size(); - if (num_outputs != runner->subgraph_output_names().size()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Output buffer size mismatch"); - } - - // In general output buffer events are assigned by the runtime and not the - // caller; here we check for any violation of that condition. - for (auto litert_output_buffer : output_buffers) { - if (litert_output_buffer->HasEvent()) { - return Error(kLiteRtStatusErrorInvalidArgument, - "Output buffers cannot have events attached"); - } - } - - // The collection of locked buffers. It is used to unlock the buffers after - // the inference is done. - std::vector locked_buffers; - locked_buffers.reserve(num_inputs + num_outputs); - auto unlock_buffers = absl::MakeCleanup([&locked_buffers]() { - for (auto locked_buffer : locked_buffers) { - LiteRtUnlockTensorBuffer(locked_buffer); - } - }); - for (int i = 0; i < num_inputs; ++i) { - const auto& input_name = runner->subgraph_input_names()[i]; - auto* input_tensor = runner->input_tensor(input_name); - auto res = - RegisterBuffer(runner, input_tensor, input_name, input_buffers[i], - /*is_input=*/true, locked_buffers); - if (!res) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - absl::StrCat("Failed to register input tensor buffer: ", - res.Error().Message())); - } - } - - for (int i = 0; i < runner->subgraph_output_names().size(); ++i) { - const auto& output_name = runner->subgraph_output_names()[i]; - auto* output_tensor = runner->output_tensor(output_name); - auto res = RegisterBuffer(runner, const_cast(output_tensor), - output_name, output_buffers[i], - /*is_input=*/false, locked_buffers); - if (!res) { - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrCat("Failed to register output tensor buffer: ", - res.Error().Message())); - } - } - - if (auto res = runner->AllocateTensors(); res != kTfLiteOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate tensors"); - } - - // Relay the intended async execution mode to DelegateKernel of Accelerator. - buffer_context_->SetAsyncExecutionMode(async); - - if (auto res = runner->Invoke(); res != kTfLiteOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, "Failed to invoke"); - } - - if (async) { - // If the caller requested async execution, then set async to true if any of - // the output buffers have been assigned a synchronization event. - async = false; - for (auto& tb : output_buffers) { - async |= tb->HasEvent(); - } - } else { - // If the caller has not requested async execution, then wait on - // synchronization events that have been attached to the outputs. - for (auto& tb : output_buffers) { - if (tb->HasEvent()) { - auto event = tb->GetEvent(); - if (auto status = litert::Event(*event, /*owned=*/false) - .Wait(/*timeout_in_ms=*/-1); - !status) { - return status; - } - } - } - } - - return {}; -} - -litert::Expected LiteRtCompiledModelT::RunCApi( - size_t signature_index, size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers, bool* async) { - if (signature_index >= signature_keys_.size()) { - return litert::Unexpected( - kLiteRtStatusErrorIndexOOB, - "Signature index is out of range of signature keys"); - } - std::vector input_buffers_vec; - input_buffers_vec.reserve(num_input_buffers); - for (int i = 0; i < num_input_buffers; ++i) { - input_buffers_vec.push_back(std::move(input_buffers[i])); - } - std::vector output_buffers_vec; - output_buffers_vec.reserve(num_output_buffers); - for (int i = 0; i < num_output_buffers; ++i) { - output_buffers_vec.push_back(std::move(output_buffers[i])); - } - bool async_ = async ? *async : false; - auto result = Run(*signature_keys_[signature_index], input_buffers_vec, - output_buffers_vec, async_); - if (async) { - *async = async_; - } - return result; -} diff --git a/tensorflow/lite/experimental/litert/runtime/compiled_model.h b/tensorflow/lite/experimental/litert/runtime/compiled_model.h deleted file mode 100644 index 792efae934e4b8..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compiled_model.h +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILED_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILED_MODEL_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "tensorflow/compiler/mlir/lite/allocation.h" -#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/model_builder.h" - -// The LiteRtCompiledModelT is internal implementation of CompiledModel C++ API. -class LiteRtCompiledModelT { - public: - using Ptr = std::unique_ptr; - - LiteRtCompiledModelT() = default; - ~LiteRtCompiledModelT() = default; - - // Creates a LiteRtCompiledModelT from a LiteRtModel object. - // The model is loaded into memory and the caller takes ownership of the - // returned object. - static litert::Expected Create( - LiteRtEnvironmentT* env, LiteRtModel model, - LiteRtCompilationOptions jit_compilation_options = nullptr); - - // Returns the buffer requirements for the n-th input tensor. The returned - // LiteRtTensorBufferRequirements is used to create the input tensor - // buffer. - litert::Expected GetInputBufferRequirements( - absl::string_view signature_key, size_t input_index); - - // The same as GetInputBufferRequirements() for C API. - litert::Expected - GetInputBufferRequirementsCApi(size_t signature_index, size_t input_index) { - if (signature_index >= signature_keys_.size()) { - return litert::Unexpected( - kLiteRtStatusErrorIndexOOB, - "Signature index is out of range of signature keys"); - } - return GetInputBufferRequirements(*signature_keys_[signature_index], - input_index); - } - - // Returns the buffer requirements for the n-th output tensor. The returned - // LiteRtTensorBufferRequirements is used to create the output tensor - // buffer. - litert::Expected GetOutputBufferRequirements( - absl::string_view signature_key, size_t output_index); - - // The same as GetOutputBufferRequirements() for C API. - litert::Expected - GetOutputBufferRequirementsCApi(size_t signature_index, size_t output_index) { - if (signature_index >= signature_keys_.size()) { - return litert::Unexpected( - kLiteRtStatusErrorIndexOOB, - "Signature index is out of range of signature keys"); - } - return GetOutputBufferRequirements(*signature_keys_[signature_index], - output_index); - } - - // Runs the model of the given signature with the provided input/output - // litert::TensorBuffers. If parameter `async` is true, then the model is run - // asynchronously, if possible. Upon returning, the function sets parameter - // `async` to true if asynchronous execution was requested and possible, - // otherwise it sets it to false. - litert::Expected Run( - absl::string_view signature_key, - const std::vector& input_buffers, - const std::vector& output_buffers, bool& async); - - // The same as Run() for C API. - litert::Expected RunCApi(size_t signature_index, - size_t num_input_buffers, - LiteRtTensorBuffer* input_buffers, - size_t num_output_buffers, - LiteRtTensorBuffer* output_buffers, - bool* async); - - private: - // Initializes the internal TFLite interpreter and related objects. - // This is called in the public Create*() methods. - // The flatbuffer_model_ must be set before calling this method. - litert::Expected InitializeRuntime(); - - // Handles any JIT compilation and intializes the flatbuffer_model_ and - // related field within the compiled model. - // - // If no JIT compilation is requested, the compiled model will point to the - // underlying tflite::Model* owned by the input litert model. The compiled - // models alloc_ and model_buf_ will be nullptr as these are only relevant - // when compiled model owns a flatbuffer. - // - // If JIT compilation does occur, a new flatbuffer owned by the compiled model - // will be serialized from the result of compilation. The alloc_ and - // model_buf_ will be set for storage of the new flatbuffer. - // - // NOTE: JIT compilation invalidates the input litert model. - // TODO: Design a better abstraction for optional ownership for flatbuffer, - // consider caching JIT result. - litert::Expected InitializeModel(LiteRtModelT& model, - LiteRtHwAcceleratorSet hw_accelerators, - LiteRtEnvironmentT& env); - - // Returns the base address of the flatbuffer memory. - // - // If no JIT compilation has taken place, this points to flatbuffer memory - // owned by the incoming litert model (litert models always owns their - // flatbuffer memory until serialization). - // - // If JIT compilation has taken place, this points to the base address of the - // a newly serialized flatbuffer which is owned by the compiled model (in - // model_buf_); - // - // NOTE: This should never be nullptr after initialization. - const char* GetModelBase() { - if (fb_model_ == nullptr) { - return nullptr; - } - - // fb_model_->allocation is only null when the flatbuffer is built with - // BuildFlatBufferFromModel, which is not currently in use in either - // litert::LoadModel or LiteRtCompiledModelT::Create. - const auto* alloc = fb_model_->allocation(); - if (alloc) { - // NOTE: During JIT, alloc->base() == model_buf_.Data(), which is owned - // by the compiled model. Otherwise, model_buf_.Data() is nullptr and - // alloc->base() points a buffer owned by the incoming litert model. - return reinterpret_cast(alloc->base()); - } - - return nullptr; - } - - // Returns the buffer requirements for the given tensor. - litert::Expected GetTensorBufferRequirements( - const TfLiteTensor* tensor); - - // Returns the SignatureRunner for the given signature key. - // If the signature key is not found, returns nullptr. - tflite::SignatureRunner* GetSignatureRunner(absl::string_view signature_key); - - // Registers the TensorBuffer for the given tensor with the SignatureRunner. - // If the TensorBuffer can be directly consumed as CPU Tensors, they'll be - // locked and use it with CustomAllocation. The locked buffer is kept in the - // `locked_buffers`. Caller is responsible for unlocking of these buffers. - // If the TensorBuffer can be consumed by the delegate, then `tensor` will be - // marked as non-CPU to avoid TFLite from allocating it. - litert::Expected RegisterBuffer( - tflite::SignatureRunner* runner, TfLiteTensor* tensor, - const char* tensor_name, LiteRtTensorBuffer buffer, bool is_input, - std::vector& locked_buffers); - - void RegisterDelegate(tflite::TfLiteOpaqueDelegateUniquePtr&& delegate) { - delegates_.push_back(std::move(delegate)); - } - - // Checks the CPU Tensors and stores them in the `cpu_tensors_` set. - void CheckCpuTensors(); - - // Map from signature key to SignatureRunner. This is used to lazy calling - // GetSignatureRunner() which is expensive. - absl::flat_hash_map - signature_runners_; - - // The buffer requirement maps for CPU buffers. For delegates with CPU - // buffers, they don't register TensorBufferRequirements. Instead, the - // CompiledModel creates the TensorBufferRequirements and stores them - // in this map. - absl::flat_hash_map - cpu_buffer_requirements_; - - // The Interpreter and related objects used to run the model. - std::unique_ptr<::tflite::Interpreter> interp_; - std::unique_ptr<::tflite::FlatBufferModel> fb_model_; - litert::OwningBufferRef model_buf_; - std::vector signature_keys_; - - // The ExternalLiteRtBufferContext used to register tensor buffers with - // Delegates. - // Note: The ExternalLiteRtBufferContext must be destroyed after the - // Interpreter. - std::unique_ptr - buffer_context_; - - std::vector delegates_; - - // The set of CPU Tensors. This is used to manage TensorBufferRequirements - // for shared CPU Tensors. - absl::flat_hash_set cpu_tensors_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILED_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/compiled_model_test.cc b/tensorflow/lite/experimental/litert/runtime/compiled_model_test.cc deleted file mode 100644 index 76797ac5074eed..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compiled_model_test.cc +++ /dev/null @@ -1,544 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/compiled_model.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -namespace litert { -namespace { - -using ::testing::ElementsAre; -using ::testing::FloatNear; -using ::testing::Pointwise; - -// Creates a tensor buffer of the given tensor, buffer type, and size. -Expected CreateBufferOfType( - const LiteRtTensorT& tensor, LiteRtTensorBufferType buffer_type, - size_t bytes) { - const LiteRtRankedTensorType ranked_tensor_type = - tensor.Type().second.ranked_tensor_type; - - LiteRtTensorBufferT* tensor_buffer; - LITERT_RETURN_IF_ERROR(LiteRtCreateManagedTensorBuffer( - buffer_type, &ranked_tensor_type, bytes, &tensor_buffer)); - - return tensor_buffer; -} - -// Creates input or output tensor buffers of the given model, buffer type and -// size. -Expected> CreateInputOutputBuffersOfType( - LiteRtModelT& model, absl::string_view signature_key, - LiteRtTensorBufferType buffer_type, size_t bytes, bool is_input) { - LITERT_ASSIGN_OR_RETURN(const LiteRtSignatureT& signature, - model.FindSignature(signature_key)); - const LiteRtSubgraphT& subgraph = signature.GetSubgraph(); - - const std::vector& tensors = - is_input ? subgraph.Inputs() : subgraph.Outputs(); - - std::vector tensor_buffers; - tensor_buffers.reserve(tensors.size()); - - for (int i = 0; i < tensors.size(); ++i) { - LITERT_ASSIGN_OR_RETURN( - LiteRtTensorBufferT * tensor_buffer, - CreateBufferOfType(*tensors[i], buffer_type, bytes)); - tensor_buffers.push_back(tensor_buffer); - } - return tensor_buffers; -} - -// Creates input buffers of the given model, buffer type, and size. -Expected> CreateInputBuffersOfType( - LiteRtModelT& model, absl::string_view signature_key, - LiteRtTensorBufferType buffer_type, size_t bytes) { - return CreateInputOutputBuffersOfType(model, signature_key, buffer_type, - bytes, /*is_input=*/true); -} - -// Creates output buffers of the given model, buffer type, and size. -Expected> CreateOutputBuffersOfType( - LiteRtModelT& model, absl::string_view signature_key, - LiteRtTensorBufferType buffer_type, size_t bytes) { - return CreateInputOutputBuffersOfType(model, signature_key, buffer_type, - bytes, /*is_input=*/false); -} - -// Creates a tensor buffer of the given tensor and buffer requirements. -Expected CreateBufferFromRequirements( - const LiteRtTensorT& tensor, - const LiteRtTensorBufferRequirementsT& requirements) { - return CreateBufferOfType(tensor, requirements.SupportedBufferTypes().at(0), - requirements.BufferSize()); -} - -// Creates input or output tensor buffers of the given model and requirements. -Expected> -CreateInputOutputBuffersFromRequirements(LiteRtModelT& model, - absl::string_view signature_key, - LiteRtCompiledModelT& compiled_model, - bool is_input) { - LITERT_ASSIGN_OR_RETURN(const LiteRtSignatureT& signature, - model.FindSignature(signature_key)); - const LiteRtSubgraphT& subgraph = signature.GetSubgraph(); - - const std::vector& tensors = - is_input ? subgraph.Inputs() : subgraph.Outputs(); - - std::vector tensor_buffers; - tensor_buffers.reserve(tensors.size()); - - for (int i = 0; i < tensors.size(); ++i) { - Expected requirements_expected = - is_input ? compiled_model.GetInputBufferRequirements(signature_key, i) - : compiled_model.GetOutputBufferRequirements(signature_key, i); - LITERT_ASSIGN_OR_RETURN(LiteRtTensorBufferRequirementsT * requirements, - requirements_expected); - - LITERT_ASSIGN_OR_RETURN( - LiteRtTensorBufferT * tensor_buffer, - CreateBufferFromRequirements(*tensors[i], *requirements)); - tensor_buffers.push_back(tensor_buffer); - } - return tensor_buffers; -} - -// Creates input buffers of the given model and requirements. -Expected> CreateInputBuffersFromRequirements( - LiteRtModelT& model, absl::string_view signature_key, - LiteRtCompiledModelT& compiled_model) { - return CreateInputOutputBuffersFromRequirements(model, signature_key, - compiled_model, - /*is_input=*/true); -} - -// Creates output buffers of the given model and requirements. -Expected> CreateOutputBuffersFromRequirements( - LiteRtModelT& model, absl::string_view signature_key, - LiteRtCompiledModelT& compiled_model) { - return CreateInputOutputBuffersFromRequirements(model, signature_key, - compiled_model, - /*is_input=*/false); -} - -TEST(CompiledModelTest, Basic) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(LiteRtEnvironmentT::Ptr env, - LiteRtEnvironmentT::CreateWithOptions({})); - LiteRtEnvironmentT* env_ptr = env.release(); - - // Create LiteRtModel and check signatures. - std::string path = testing::GetTestFilePath(kModelFileName); - LiteRtModel model; - ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk); - - absl::Span signatures = model->Signatures(); - ASSERT_EQ(signatures.size(), 1); - absl::string_view signature_key = signatures[0]->Key(); - EXPECT_EQ(signature_key, LiteRtSignatureT::kDefaultSignatureKey); - - const std::vector& input_names = signatures[0]->InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - const std::vector& output_names = signatures[0]->OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.add")); - - // Create CompiledModel with options. - LiteRtCompilationOptions jit_compilation_options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&jit_compilation_options), - kLiteRtStatusOk); - ASSERT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - jit_compilation_options, kLiteRtHwAcceleratorCpu), - kLiteRtStatusOk); - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtCompiledModelT::Ptr compiled_model, - LiteRtCompiledModelT::Create(env_ptr, model, jit_compilation_options)); - LiteRtDestroyCompilationOptions(jit_compilation_options); - - // Check CompiledModel buffer requirements. - // input and output expect host memory. - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * input_buffer_requirements_arg0, - compiled_model->GetInputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*input_index=*/0)); - const std::vector& input_buffer_types_arg0 = - input_buffer_requirements_arg0->SupportedBufferTypes(); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * input_buffer_requirements_arg1, - compiled_model->GetInputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*input_index=*/1)); - const std::vector& input_buffer_types_arg1 = - input_buffer_requirements_arg1->SupportedBufferTypes(); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * output_buffer_requirements, - compiled_model->GetOutputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*output_index=*/0)); - const std::vector& output_buffer_types = - output_buffer_requirements->SupportedBufferTypes(); - EXPECT_THAT(output_buffer_types, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - // Create and fill input and output LiteRtTensorBuffers. Buffers are - // created to match CompiledModel's TensorBufferRequirements. - LITERT_ASSERT_OK_AND_ASSIGN(std::vector input_buffers, - CreateInputBuffersFromRequirements( - *model, signature_key, *compiled_model)); - LITERT_ASSERT_OK_AND_ASSIGN(std::vector output_buffers, - CreateOutputBuffersFromRequirements( - *model, signature_key, *compiled_model)); - - LiteRtTensorBuffer& input_0_buffer = input_buffers[0]; - { - TensorBuffer cpu_buffer(input_0_buffer, /*owned=*/false); - cpu_buffer.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); - } - LiteRtTensorBuffer& input_1_buffer = input_buffers[1]; - { - TensorBuffer cpu_buffer(input_1_buffer, /*owned=*/false); - cpu_buffer.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); - } - - // Execute model. - bool async = false; - compiled_model->Run(signature_key, input_buffers, output_buffers, async); - - // Check model output. - { - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_buffers[0], &host_mem_addr), - kLiteRtStatusOk); - absl::Span output = absl::MakeSpan( - static_cast(host_mem_addr), kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_buffers[0]), kLiteRtStatusOk); - } - - // Since Buffers in LiteRtTensorBuffer, we need to destroy them explicitly. - for (auto& input_buffer : input_buffers) { - LiteRtDestroyTensorBuffer(input_buffer); - } - for (auto& output_buffer : output_buffers) { - LiteRtDestroyTensorBuffer(output_buffer); - } - - LiteRtDestroyModel(model); - LiteRtDestroyEnvironment(env_ptr); -} - -TEST(CompiledModelTest, UseAhwbBuffer) { -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices"; -#endif - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(LiteRtEnvironmentT::Ptr env, - LiteRtEnvironmentT::CreateWithOptions({})); - LiteRtEnvironmentT* env_ptr = env.release(); - - // Create LiteRtModel and check signatures. - std::string path = testing::GetTestFilePath(kModelFileName); - LiteRtModel model; - ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk); - - absl::Span signatures = model->Signatures(); - ASSERT_EQ(signatures.size(), 1); - absl::string_view signature_key = signatures[0]->Key(); - EXPECT_EQ(signature_key, LiteRtSignatureT::kDefaultSignatureKey); - - const std::vector& input_names = signatures[0]->InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - const std::vector& output_names = signatures[0]->OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.add")); - - // Create CompiledModel with options. - LiteRtCompilationOptions jit_compilation_options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&jit_compilation_options), - kLiteRtStatusOk); - ASSERT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - jit_compilation_options, kLiteRtHwAcceleratorCpu), - kLiteRtStatusOk); - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtCompiledModelT::Ptr compiled_model, - LiteRtCompiledModelT::Create(env_ptr, model, jit_compilation_options)); - LiteRtDestroyCompilationOptions(jit_compilation_options); - - // Check input and output buffer requirements expect host memory. - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * input_buffer_requirements_arg0, - compiled_model->GetInputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*input_index=*/0)); - const std::vector& input_buffer_types_arg0 = - input_buffer_requirements_arg0->SupportedBufferTypes(); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * input_buffer_requirements_arg1, - compiled_model->GetInputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*input_index=*/1)); - const std::vector& input_buffer_types_arg1 = - input_buffer_requirements_arg1->SupportedBufferTypes(); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * output_buffer_requirements, - compiled_model->GetOutputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*output_index=*/0)); - const std::vector& output_buffer_types = - output_buffer_requirements->SupportedBufferTypes(); - EXPECT_THAT(output_buffer_types, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - // Create and fill input and output buffers. CompiledModel's - // TensorBufferRequirements expect host memory,but we create AHWB buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - CreateInputBuffersOfType(*model, signature_key, - kLiteRtTensorBufferTypeAhwb, - sizeof(float) * kTestInput0Size)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - CreateOutputBuffersOfType(*model, signature_key, - kLiteRtTensorBufferTypeAhwb, - sizeof(float) * kTestOutputSize)); - - LiteRtTensorBuffer& input_0_buffer = input_buffers[0]; - EXPECT_EQ(input_0_buffer->buffer_type(), kLiteRtTensorBufferTypeAhwb); - { - TensorBuffer ahwb_buffer(input_0_buffer, /*owned=*/false); - ahwb_buffer.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); - } - LiteRtTensorBuffer& input_1_buffer = input_buffers[1]; - { - TensorBuffer ahwb_buffer(input_1_buffer, /*owned=*/false); - ahwb_buffer.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); - } - - // Execute model. - bool async = false; - compiled_model->Run(signature_key, input_buffers, output_buffers, async); - - // Check model output. - { - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_buffers[0], &host_mem_addr), - kLiteRtStatusOk); - absl::Span output = absl::MakeSpan( - static_cast(host_mem_addr), kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_buffers[0]), kLiteRtStatusOk); - } - - // Since Buffers in LiteRtTensorBuffer, we need to destroy them explicitly. - for (auto& input_buffer : input_buffers) { - LiteRtDestroyTensorBuffer(input_buffer); - } - for (auto& output_buffer : output_buffers) { - LiteRtDestroyTensorBuffer(output_buffer); - } - - LiteRtDestroyModel(model); - LiteRtDestroyEnvironment(env_ptr); -} - -TEST(CompiledModelTest, UseOpenCLBuffer) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported In msan"; -#endif - - if (!litert::internal::OpenClBuffer::IsSupported()) { - GTEST_SKIP() << "OpenCL buffers are not supported on this platform; " - "skipping the test"; - } - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(LiteRtEnvironmentT::Ptr env, - LiteRtEnvironmentT::CreateWithOptions({})); - LiteRtEnvironmentT* env_ptr = env.release(); - - // Create LiteRtModel and check signatures. - std::string path = testing::GetTestFilePath(kModelFileName); - LiteRtModel model; - ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk); - - absl::Span signatures = model->Signatures(); - ASSERT_EQ(signatures.size(), 1); - absl::string_view signature_key = signatures[0]->Key(); - EXPECT_EQ(signature_key, LiteRtSignatureT::kDefaultSignatureKey); - - const std::vector& input_names = signatures[0]->InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - const std::vector& output_names = signatures[0]->OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.add")); - - // Create CompiledModel with options. - LiteRtCompilationOptions jit_compilation_options; - ASSERT_EQ(LiteRtCreateCompilationOptions(&jit_compilation_options), - kLiteRtStatusOk); - ASSERT_EQ(LiteRtSetCompilationOptionsHardwareAccelerators( - jit_compilation_options, kLiteRtHwAcceleratorCpu), - kLiteRtStatusOk); - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtCompiledModelT::Ptr compiled_model, - LiteRtCompiledModelT::Create(env_ptr, model, jit_compilation_options)); - LiteRtDestroyCompilationOptions(jit_compilation_options); - - // Check ComiledModel buffer requirements. - // input and output expect host memory. - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * input_buffer_requirements_arg0, - compiled_model->GetInputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*input_index=*/0)); - const std::vector& input_buffer_types_arg0 = - input_buffer_requirements_arg0->SupportedBufferTypes(); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * input_buffer_requirements_arg1, - compiled_model->GetInputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*input_index=*/1)); - const std::vector& input_buffer_types_arg1 = - input_buffer_requirements_arg1->SupportedBufferTypes(); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferRequirementsT * output_buffer_requirements, - compiled_model->GetOutputBufferRequirements( - /*signature_key=*/LiteRtSignatureT::kDefaultSignatureKey, - /*output_index=*/0)); - const std::vector& output_buffer_types = - output_buffer_requirements->SupportedBufferTypes(); - EXPECT_THAT(output_buffer_types, - ElementsAre(kLiteRtTensorBufferTypeHostMemory)); - - // Create and fill input and output buffers. CompiledModel's - // TensorBufferRequirements expect host memory,but we create OpenCL buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - CreateInputBuffersOfType(*model, signature_key, - kLiteRtTensorBufferTypeOpenCl, - sizeof(float) * kTestInput0Size)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - CreateOutputBuffersOfType(*model, signature_key, - kLiteRtTensorBufferTypeOpenCl, - sizeof(float) * kTestOutputSize)); - - // Fill model inputs. - LiteRtTensorBuffer& input_0_buffer = input_buffers[0]; - EXPECT_EQ(input_0_buffer->buffer_type(), kLiteRtTensorBufferTypeOpenCl); - { - TensorBuffer opencl_buffer(input_0_buffer, /*owned=*/false); - opencl_buffer.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); - } - LiteRtTensorBuffer& input_1_buffer = input_buffers[1]; - { - TensorBuffer opencl_buffer(input_1_buffer, /*owned=*/false); - opencl_buffer.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); - } - - // Execute model. - bool async = false; - compiled_model->Run(signature_key, input_buffers, output_buffers, async); - - // Check model output. - { - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_buffers[0], &host_mem_addr), - kLiteRtStatusOk); - absl::Span output = absl::MakeSpan( - static_cast(host_mem_addr), kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_buffers[0]), kLiteRtStatusOk); - } - - // Since Buffers in LiteRtTensorBuffer, we need to destroy them explicitly. - for (auto& input_buffer : input_buffers) { - LiteRtDestroyTensorBuffer(input_buffer); - } - for (auto& output_buffer : output_buffers) { - LiteRtDestroyTensorBuffer(output_buffer); - } - - LiteRtDestroyModel(model); - LiteRtDestroyEnvironment(env_ptr); -} -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/compiler/BUILD b/tensorflow/lite/experimental/litert/runtime/compiler/BUILD deleted file mode 100644 index 43bef76096cbe1..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compiler/BUILD +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_test( - name = "jit_compilation_qualcomm_test", - srcs = ["jit_compilation_qualcomm_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:simple_model", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:qnn_compiler_plugin_so", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch:dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - deps = [ - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "jit_compilation_mediatek_test", - srcs = ["jit_compilation_mediatek_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:simple_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler:compiler_plugin_so", - "//tensorflow/lite/experimental/litert/vendors/mediatek/dispatch:dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - "no_oss", - "nobuilder", - "notap", - ], - deps = [ - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_mediatek_test.cc b/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_mediatek_test.cc deleted file mode 100644 index 4e3b2f24d87c2c..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_mediatek_test.cc +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -constexpr absl::string_view kCompilerPluginLibSearchPath = "/data/local/tmp"; -constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - -using testing::FloatNear; -using testing::Pointwise; - -TEST(JitCompilation, MediaTek) { - const std::array environment_options = { - litert::Environment::Option{ - /*.tag=*/litert::Environment::OptionTag::CompilerPluginLibraryDir, - /*.value=*/kCompilerPluginLibSearchPath, - }, - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - LITERT_ASSERT_OK_AND_ASSIGN(auto environment, - litert::Environment::Create(environment_options)); - - auto model_path = litert::testing::GetTestFilePath(kModelFileName); - LITERT_ASSERT_OK_AND_ASSIGN(auto model, - litert::Model::CreateFromFile(model_path)); - - auto num_signatures = model.GetNumSignatures(); - ASSERT_EQ(num_signatures, 1); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "MediaTek NPU"; -#endif - - LITERT_ASSERT_OK_AND_ASSIGN(auto compiled_model, - litert::CompiledModel::Create( - environment, model, kLiteRtHwAcceleratorNpu)); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto input_buffers, - compiled_model.CreateInputBuffers(/*signature_index=*/0)); - EXPECT_EQ(input_buffers.size(), 2); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto output_buffers, - compiled_model.CreateOutputBuffers(/*signature_index=*/0)); - EXPECT_EQ(output_buffers.size(), 1); - - LITERT_ASSERT_OK(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - LITERT_ASSERT_OK(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(/*signature_index=*/0, input_buffers, output_buffers); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} diff --git a/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc b/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc deleted file mode 100644 index 1f7a3366f86af5..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" - -constexpr absl::string_view kCompilerPluginLibSearchPath = "/data/local/tmp"; -constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - -using testing::FloatNear; -using testing::Pointwise; - -TEST(JitCompilation, Qualcomm) { - const std::array environment_options = { - litert::Environment::Option{ - /*.tag=*/litert::Environment::OptionTag::CompilerPluginLibraryDir, - /*.value=*/kCompilerPluginLibSearchPath, - }, - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - LITERT_ASSERT_OK_AND_ASSIGN(auto environment, - litert::Environment::Create(environment_options)); - - auto model_path = litert::testing::GetTestFilePath(kModelFileName); - LITERT_ASSERT_OK_AND_ASSIGN(auto model, - litert::Model::CreateFromFile(model_path)); - - auto num_signatures = model.GetNumSignatures(); - ASSERT_EQ(num_signatures, 1); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "Qualcomm HTP"; -#endif - - LITERT_ASSERT_OK_AND_ASSIGN(auto compiled_model, - litert::CompiledModel::Create( - environment, model, kLiteRtHwAcceleratorNpu)); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto input_buffers, - compiled_model.CreateInputBuffers(/*signature_index=*/0)); - EXPECT_EQ(input_buffers.size(), 2); - - LITERT_ASSERT_OK_AND_ASSIGN( - auto output_buffers, - compiled_model.CreateOutputBuffers(/*signature_index=*/0)); - EXPECT_EQ(output_buffers.size(), 1); - - LITERT_ASSERT_OK(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - LITERT_ASSERT_OK(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(/*signature_index=*/0, input_buffers, output_buffers); - - // Check model output. - { - LITERT_ASSERT_OK_AND_ASSIGN( - auto lock_and_addr, - litert::TensorBufferScopedLock::Create(output_buffers[0])); - auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/BUILD b/tensorflow/lite/experimental/litert/runtime/dispatch/BUILD deleted file mode 100644 index 6e8886c434994f..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/BUILD +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:special_rule.bzl", "gles_linkopts") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - "//third_party/odml/infra/perf/mobile_tests/litert:__subpackages__", - ], -) - -# Dispatch API implementation, it is used by the dispatch delegate to call the vendor's dispatch -# API. -cc_library( - name = "dispatch", - srcs = [ - "litert_dispatch.cc", - ], - hdrs = [ - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch.h", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_api.h", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_event", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - "//tensorflow/lite/experimental/litert/core:dynamic_loading", - "//tensorflow/lite/experimental/litert/core:version", - ], - alwayslink = 1, -) - -cc_library( - name = "dispatch_delegate", - srcs = [ - "dispatch_delegate.cc", - "dispatch_delegate_kernel.cc", - ], - hdrs = [ - "dispatch_delegate_kernel.h", - "dispatch_delegate_options.h", - "//tensorflow/lite/experimental/litert/c:litert_dispatch_delegate.h", - "//tensorflow/lite/experimental/litert/cc:litert_dispatch_delegate.h", - ], - deps = [ - ":dispatch", - "//tensorflow/lite/c:c_api", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core/c:c_api_opaque_without_op_resolver", - "//tensorflow/lite/delegates/utils:simple_opaque_delegate", - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_environment_options", - "//tensorflow/lite/experimental/litert/c:litert_event", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core:dispatch_op_schema", - "//tensorflow/lite/experimental/litert/core:environment_options", - "//tensorflow/lite/experimental/litert/runtime:external_litert_buffer_context", - "//tensorflow/lite/experimental/litert/runtime:tfl_utils", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "dispatch_delegate_google_tensor_test", - srcs = ["dispatch_delegate_google_tensor_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/shared_input_cpu_npu.tflite", - "//tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch:dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }) + gles_linkopts(), - deps = [ - ":dispatch_delegate", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/darwinn/driver_shared/fence:fence_test_util", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_compilation_options", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_event", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/core/model:model_buffer", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/runtime:external_litert_buffer_context", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - ], -) - -cc_test( - name = "dispatch_delegate_qualcomm_test", - srcs = ["dispatch_delegate_qualcomm_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/shared_input_cpu_npu.tflite", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch:dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - deps = [ - ":dispatch_delegate", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/cc:litert_compilation_options", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_dispatch_delegate", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/runtime:external_litert_buffer_context", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "dispatch_delegate_mediatek_test", - srcs = ["dispatch_delegate_mediatek_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/shared_input_cpu_npu.tflite", - "//tensorflow/lite/experimental/litert/vendors/mediatek/dispatch:dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - "no_oss", - "nobuilder", - "notap", - ], - deps = [ - ":dispatch_delegate", - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/cc:litert_compilation_options", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/core/model:model_buffer", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/runtime:external_litert_buffer_context", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/README.md b/tensorflow/lite/experimental/litert/runtime/dispatch/README.md deleted file mode 100644 index 5a2e33e0806a8c..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/README.md +++ /dev/null @@ -1,20 +0,0 @@ -## Google Tensor - -Test case can dispatch_delegate_google_tensor_test can be run on a device with a -Pixel 9 device with the following comands - -$ ../../google/run_test_on_android.sh dispatch_delegate_google_tensor_test - -## Qualcomm - -Test case can dispatch_delegate_qualcomm_test can be run on a Samsung S24 device -with the following comands - -$ ../../google/run_test_on_android.sh dispatch_delegate_qualcomm_test - -## MediaTek - -Test case can dispatch_delegate_mediatek_test can be run on a device with a -MetiaTek mt6989 SoC with the following comands - -$ ../../google/run_test_on_android.sh dispatch_delegate_mediatek_test diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate.cc deleted file mode 100644 index 2b69430a2eae19..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate.cc +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h" -#include "tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -namespace { - -using ::litert::internal::kLiteRtDispatchOpCustomCode; - -// A TFL Delegate that can recognize subgraphs that run on Dispatch API capable -// accelerators, e.g. TPU, DSP, ... It replaces such subgraphs and offloads -// their work through the Dispatch API. -class DispatchDelegate : public tflite::SimpleOpaqueDelegateInterface { - public: - static TfLiteOpaqueDelegate* Create(LiteRtDispatchDelegateOptions* options_) { - litert::DispatchDelegateOptionsPtr options( - options_, LiteRtDestroyDispatchDelegateOptions); - if (!options) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return nullptr; - } - - std::unique_ptr managed_sb_delegate( - new DispatchDelegate(std::move(options))); - return tflite::TfLiteOpaqueDelegateFactory::CreateSimpleDelegate( - std::move(managed_sb_delegate), - kTfLiteDelegateFlagsAllowDynamicTensors); - } - - bool IsNodeSupportedByDelegate(const TfLiteOperator* op, - const TfLiteOpaqueNode* node, - TfLiteOpaqueContext* context) const override; - - TfLiteStatus Initialize(TfLiteOpaqueContext* context) override; - - const char* Name() const override; - - std::unique_ptr - CreateDelegateKernelInterface() override; - - private: - static constexpr absl::string_view kDelegateName = "DispatchDelegate"; - - explicit DispatchDelegate(litert::DispatchDelegateOptionsPtr&& options) - : options_(std::move(options)) {} - - litert::DispatchDelegateOptionsPtr options_; - int dispatch_graph_name_id_ = 0; -}; - -bool DispatchDelegate::IsNodeSupportedByDelegate( - const TfLiteOperator* op, const TfLiteOpaqueNode* node, - TfLiteOpaqueContext* context) const { - auto custom_code = absl::string_view(TfLiteOperatorGetCustomName(op)); - return custom_code == kLiteRtDispatchOpCustomCode; -} - -TfLiteStatus DispatchDelegate::Initialize(TfLiteOpaqueContext* context) { - return kTfLiteOk; -} - -const char* DispatchDelegate::Name() const { return kDelegateName.data(); } - -std::unique_ptr -DispatchDelegate::CreateDelegateKernelInterface() { - std::string dispatch_graph_name = - absl::StrFormat("DispatchGraph_%d", dispatch_graph_name_id_++); - - auto kernel = litert::internal::DispatchDelegateKernel::Create( - std::move(dispatch_graph_name), *options_); - if (kernel) { - return std::move(*kernel); - } else { - LITERT_FATAL("Failed to create a dispatch delegate kernel: %s", - kernel.Error().Message().c_str()); - return nullptr; - } -} - -} // namespace - -LiteRtDispatchDelegateOptions* LiteRtCreateDefaultDispatchDelegateOptions( - LiteRtEnvironmentOptions environment_options) { - return new LiteRtDispatchDelegateOptions(environment_options); -} - -TfLiteStatus LiteRtAddDispatchDelegateOption( - LiteRtDispatchDelegateOptions* options, LiteRtDispatchOption option) { - if (!options) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kTfLiteError; - } - - options->AddOption(option); - return kTfLiteOk; -} - -TfLiteStatus LiteRtDispatchDelegateAddAllocBaseOption( - LiteRtDispatchDelegateOptions* options, const void* alloc_base) { - AddAllocBaseOption(alloc_base, *options); - return kTfLiteOk; -} - -TfLiteStatus LiteRtDispatchDelegateAddAllocFdOption( - LiteRtDispatchDelegateOptions* options, int alloc_fd) { - AddAllocFdOption(alloc_fd, *options); - return kTfLiteOk; -} - -void LiteRtDestroyDispatchDelegateOptions( - LiteRtDispatchDelegateOptions* options) { - delete options; -} - -TfLiteOpaqueDelegate* LiteRtCreateDispatchDelegate( - LiteRtEnvironmentOptions environment_options, - LiteRtDispatchDelegateOptions* options) { - if (!options) { - options = LiteRtCreateDefaultDispatchDelegateOptions(environment_options); - } - return DispatchDelegate::Create(options); -} - -void LiteRtDestroyDispatchDelegate(TfLiteOpaqueDelegate* delegate) { - tflite::TfLiteOpaqueDelegateFactory::DeleteSimpleDelegate(delegate); -} - -namespace litert { - -DispatchDelegateOptionsPtr CreateDispatchDelegateOptionsPtr( - LiteRtEnvironmentOptions environment_options) { - return {LiteRtCreateDefaultDispatchDelegateOptions(environment_options), - LiteRtDestroyDispatchDelegateOptions}; -} - -DispatchDelegatePtr CreateDispatchDelegatePtr( - LiteRtEnvironmentOptions environment_options, - DispatchDelegateOptionsPtr&& options) { - return DispatchDelegatePtr( - LiteRtCreateDispatchDelegate(environment_options, options.release()), - LiteRtDestroyDispatchDelegate); -} -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc deleted file mode 100644 index b75f91627cb0db..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc +++ /dev/null @@ -1,529 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" - -#if defined(__ANDROID__) -#include "platforms/darwinn/tachyon/core/fence/fence.h" -#endif -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/darwinn/driver_shared/fence/fence_test_util.h" -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_event.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/signature_runner.h" - -using litert::testing::MakeRuntimeFromTestFileWithNpuModel; -using testing::FloatNear; -using testing::Pointwise; -using Fence = std::shared_ptr; -using ::testing::ElementsAre; - -namespace litert { -namespace { - -constexpr absl::string_view kNpuFile = kGoogleTensorModelFileName; -constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; -constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - -litert::Expected CreateDefaultEnvironment() { - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - return litert::Environment::Create(absl::MakeConstSpan(environment_options)); -} - -TEST(DispatchDelegate, GoogleTensorCpuBuffer) { - LITERT_ASSERT_OK_AND_ASSIGN( - testing::TflRuntime::Ptr runtime, - MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile)); - tflite::Interpreter& interpreter = runtime->Interpreter(); - - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, CreateDefaultEnvironment()); - - internal::ExternalLiteRtBufferContext buffer_context; - interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); - - EXPECT_EQ(interpreter.nodes_size(), 1); - EXPECT_EQ(interpreter.inputs().size(), 2); - EXPECT_EQ(interpreter.outputs().size(), 1); - ASSERT_EQ(interpreter.execution_plan().size(), 1); - - LiteRtEnvironmentOptions env_options = nullptr; - LiteRtGetEnvironmentOptions(env.Get(), &env_options); - DispatchDelegateOptionsPtr dispatch_delegate_options = - CreateDispatchDelegateOptionsPtr(env_options); - LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), - runtime->Flatbuffer().Buf().Data()); - DispatchDelegatePtr dispatch_delegate = CreateDispatchDelegatePtr( - env_options, std::move(dispatch_delegate_options)); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; -#endif - - ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), - kTfLiteOk); - - // Get the list of signatures and check it. - auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 1); - - tflite::impl::SignatureRunner* runner = - interpreter.GetSignatureRunner(/*signature_key=*/nullptr); - ASSERT_NE(runner, nullptr); - - EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); - - // Fill model inputs. - ASSERT_STREQ(runner->input_names()[0], "arg0"); - TfLiteTensor* input_0_tensor = runner->input_tensor("arg0"); - ASSERT_NE(input_0_tensor, nullptr); - float* input_0 = input_0_tensor->data.f; - std::memcpy(input_0, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - - ASSERT_STREQ(runner->input_names()[1], "arg1"); - TfLiteTensor* input_1_tensor = runner->input_tensor("arg1"); - ASSERT_NE(input_1_tensor, nullptr); - auto* input_1 = input_1_tensor->data.f; - std::memcpy(input_1, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - - EXPECT_EQ(runner->Invoke(), kTfLiteOk); - - // Check model output. - ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); - auto output_tensor = runner->output_tensor("tfl.custom"); - ASSERT_NE(output_tensor, nullptr); - auto output = absl::MakeSpan(output_tensor->data.f, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(::testing::FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, GoogleTensorHwBuffer) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, CreateDefaultEnvironment()); - - LITERT_ASSERT_OK_AND_ASSIGN( - testing::TflRuntime::Ptr runtime, - MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile)); - tflite::Interpreter& interpreter = runtime->Interpreter(); - - internal::ExternalLiteRtBufferContext buffer_context; - interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); - - EXPECT_EQ(interpreter.nodes_size(), 1); - EXPECT_EQ(interpreter.inputs().size(), 2); - EXPECT_EQ(interpreter.outputs().size(), 1); - ASSERT_EQ(interpreter.execution_plan().size(), 1); - - LiteRtEnvironmentOptions env_options = nullptr; - LiteRtGetEnvironmentOptions(env.Get(), &env_options); - - DispatchDelegateOptionsPtr dispatch_delegate_options = - CreateDispatchDelegateOptionsPtr(env_options); - LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), - runtime->Flatbuffer().Buf().Data()); - DispatchDelegatePtr dispatch_delegate = CreateDispatchDelegatePtr( - env_options, std::move(dispatch_delegate_options)); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; -#endif - - ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), - kTfLiteOk); - - // Create and register tensor buffers for all inputs and outputs. - std::vector input_buffers; - for (int i = 0; i < interpreter.inputs().size(); ++i) { - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements * input_buffer_requirements, - buffer_context.GetBufferRequirement(interpreter.input_tensor(i))); - ASSERT_EQ(input_buffer_requirements->SupportedTypes()->at(0), - kLiteRtTensorBufferTypeAhwb); - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer input_buffer, - buffer_context.CreateBufferForTensor(interpreter.input_tensor(i))); - ASSERT_TRUE(input_buffer.IsOwned()); - ASSERT_EQ(*input_buffer.BufferType(), kLiteRtTensorBufferTypeAhwb); - LITERT_ASSERT_OK_AND_ASSIGN(TensorBuffer duplicate_buffer, - input_buffer.Duplicate()); - auto status = buffer_context.RegisterTensorBuffer( - interpreter.input_tensor(i), std::move(duplicate_buffer)); - ASSERT_EQ(status, kLiteRtStatusOk); - input_buffers.push_back(std::move(input_buffer)); - } - - std::vector output_buffers; - for (int i = 0; i < interpreter.outputs().size(); ++i) { - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements * output_buffer_requirements, - buffer_context.GetBufferRequirement(interpreter.output_tensor(i))); - ASSERT_NE(output_buffer_requirements, nullptr); - ASSERT_EQ(output_buffer_requirements->SupportedTypes()->at(0), - kLiteRtTensorBufferTypeAhwb); - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBuffer output_buffer, - buffer_context.CreateBufferForTensor(interpreter.output_tensor(i))); - ASSERT_TRUE(output_buffer.IsOwned()); - ASSERT_EQ(*output_buffer.BufferType(), kLiteRtTensorBufferTypeAhwb); - LITERT_ASSERT_OK_AND_ASSIGN(TensorBuffer duplicate_buffer, - output_buffer.Duplicate()); - auto status = buffer_context.RegisterTensorBuffer( - interpreter.output_tensor(i), std::move(duplicate_buffer)); - ASSERT_EQ(status, kLiteRtStatusOk); - output_buffers.push_back(std::move(output_buffer)); - } - - // Get the list of signatures and check it. - auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 1); - - tflite::impl::SignatureRunner* runner = - interpreter.GetSignatureRunner(/*signature_key=*/nullptr); - ASSERT_NE(runner, nullptr); - - EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); - - // Fill model inputs. - ASSERT_STREQ(runner->input_names()[0], "arg0"); - auto& input_0_buffer = input_buffers[0]; - input_0_buffer.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); - - ASSERT_STREQ(runner->input_names()[1], "arg1"); - auto& input_1_buffer = input_buffers[1]; - input_1_buffer.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); - - EXPECT_EQ(runner->Invoke(), kTfLiteOk); - - // Check model output. - ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); - auto& output_buffer = output_buffers[0]; - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - auto read_success = output_buffer.Read(output_span); - ASSERT_TRUE(read_success); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, CompiledModel) { -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; -#endif - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, CreateDefaultEnvironment()); - - // Create Model and check signatures. - LITERT_ASSERT_OK_AND_ASSIGN( - OwningBufferRef model_with_byte_code, - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile))); - LITERT_ASSERT_OK_AND_ASSIGN(Model model, - Model::CreateFromBuffer(model_with_byte_code)); - - LITERT_ASSERT_OK_AND_ASSIGN(std::vector signatures, - model.GetSignatures()); - EXPECT_EQ(signatures.size(), 1); - Signature& signature = signatures.at(0); - EXPECT_EQ(signature.Key(), Model::DefaultSignatureKey()); - size_t signature_index = 0; - - std::vector input_names = signature.InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - std::vector output_names = signature.OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.custom")); - - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - // Check CompiledModel buffer requirements. - // input and output expect AHWB. - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg0, - compiled_model.GetInputBufferRequirements(signature_index, - /*input_name=*/"arg0")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg0, - input_buffer_requirements_arg0.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg0, - ElementsAre(kLiteRtTensorBufferTypeAhwb)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements input_buffer_requirements_arg1, - compiled_model.GetInputBufferRequirements(signature_index, - /*input_name=*/"arg1")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffer_types_arg1, - input_buffer_requirements_arg1.SupportedTypes()); - EXPECT_THAT(input_buffer_types_arg1, - ElementsAre(kLiteRtTensorBufferTypeAhwb)); - - LITERT_ASSERT_OK_AND_ASSIGN( - TensorBufferRequirements output_buffer_requirements, - compiled_model.GetOutputBufferRequirements(signature_index, - /*output_name=*/"tfl.custom")); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffer_types, - output_buffer_requirements.SupportedTypes()); - EXPECT_THAT(output_buffer_types, ElementsAre(kLiteRtTensorBufferTypeAhwb)); - - // Create and fill input and output tensor buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(signature_index)); - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(signature_index)); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute compiled model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - float output_buffer_data[kTestOutputSize]; - absl::Span output_span = - absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[0].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, CompiledModelSharedInput) { - auto model_with_byte_code = internal::GetModelBufWithByteCode( - testing::GetTestFilePath("shared_input_cpu_npu.tflite"), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - auto model = Model::CreateFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; -#endif - auto jit_compilation_options = CompilationOptions::Create(); - ASSERT_TRUE(jit_compilation_options); - ASSERT_TRUE(jit_compilation_options->SetHardwareAccelerators( - kLiteRtHwAcceleratorCpu)); - - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - auto res_compiled_model = - CompiledModel::Create(*env, *model, *jit_compilation_options); - ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; - auto& compiled_model = *res_compiled_model; - - size_t signature_index = 0; - auto signature = *model->GetSignature(signature_index); - auto input_buffers = *compiled_model.CreateInputBuffers(signature_index); - auto output_buffers = *compiled_model.CreateOutputBuffers(signature_index); - - // Fill model inputs. - auto input_names = signature.InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - auto output_names = signature.OutputNames(); - EXPECT_EQ(output_names.size(), 2); - { - EXPECT_EQ(output_names.at(0), "tfl.add"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[0].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } - { - EXPECT_EQ(output_names.at(1), "tfl.custom"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[1].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -TEST(DispatchDelegate, CompiledModelAsync) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "The rest of this test is specific to Android devices with a " - "GoogleTensor eTPU"; -#endif - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(Environment env, CreateDefaultEnvironment()); - - // Create Model and check signatures. - LITERT_ASSERT_OK_AND_ASSIGN( - OwningBufferRef model_with_byte_code, - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile))); - - LITERT_ASSERT_OK_AND_ASSIGN(Model model, - Model::CreateFromBuffer(model_with_byte_code)); - - LITERT_ASSERT_OK_AND_ASSIGN(std::vector signatures, - model.GetSignatures()); - EXPECT_EQ(signatures.size(), 1); - Signature& signature = signatures.at(0); - absl::string_view signature_key = signature.Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - std::vector input_names = signature.InputNames(); - EXPECT_THAT(input_names, ElementsAre("arg0", "arg1")); - - std::vector output_names = signature.OutputNames(); - EXPECT_THAT(output_names, ElementsAre("tfl.custom")); - - // Create CompiledModel. - LITERT_ASSERT_OK_AND_ASSIGN(CompiledModel compiled_model, - CompiledModel::Create(env, model)); - - // Create and fill input and output tensor buffers. - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector input_buffers, - compiled_model.CreateInputBuffers(signature_index)); - - LITERT_ASSERT_OK_AND_ASSIGN( - std::vector output_buffers, - compiled_model.CreateOutputBuffers(signature_index)); - - LITERT_ASSERT_OK_AND_ASSIGN(auto input_0_cpu_addr_and_lock, - TensorBufferScopedLock::Create(input_buffers[0])); - - LITERT_ASSERT_OK_AND_ASSIGN(auto input_1_cpu_addr_and_lock, - TensorBufferScopedLock::Create(input_buffers[1])); - - // Attach events to input buffers. - Fence input_fence_0 = platforms::darwinn::fence_util::CreateFence(); - LITERT_ASSERT_OK_AND_ASSIGN( - Event input_event_0, - litert::Event::CreateFromSyncFenceFd(input_fence_0->GetFd(), - /*owns_fd=*/false)); - input_buffers[0].SetEvent(std::move(input_event_0)); - - Fence input_fence_1 = platforms::darwinn::fence_util::CreateFence(); - LITERT_ASSERT_OK_AND_ASSIGN( - Event input_event_1, - litert::Event::CreateFromSyncFenceFd(input_fence_1->GetFd(), - /*owns_fd=*/false)); - input_buffers[1].SetEvent(std::move(input_event_1)); - - // Start the model asynchronously. - bool async; - compiled_model.RunAsync(signature_index, input_buffers, output_buffers, - async); - ASSERT_TRUE(async); - ASSERT_TRUE(output_buffers[0].HasEvent()); - - // Set input values. - std::memcpy(input_0_cpu_addr_and_lock.second, kTestInput0Tensor, - sizeof(kTestInput0Tensor)); - std::memcpy(input_1_cpu_addr_and_lock.second, kTestInput1Tensor, - sizeof(kTestInput1Tensor)); - - // Signal input fences so that the inference can start. - ASSERT_OK(input_fence_0->Signal(/*success=*/true)); - ASSERT_OK(input_fence_1->Signal(/*success=*/true)); - - // Check model output. - float output_buffer_data[kTestOutputSize]; - absl::Span output_span = - absl::MakeSpan(output_buffer_data, kTestOutputSize); - // The next read operation will block on the output buffer's sync fence. - ASSERT_TRUE(output_buffers[0].Read(output_span)); - // Print and confirm the output values are correct. - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc deleted file mode 100644 index b408ae39a027dc..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc +++ /dev/null @@ -1,657 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h" - -#include -#include -#include -#include -#include -#include - -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/core/c/c_api_opaque.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" -#include "tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h" -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" -#include "tensorflow/lite/experimental/litert/runtime/tfl_utils.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -namespace litert { -namespace internal { - -DispatchDelegateKernel::~DispatchDelegateKernel() { - for (size_t i = 0; i < input_tensor_buffer_handles_.size(); ++i) { - (void)LiteRtDispatchDetachInput(invocation_context_, i, - input_tensor_buffer_handles_[i]); - } - - for (size_t i = 0; i < output_tensor_buffer_handles_.size(); ++i) { - (void)LiteRtDispatchDetachOutput(invocation_context_, i, - output_tensor_buffer_handles_[i]); - } - - if (invocation_context_) { - (void)LiteRtDispatchInvocationContextDestroy(invocation_context_); - } - - for (auto& buffer_handle : input_tensor_buffer_handles_) { - (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, buffer_handle); - } - - for (auto& buffer_handle : output_tensor_buffer_handles_) { - (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, buffer_handle); - } - - if (device_context_) { - (void)LiteRtDispatchDeviceContextDestroy(device_context_); - } - - input_tensor_buffers_.clear(); - output_tensor_buffers_.clear(); -} - -Expected DispatchDelegateKernel::Create( - std::string&& graph_name, const LiteRtDispatchDelegateOptions& options) { - auto dispatch_options = options.GetDispatchOptions(); - if (auto status = LiteRtDispatchInitialize(dispatch_options.data(), - dispatch_options.size()); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to initialize Dispatch API: %d", status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to initialize Dispatch API"); - } - - const char* vendor_id; - if (auto status = LiteRtDispatchGetVendorId(&vendor_id); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API vendor ID: %d", - status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get Dispatch API vendor ID"); - } - LITERT_LOG(LITERT_INFO, "Dispatch API vendor ID: %s", vendor_id); - - const char* build_id; - if (auto status = LiteRtDispatchGetBuildId(&build_id); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API build ID: %d", status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get Dispatch API build ID"); - } - LITERT_LOG(LITERT_INFO, "Dispatch API build ID: %s", build_id); - - LiteRtApiVersion api_version; - if (auto status = LiteRtDispatchGetApiVersion(&api_version); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get LiteRT Dispatch API version: %d", - status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get LiteRT Dispatch API version"); - } - LITERT_LOG(LITERT_INFO, "Dispatch API version: %d.%d.%d", api_version.major, - api_version.minor, api_version.patch); - // Check if the versions mach. - if (api_version.major != LITERT_API_VERSION_MAJOR || - api_version.minor < LITERT_API_VERSION_MINOR) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Found Dispatch API with an unsupported version"); - } - - int capabilities; - if (auto status = LiteRtDispatchGetCapabilities(&capabilities); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API capabilities: %d", - status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get Dispatch API capabilities"); - } - LITERT_LOG(LITERT_INFO, "Dispatch API capabilities: %d", capabilities); - - if (!(capabilities & kLiteRtDispatchCapabilitiesBasic)) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Dispatch API has insufficient capabilities"); - } - - bool async_dispatch = (capabilities & kLiteRtDispatchCapabilitiesAsync); - if (async_dispatch) { - LITERT_LOG(LITERT_INFO, "Found async dispatch capabilities"); - } - - LiteRtDispatchDeviceContext device_context; - if (auto status = LiteRtDispatchDeviceContextCreate(&device_context); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API device context: %d", - status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create Dispatch API device context"); - } - - return Ptr(new DispatchDelegateKernel(options, std::move(graph_name), - device_context, async_dispatch)); -} - -TfLiteStatus DispatchDelegateKernel::Init( - TfLiteOpaqueContext* context, const TfLiteOpaqueDelegateParams* params) { - if (params->nodes_to_replace->size != 1) { - LITERT_LOG(LITERT_ERROR, - "Models with more than one dispatch node are not yet supported"); - return kTfLiteError; - } - - auto node_id = params->nodes_to_replace->data[0]; - TfLiteOpaqueNode* node; - TfLiteOperator* op; - if (auto status = TfLiteOpaqueContextGetNodeAndRegistration(context, node_id, - &node, &op); - status != kTfLiteOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get node and registration: %d", status); - return status; - } - - const void* init_data; - int init_data_size; - if (auto status = TfLiteOpaqueNodeGetCustomInitialData(node, &init_data, - &init_data_size); - status != kTfLiteOk) { - LITERT_LOG(LITERT_ERROR, "Failed to get custom initial data: %d", status); - return status; - } - if (!init_data || !init_data_size) { - LITERT_LOG(LITERT_ERROR, "Found custom op with missing initial data"); - return kTfLiteError; - } - - BufferRef custom_opts(init_data, init_data_size); - - // Read offset and size (relative to alloc_base) from the custom options (and - // name). - const auto dispatch_opts = GetDispatchOpOptions(custom_opts); - if (dispatch_opts.bytecode_offset == 0) { - LITERT_LOG(LITERT_ERROR, "Found dispatch op with missing bytecode offset"); - return kTfLiteError; - } - - // Find pointer to the start of the loaded model buffer. - const auto alloc_base = FindAllocBase(options_); - if (!alloc_base) { - LITERT_LOG(LITERT_ERROR, - "Could not find requried delegate options \"alloc_base\""); - return kTfLiteError; - } - - const auto alloc_fd = FindAllocFd(options_); - - // Get location of bytecode in the model buffer relative to alloc_base. - LiteRtMemBuffer exec_bytecode_buffer = { - /*.fd=*/alloc_fd ? *alloc_fd : -1, - /*.base_addr=*/*alloc_base, - /*.offset=*/dispatch_opts.bytecode_offset, - /*.size=*/dispatch_opts.bytecode_size}; - const auto& function_name = dispatch_opts.name; - const int num_inputs = params->input_tensors->size; - const int num_outputs = params->output_tensors->size; - - if (auto status = LiteRtDispatchInvocationContextCreate( - device_context_, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, function_name.data(), num_inputs, num_outputs, - &invocation_context_); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to create invocation context: %d", status); - return kTfLiteError; - } - - input_tensor_buffers_require_cpu_sync_.resize(num_inputs); - input_tensor_buffers_.resize(num_inputs); - input_tensor_buffer_handles_.resize(num_inputs); - input_tensor_buffer_used_size_.resize(num_inputs); - - output_tensor_buffers_require_cpu_sync_.resize(num_outputs); - output_tensor_buffers_.resize(num_outputs); - output_tensor_buffer_handles_.resize(num_outputs); - output_tensor_buffer_used_size_.resize(num_outputs); - - void* external_context; - TfLiteOpaqueContextGetExternalContext(context, &external_context, - kTfLiteLiteRtBufferContext); - if (!external_context) { - LITERT_LOG(LITERT_ERROR, "External context not found"); - return kTfLiteError; - } - - buffer_context_ = - reinterpret_cast( - external_context); - - // Register input and output buffer requirements. - size_t num_node_inputs = TfLiteOpaqueNodeNumberOfInputs(node); - for (size_t i = 0; i < num_node_inputs; ++i) { - auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetInput(context, node, i); - if (!tfl_opaque_tensor) { - LITERT_LOG(LITERT_ERROR, "Failed to get TFL node input %d", i); - return kTfLiteError; - } - auto tensor_type = ConvertTensorType(tfl_opaque_tensor); - if (!tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", tensor_type.Error().Message().c_str()); - return kTfLiteError; - } - auto input_buffer_requirements = - GetBufferRequirements(*tensor_type, i, /*is_input=*/true); - if (auto res = buffer_context_->RegisterBufferRequirement( - tfl_opaque_tensor, std::move(*input_buffer_requirements)); - res != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register buffer requirement"); - return kTfLiteError; - } - } - - size_t num_node_outputs = TfLiteOpaqueNodeNumberOfOutputs(node); - for (size_t i = 0; i < num_node_outputs; ++i) { - auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetOutput(context, node, i); - if (!tfl_opaque_tensor) { - LITERT_LOG(LITERT_ERROR, "Failed to get TFL node output %d", i); - return kTfLiteError; - } - auto tensor_type = ConvertTensorType(tfl_opaque_tensor); - if (!tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", tensor_type.Error().Message().c_str()); - return kTfLiteError; - } - auto output_buffer_requirements = - GetBufferRequirements(*tensor_type, i, /*is_input=*/false); - if (auto res = buffer_context_->RegisterBufferRequirement( - tfl_opaque_tensor, std::move(*output_buffer_requirements)); - res != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register buffer requirement"); - return kTfLiteError; - } - } - - return kTfLiteOk; -} - -Expected -DispatchDelegateKernel::GetBufferRequirements( - const RankedTensorType& tensor_type, int io_tensor_index, - bool is_input) const { - auto litert_tensor_type = static_cast(tensor_type); - LiteRtTensorBufferRequirements tensor_buffer_requirements; - if (is_input) { - if (auto status = LiteRtDispatchGetInputRequirements( - invocation_context_, /*input_index=*/io_tensor_index, - &litert_tensor_type, &tensor_buffer_requirements); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, - "Failed to get tensor buffer requirements for input %d: %d", - io_tensor_index, status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get tensor buffer requirements for input"); - } - - } else { - if (auto status = LiteRtDispatchGetOutputRequirements( - invocation_context_, /*output_index=*/io_tensor_index, - &litert_tensor_type, &tensor_buffer_requirements); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, - "Failed to get tensor buffer requirements for output %d: %d", - io_tensor_index, status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get tensor buffer requirements for output"); - } - } - - return TensorBufferRequirements(tensor_buffer_requirements, - /*owned=*/true); -} - -TfLiteStatus DispatchDelegateKernel::CreateAndSetBuffer( - const TfLiteOpaqueTensor* tfl_opaque_tensor, int buffer_index, - bool is_input) { - auto& cached_tensor_buffer = is_input ? input_tensor_buffers_[buffer_index] - : output_tensor_buffers_[buffer_index]; - - auto tensor_type = ConvertTensorType(tfl_opaque_tensor); - if (!tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", tensor_type.Error().Message().c_str()); - return kTfLiteError; - } - - // Check if we can reuse a cached tensor buffer or we need to create a new - // one. - if (static_cast(cached_tensor_buffer)) { - if (auto cached_tensor_type = cached_tensor_buffer.TensorType(); - !cached_tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", - cached_tensor_type.Error().Message().c_str()); - return kTfLiteError; - } - - if (tensor_type->Layout() == cached_tensor_buffer.TensorType()->Layout()) { - // We can reuse the cached tensor buffer. - return kTfLiteOk; - } - - // We cannot reuse the cached tensor buffer; proceed below. - } - - auto tensor_buffer_requirements = - GetBufferRequirements(*tensor_type, buffer_index, is_input); - if (!tensor_buffer_requirements) { - LITERT_LOG(LITERT_ERROR, "%s", - tensor_buffer_requirements.Error().Message().c_str()); - return kTfLiteError; - } - - auto supported_tensor_buffer_types = - tensor_buffer_requirements->SupportedTypes(); - if (!supported_tensor_buffer_types) { - LITERT_LOG(LITERT_ERROR, "%s", - supported_tensor_buffer_types.Error().Message().c_str()); - return kTfLiteError; - } - - if (supported_tensor_buffer_types->empty()) { - LITERT_LOG(LITERT_ERROR, - "Insufficient number of supported tensor buffer types"); - return kTfLiteError; - } - - // For now we simply pick the first buffer type that's supported. - LiteRtTensorBufferType tensor_buffer_type = - (*supported_tensor_buffer_types)[0]; - - auto tensor_buffer_size = tensor_buffer_requirements->BufferSize(); - if (!tensor_buffer_size) { - LITERT_LOG(LITERT_ERROR, "%s", - tensor_buffer_size.Error().Message().c_str()); - return kTfLiteError; - } - - auto litert_tensor_type = static_cast(*tensor_type); - LiteRtTensorBuffer litert_tensor_buffer; - if (auto status = LiteRtCreateManagedTensorBuffer( - tensor_buffer_type, &litert_tensor_type, *tensor_buffer_size, - &litert_tensor_buffer); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to create managed tensor buffer: %d", - status); - return kTfLiteError; - } - - return RegisterLiteRtTensorBuffer(TensorBuffer(litert_tensor_buffer), - *tensor_buffer_size, buffer_index, - is_input); -} - -TfLiteStatus DispatchDelegateKernel::RegisterLiteRtTensorBuffer( - TensorBuffer&& tensor_buffer, size_t buffer_used_size, int buffer_index, - bool is_input) { - LiteRtTensorBufferHandle buffer_handle; - if (auto status = LiteRtDispatchRegisterTensorBuffer( - device_context_, tensor_buffer.Get(), &buffer_handle); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register tensor buffer: %d", status); - return kTfLiteError; - } - - if (is_input) { - if (auto status = LiteRtDispatchAttachInput(invocation_context_, - buffer_index, buffer_handle); - status != kLiteRtStatusOk) { - (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, - buffer_handle); - LITERT_LOG(LITERT_ERROR, "Failed to attach tensor buffer to input %d: %d", - buffer_index, status); - return kTfLiteError; - } - - if (tensor_buffer.HasEvent()) { - auto event = tensor_buffer.GetEvent(); - if (!event) { - LITERT_LOG(LITERT_ERROR, - "Failed to get event from tensor buffer %d: %s", - buffer_index, event.Error().Message().c_str()); - return kTfLiteError; - } - - if (!async_dispatch_) { - // If the Dispatch API runtime doesn't support async execution, then - // wait for the event on the CPU. - LITERT_LOG(LITERT_WARNING, "Waiting on an input event on the CPU..."); - if (auto status = event->Wait(/*timeout_in_ms=*/-1); !status) { - LITERT_LOG(LITERT_ERROR, "Failed to wait on event: %s", - status.Error().Message().c_str()); - return kTfLiteError; - } - - } else { - if (auto status = LiteRtDispatchAttachInputEvent( - invocation_context_, buffer_index, event->Get()); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to attach event to input %d: %d", - buffer_index, status); - return kTfLiteError; - } - } - } - - } else { - if (auto status = LiteRtDispatchAttachOutput(invocation_context_, - buffer_index, buffer_handle); - status != kLiteRtStatusOk) { - (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, - buffer_handle); - LITERT_LOG(LITERT_ERROR, - "Failed to attach tensor buffer to output %d: %d", - buffer_index, status); - return kTfLiteError; - } - } - - if (is_input) { - input_tensor_buffers_[buffer_index] = std::move(tensor_buffer); - input_tensor_buffer_handles_[buffer_index] = buffer_handle; - input_tensor_buffer_used_size_[buffer_index] = buffer_used_size; - } else { - output_tensor_buffers_[buffer_index] = std::move(tensor_buffer); - output_tensor_buffer_handles_[buffer_index] = buffer_handle; - output_tensor_buffer_used_size_[buffer_index] = buffer_used_size; - } - return kTfLiteOk; -} - -TfLiteStatus DispatchDelegateKernel::Prepare(TfLiteOpaqueContext* context, - TfLiteOpaqueNode* node) { - return kTfLiteOk; -} - -TfLiteStatus DispatchDelegateKernel::RegisterLiteRtTensorBuffers( - TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) { - size_t num_node_inputs = TfLiteOpaqueNodeNumberOfInputs(node); - for (size_t i = 0; i < num_node_inputs; ++i) { - auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetInput(context, node, i); - auto tensor_buffer = buffer_context_->GetTensorBuffer(tfl_opaque_tensor); - if (tensor_buffer.HasValue()) { - // TODO - b/379176766: If the provided TensorBuffer is not supported - // types, we need to create a new one and convert the data from the - // provided TensorBuffer. - auto buffer_size = tensor_buffer->Size(); - if (!buffer_size) { - LITERT_LOG(LITERT_ERROR, "%s", buffer_size.Error().Message().c_str()); - return kTfLiteError; - } - if (auto status = RegisterLiteRtTensorBuffer(std::move(*tensor_buffer), - *buffer_size, i, - /*is_input=*/true); - status != kTfLiteOk) { - return status; - } - input_tensor_buffers_require_cpu_sync_[i] = false; - } else { - LITERT_LOG(LITERT_VERBOSE, - "Input#%d TensorBuffer is not registered. Create a new one", - i); - if (auto status = - CreateAndSetBuffer(tfl_opaque_tensor, i, /*is_input=*/true); - status != kTfLiteOk) { - return status; - } - input_tensor_buffers_require_cpu_sync_[i] = true; - } - } - - size_t num_node_outputs = TfLiteOpaqueNodeNumberOfOutputs(node); - for (size_t i = 0; i < num_node_outputs; ++i) { - auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetOutput(context, node, i); - auto tensor_buffer = buffer_context_->GetTensorBuffer(tfl_opaque_tensor); - if (tensor_buffer.HasValue()) { - // TODO - b/379176766: If the provided TensorBuffer is not supported - // types, we need to create a new one and convert the data back to the - // provided TensorBuffer. - auto buffer_size = tensor_buffer->Size(); - if (!buffer_size) { - LITERT_LOG(LITERT_ERROR, "%s", buffer_size.Error().Message().c_str()); - return kTfLiteError; - } - if (auto status = RegisterLiteRtTensorBuffer(std::move(*tensor_buffer), - *buffer_size, i, - /*is_input=*/false); - status != kTfLiteOk) { - return status; - } - output_tensor_buffers_require_cpu_sync_[i] = false; - } else { - LITERT_LOG(LITERT_VERBOSE, - "Output#%d TensorBuffer is not registered. Create a new one", - i); - if (auto status = - CreateAndSetBuffer(tfl_opaque_tensor, i, /*is_input=*/false); - status != kTfLiteOk) { - return status; - } - output_tensor_buffers_require_cpu_sync_[i] = true; - } - } - - return kTfLiteOk; -} - -TfLiteStatus DispatchDelegateKernel::Eval(TfLiteOpaqueContext* context, - TfLiteOpaqueNode* node) { - if (auto status = RegisterLiteRtTensorBuffers(context, node); - status != kTfLiteOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register tensor buffers: %d", status); - return kTfLiteError; - } - - size_t num_node_inputs = TfLiteOpaqueNodeNumberOfInputs(node); - if (num_node_inputs != input_tensor_buffers_.size()) { - LITERT_LOG(LITERT_ERROR, "Invalid number of inputs"); - return kTfLiteError; - } - - for (size_t i = 0; i < num_node_inputs; ++i) { - if (!input_tensor_buffers_require_cpu_sync_[i]) { - continue; - } - auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetInput(context, node, i); - void* tensor_data = TfLiteOpaqueTensorData(tfl_opaque_tensor); - auto& tensor_buffer = input_tensor_buffers_[i]; - - auto lock_and_addr = TensorBufferScopedLock::Create(tensor_buffer); - if (!lock_and_addr) { - LITERT_LOG(LITERT_ERROR, "%s", lock_and_addr.Error().Message().c_str()); - return kTfLiteError; - } - - size_t buffer_size = input_tensor_buffer_used_size_[i]; - std::memcpy(lock_and_addr->second, tensor_data, buffer_size); - } - - size_t num_node_outputs = TfLiteOpaqueNodeNumberOfOutputs(node); - if (num_node_outputs != output_tensor_buffers_.size()) { - LITERT_LOG(LITERT_ERROR, "Invalid number of outputs"); - return kTfLiteError; - } - - if (async_dispatch_ && buffer_context_->IsAsyncExecutionMode()) { - std::vector output_events(num_node_outputs); - if (auto status = LiteRtDispatchInvokeAsync( - invocation_context_, output_events.size(), output_events.data()); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to invoke context asynchronously: %d", - status); - return kTfLiteError; - } - for (size_t i = 0; i < output_events.size(); ++i) { - auto output_event = output_events[i]; - if (output_event) { - auto& tensor_buffer = output_tensor_buffers_[i]; - if (auto status = tensor_buffer.SetEvent(Event(output_event)); - !status) { - LITERT_LOG(LITERT_ERROR, - "Failed to set event on output tensor buffer: %s", - status.Error().Message().c_str()); - return kTfLiteError; - } - } - } - - } else { - if (auto status = LiteRtDispatchInvoke(invocation_context_); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to invoke context: %d", status); - return kTfLiteError; - } - } - - for (size_t i = 0; i < num_node_outputs; ++i) { - if (!output_tensor_buffers_require_cpu_sync_[i]) { - continue; - } - auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetOutput(context, node, i); - void* tensor_data = TfLiteOpaqueTensorData(tfl_opaque_tensor); - auto& tensor_buffer = output_tensor_buffers_[i]; - - auto lock_and_addr = TensorBufferScopedLock::Create(tensor_buffer); - if (!lock_and_addr) { - LITERT_LOG(LITERT_ERROR, "%s", lock_and_addr.Error().Message().c_str()); - return kTfLiteError; - } - - size_t buffer_size = output_tensor_buffer_used_size_[i]; - std::memcpy(tensor_data, lock_and_addr->second, buffer_size); - } - - return kTfLiteOk; -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h deleted file mode 100644 index c53cc09e9d780c..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.h +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_KERNEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_KERNEL_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -namespace litert::internal { - -class ExternalLiteRtBufferContext; - -// A TFL kernel that the interpreter calls to dispatch execution through the -// Dispatch API. -class DispatchDelegateKernel - : public tflite::SimpleOpaqueDelegateKernelInterface { - public: - using Ptr = std::unique_ptr; - - ~DispatchDelegateKernel() override; - - static Expected Create(std::string&& graph_name, - const LiteRtDispatchDelegateOptions& options); - - TfLiteStatus Init(TfLiteOpaqueContext* context, - const TfLiteOpaqueDelegateParams* params) override; - - TfLiteStatus Prepare(TfLiteOpaqueContext* context, - TfLiteOpaqueNode* node) override; - - TfLiteStatus Eval(TfLiteOpaqueContext* context, - TfLiteOpaqueNode* node) override; - - private: - DispatchDelegateKernel(const LiteRtDispatchDelegateOptions& options, - std::string&& graph_name, - LiteRtDispatchDeviceContext device_context, - bool async_dispatch) - : options_(options), - graph_name_(std::move(graph_name)), - device_context_(device_context), - async_dispatch_(async_dispatch) {} - - Expected GetBufferRequirements( - const RankedTensorType& tensor_type, int io_tensor_index, - bool is_input) const; - - // Creates a new tensor buffer for the given tensor. After that the created - // tensor buffer is registered with RegisterLiteRtTensorBuffer(). - TfLiteStatus CreateAndSetBuffer(const TfLiteOpaqueTensor* tfl_opaque_tensor, - int buffer_index, bool is_input); - - // Registers the given LiteRtTensorBuffer (and its size) with the Dispatch - // API. - // Also update the internal state (input_tensor_buffers_, etc.) to keep track - // of the registered tensor buffers. - TfLiteStatus RegisterLiteRtTensorBuffer(TensorBuffer&& tensor_buffer, - size_t used_size, int buffer_index, - bool is_input); - - // Registers LiteRtTensorBuffers for all inputs and outputs of the given - // node. - // Also update the internal state (input_tensor_buffers_, etc.) to keep track - // of the registered tensor buffers. - TfLiteStatus RegisterLiteRtTensorBuffers(TfLiteOpaqueContext* context, - TfLiteOpaqueNode* node); - - const LiteRtDispatchDelegateOptions& options_; - std::string graph_name_; - LiteRtDispatchDeviceContext device_context_; - LiteRtDispatchInvocationContext invocation_context_ = nullptr; - // Indicates whether the Dispatch API can be invoked asynchronously. - const bool async_dispatch_; - - ExternalLiteRtBufferContext* buffer_context_ = nullptr; - - // Indicates whether the input tensor buffer requires a CPU sync before - // invoking the Dispatch API. - std::vector input_tensor_buffers_require_cpu_sync_; - - std::vector input_tensor_buffers_; - std::vector input_tensor_buffer_handles_; - std::vector input_tensor_buffer_used_size_; - - // Indicates whether the output tensor buffer requires a CPU sync after - // invoking the Dispatch API. - std::vector output_tensor_buffers_require_cpu_sync_; - - std::vector output_tensor_buffers_; - std::vector output_tensor_buffer_handles_; - std::vector output_tensor_buffer_used_size_; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_KERNEL_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc deleted file mode 100644 index aa26b3c211c9ad..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc +++ /dev/null @@ -1,406 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/signature_runner.h" - -namespace litert { -namespace { - -using ::litert::testing::MakeRuntimeFromTestFileWithNpuModel; -using ::testing::FloatNear; -using ::testing::Pointwise; - -static constexpr absl::string_view kNpuFile = kMediaTekModelFileName; -static constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; -static constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - -TEST(DispatchDelegate, MediaTekCpuBuffer) { - auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); - ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; - auto& rt = **runtime; - auto& interpreter = rt.Interpreter(); - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - - litert::internal::ExternalLiteRtBufferContext buffer_context; - interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); - - EXPECT_EQ(interpreter.nodes_size(), 1); - EXPECT_EQ(interpreter.inputs().size(), 2); - EXPECT_EQ(interpreter.outputs().size(), 1); - ASSERT_EQ(interpreter.execution_plan().size(), 1); - - LiteRtEnvironmentOptions env_options; - LiteRtGetEnvironmentOptions(env->Get(), &env_options); - - auto dispatch_delegate_options = - CreateDispatchDelegateOptionsPtr(env_options); - LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), - rt.Flatbuffer().Buf().Data()); - auto dispatch_delegate = CreateDispatchDelegatePtr( - env_options, std::move(dispatch_delegate_options)); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "MediaTek NPU"; -#endif - - ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), - kTfLiteOk); - - // Get the list of signatures and check it. - auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 1); - - tflite::impl::SignatureRunner* runner = - interpreter.GetSignatureRunner(/*signature_key=*/nullptr); - ASSERT_NE(runner, nullptr); - - EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); - - // Fill model inputs. - ASSERT_STREQ(runner->input_names()[0], "arg0"); - auto input_0_tensor = runner->input_tensor("arg0"); - ASSERT_NE(input_0_tensor, nullptr); - auto* input_0 = input_0_tensor->data.f; - std::memcpy(input_0, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - - ASSERT_STREQ(runner->input_names()[1], "arg1"); - auto input_1_tensor = runner->input_tensor("arg1"); - ASSERT_NE(input_1_tensor, nullptr); - auto* input_1 = input_1_tensor->data.f; - std::memcpy(input_1, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - - EXPECT_EQ(runner->Invoke(), kTfLiteOk); - - // Check model output. - ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); - auto output_tensor = runner->output_tensor("tfl.custom"); - ASSERT_NE(output_tensor, nullptr); - auto output = absl::MakeSpan(output_tensor->data.f, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(::testing::FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, MediaTekHwBuffer) { - auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); - ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; - auto& rt = **runtime; - auto& interpreter = rt.Interpreter(); - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - - litert::internal::ExternalLiteRtBufferContext buffer_context; - interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); - - EXPECT_EQ(interpreter.nodes_size(), 1); - EXPECT_EQ(interpreter.inputs().size(), 2); - EXPECT_EQ(interpreter.outputs().size(), 1); - ASSERT_EQ(interpreter.execution_plan().size(), 1); - - LiteRtEnvironmentOptions env_options; - LiteRtGetEnvironmentOptions(env->Get(), &env_options); - - auto dispatch_delegate_options = - CreateDispatchDelegateOptionsPtr(env_options); - LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), - rt.Flatbuffer().Buf().Data()); - auto dispatch_delegate = CreateDispatchDelegatePtr( - env_options, std::move(dispatch_delegate_options)); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "MediaTek NPU"; -#endif - - ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), - kTfLiteOk); - - // Create and register tensor buffers for all inputs and outputs. - - std::vector input_buffers; - for (int i = 0; i < interpreter.inputs().size(); ++i) { - auto input_buffer_requirements = - buffer_context.GetBufferRequirement(interpreter.input_tensor(i)); - ASSERT_TRUE(input_buffer_requirements); - ASSERT_EQ((*input_buffer_requirements)->SupportedTypes().Value()[0], - kLiteRtTensorBufferTypeAhwb); - auto input_buffer = - buffer_context.CreateBufferForTensor(interpreter.input_tensor(i)); - ASSERT_TRUE(input_buffer); - ASSERT_TRUE(input_buffer->IsOwned()); - ASSERT_EQ(*input_buffer->BufferType(), kLiteRtTensorBufferTypeAhwb); - auto duplicate_buffer = (*input_buffer).Duplicate(); - ASSERT_TRUE(duplicate_buffer); - auto status = buffer_context.RegisterTensorBuffer( - interpreter.input_tensor(i), std::move(*duplicate_buffer)); - ASSERT_EQ(status, kLiteRtStatusOk); - input_buffers.push_back(std::move(*input_buffer)); - } - - std::vector output_buffers; - for (int i = 0; i < interpreter.outputs().size(); ++i) { - auto output_buffer_requirements = - buffer_context.GetBufferRequirement(interpreter.output_tensor(i)); - ASSERT_TRUE(output_buffer_requirements); - ASSERT_EQ((*output_buffer_requirements)->SupportedTypes().Value()[0], - kLiteRtTensorBufferTypeAhwb); - auto output_buffer = - buffer_context.CreateBufferForTensor(interpreter.output_tensor(i)); - ASSERT_TRUE(output_buffer.HasValue()); - ASSERT_TRUE(output_buffer->IsOwned()); - ASSERT_EQ(*output_buffer->BufferType(), kLiteRtTensorBufferTypeAhwb); - auto duplicate_buffer = (*output_buffer).Duplicate(); - ASSERT_TRUE(duplicate_buffer); - auto status = buffer_context.RegisterTensorBuffer( - interpreter.output_tensor(i), std::move(*duplicate_buffer)); - ASSERT_EQ(status, kLiteRtStatusOk); - output_buffers.push_back(std::move(*output_buffer)); - } - - // Get the list of signatures and check it. - auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 1); - - tflite::impl::SignatureRunner* runner = - interpreter.GetSignatureRunner(/*signature_key=*/nullptr); - ASSERT_NE(runner, nullptr); - - EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); - - // Fill model inputs. - ASSERT_STREQ(runner->input_names()[0], "arg0"); - auto& input_0_buffer = input_buffers[0]; - input_0_buffer.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); - - ASSERT_STREQ(runner->input_names()[1], "arg1"); - auto& input_1_buffer = input_buffers[1]; - input_1_buffer.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); - - EXPECT_EQ(runner->Invoke(), kTfLiteOk); - - // Check model output. - ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); - auto& output_buffer = output_buffers[0]; - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - auto read_success = output_buffer.Read(output_span); - ASSERT_TRUE(read_success); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, CompiledModel) { - auto model_with_byte_code = - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - auto model = Model::CreateFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "MediaTek NPU"; -#endif - auto jit_compilation_options = CompilationOptions::Create(); - ASSERT_TRUE(jit_compilation_options); - ASSERT_TRUE(jit_compilation_options->SetHardwareAccelerators( - kLiteRtHwAcceleratorCpu)); - - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - auto res_compiled_model = - CompiledModel::Create(*env, *model, *jit_compilation_options); - ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; - auto& compiled_model = *res_compiled_model; - - auto signatures = model->GetSignatures(); - ASSERT_TRUE(signatures); - EXPECT_EQ(signatures->size(), 1); - auto& signature = signatures->at(0); - auto signature_key = signature.Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - auto input_buffers_res = compiled_model.CreateInputBuffers(signature_index); - EXPECT_TRUE(input_buffers_res); - auto& input_buffers = *input_buffers_res; - - auto output_buffers_res = compiled_model.CreateOutputBuffers(signature_index); - EXPECT_TRUE(output_buffers_res); - auto& output_buffers = *output_buffers_res; - - // Fill model inputs. - auto input_names = signature.InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - auto output_names = signature.OutputNames(); - EXPECT_EQ(output_names.size(), 1); - EXPECT_EQ(output_names.at(0), "tfl.custom"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[0].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, CompiledModelSharedInput) { - auto model_with_byte_code = internal::GetModelBufWithByteCode( - testing::GetTestFilePath("shared_input_cpu_npu.tflite"), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - auto model = Model::CreateFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "MediaTek NPU"; -#endif - auto jit_compilation_options = CompilationOptions::Create(); - ASSERT_TRUE(jit_compilation_options); - ASSERT_TRUE(jit_compilation_options->SetHardwareAccelerators( - kLiteRtHwAcceleratorCpu)); - - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - auto res_compiled_model = - CompiledModel::Create(*env, *model, *jit_compilation_options); - ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; - auto& compiled_model = *res_compiled_model; - - size_t signature_index = 0; - auto signature = *model->GetSignature(signature_index); - auto input_buffers = *compiled_model.CreateInputBuffers(signature_index); - auto output_buffers = *compiled_model.CreateOutputBuffers(signature_index); - - // Fill model inputs. - auto input_names = signature.InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - auto output_names = signature.OutputNames(); - EXPECT_EQ(output_names.size(), 2); - { - EXPECT_EQ(output_names.at(0), "tfl.add"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[0].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } - { - EXPECT_EQ(output_names.at(1), "tfl.custom"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[1].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h deleted file mode 100644 index c4847fdaac0e86..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_OPTIONS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_OPTIONS_H_ - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/environment_options.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -class LiteRtDispatchDelegateOptions { - public: - explicit LiteRtDispatchDelegateOptions( - const LiteRtEnvironmentOptionsT* environment_options) { - if (!environment_options) { - return; - } - auto option = - environment_options->GetOption(kLiteRtEnvOptionTagDispatchLibraryDir); - if (!option.HasValue()) { - return; - } - - if (option->type != kLiteRtAnyTypeString) { - LITERT_LOG(LITERT_WARNING, - "Ignoring option kLiteRtEnvOptionTagDispatchLibraryDir due " - "to invalid value"); - return; - } - - LiteRtDispatchOption dispatch_option = { - /*.name=*/kDispatchOptionSharedLibraryDir, - /*.value=*/*option, - }; - AddOption(dispatch_option); - } - - // Push a new dispatch option. - void AddOption(LiteRtDispatchOption option) { options_.push_back(option); } - - // Get all dispatch options. - const std::vector& GetDispatchOptions() const { - return options_; - } - - // Find a dispatch option under the given name if it exists. - litert::Expected FindDispatchOption(absl::string_view name) const { - for (const auto& option : options_) { - if (option.name != name) { - continue; - } - return litert::ToStdAny(option.value); - } - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); - } - - private: - std::vector options_; -}; - -// -// Common options -// - -static constexpr absl::string_view kAllocBase = "alloc_base"; -static constexpr absl::string_view kAllocFd = "alloc_fd"; - -inline void AddAllocBaseOption(const void* alloc_base, - LiteRtDispatchDelegateOptions& opts) { - LiteRtAny opt; - opt.type = kLiteRtAnyTypeVoidPtr; - opt.ptr_value = alloc_base; - opts.AddOption(LiteRtDispatchOption{kAllocBase.data(), opt}); -} - -inline litert::Expected FindAllocBase( - const LiteRtDispatchDelegateOptions& opts) { - auto alloc_base = opts.FindDispatchOption(kAllocBase); - if (!alloc_base) { - return alloc_base.Error(); - } - return std::any_cast(*alloc_base); -} - -inline void AddAllocFdOption(int alloc_fd, - LiteRtDispatchDelegateOptions& opts) { - LiteRtAny opt; - opt.type = kLiteRtAnyTypeVoidPtr; - opt.int_value = alloc_fd; - opts.AddOption(LiteRtDispatchOption{kAllocBase.data(), opt}); -} - -inline litert::Expected FindAllocFd( - const LiteRtDispatchDelegateOptions& opts) { - auto alloc_fd = opts.FindDispatchOption(kAllocFd); - if (!alloc_fd) { - return alloc_fd.Error(); - } - return std::any_cast(*alloc_fd); -} - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DISPATCH_DISPATCH_DELEGATE_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc deleted file mode 100644 index 2b18e48f63b5c0..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc +++ /dev/null @@ -1,406 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment_options.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_dispatch_delegate.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/signature_runner.h" - -namespace litert { -namespace { - -using ::litert::testing::MakeRuntimeFromTestFileWithNpuModel; -using ::testing::FloatNear; -using ::testing::Pointwise; - -static constexpr absl::string_view kNpuFile = kQualcommModelFileName; -static constexpr absl::string_view kTfliteFile = "simple_model_npu.tflite"; -static constexpr absl::string_view kDispatchLibraryDir = "/data/local/tmp"; - -TEST(DispatchDelegate, QualcommCpuBuffer) { - auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); - ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; - auto& rt = **runtime; - auto& interpreter = rt.Interpreter(); - - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - - litert::internal::ExternalLiteRtBufferContext buffer_context; - interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); - - EXPECT_EQ(interpreter.nodes_size(), 1); - EXPECT_EQ(interpreter.inputs().size(), 2); - EXPECT_EQ(interpreter.outputs().size(), 1); - ASSERT_EQ(interpreter.execution_plan().size(), 1); - - LiteRtEnvironmentOptions env_options; - LiteRtGetEnvironmentOptions(env->Get(), &env_options); - - auto dispatch_delegate_options = - CreateDispatchDelegateOptionsPtr(env_options); - LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), - rt.Flatbuffer().Buf().Data()); - auto dispatch_delegate = CreateDispatchDelegatePtr( - env_options, std::move(dispatch_delegate_options)); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "Qualcomm HTP"; -#endif - - ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), - kTfLiteOk); - - // Get the list of signatures and check it. - auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 1); - - tflite::impl::SignatureRunner* runner = - interpreter.GetSignatureRunner(/*signature_key=*/nullptr); - ASSERT_NE(runner, nullptr); - - EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); - - // Fill model inputs. - ASSERT_STREQ(runner->input_names()[0], "arg0"); - auto input_0_tensor = runner->input_tensor("arg0"); - ASSERT_NE(input_0_tensor, nullptr); - auto* input_0 = input_0_tensor->data.f; - std::memcpy(input_0, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - - ASSERT_STREQ(runner->input_names()[1], "arg1"); - auto input_1_tensor = runner->input_tensor("arg1"); - ASSERT_NE(input_1_tensor, nullptr); - auto* input_1 = input_1_tensor->data.f; - std::memcpy(input_1, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - - EXPECT_EQ(runner->Invoke(), kTfLiteOk); - - // Check model output. - ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); - auto output_tensor = runner->output_tensor("tfl.custom"); - ASSERT_NE(output_tensor, nullptr); - auto output = absl::MakeSpan(output_tensor->data.f, kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(::testing::FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, QualcommHwBuffer) { - auto runtime = MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile); - ASSERT_TRUE(runtime) << "Failed to initialize tflite interpreter"; - auto& rt = **runtime; - auto& interpreter = rt.Interpreter(); - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - - litert::internal::ExternalLiteRtBufferContext buffer_context; - interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); - - EXPECT_EQ(interpreter.nodes_size(), 1); - EXPECT_EQ(interpreter.inputs().size(), 2); - EXPECT_EQ(interpreter.outputs().size(), 1); - ASSERT_EQ(interpreter.execution_plan().size(), 1); - - LiteRtEnvironmentOptions env_options; - LiteRtGetEnvironmentOptions(env->Get(), &env_options); - - auto dispatch_delegate_options = - CreateDispatchDelegateOptionsPtr(env_options); - LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), - rt.Flatbuffer().Buf().Data()); - auto dispatch_delegate = CreateDispatchDelegatePtr( - env_options, std::move(dispatch_delegate_options)); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "Qualcomm HTP"; -#endif - - ASSERT_EQ(interpreter.ModifyGraphWithDelegate(dispatch_delegate.get()), - kTfLiteOk); - - // Create and register tensor buffers for all inputs and outputs. - - std::vector input_buffers; - for (int i = 0; i < interpreter.inputs().size(); ++i) { - auto input_buffer_requirements = - buffer_context.GetBufferRequirement(interpreter.input_tensor(i)); - ASSERT_TRUE(input_buffer_requirements); - ASSERT_EQ((*input_buffer_requirements)->SupportedTypes().Value()[0], - kLiteRtTensorBufferTypeFastRpc); - auto input_buffer = - buffer_context.CreateBufferForTensor(interpreter.input_tensor(i)); - ASSERT_TRUE(input_buffer); - ASSERT_TRUE(input_buffer->IsOwned()); - ASSERT_EQ(*input_buffer->BufferType(), kLiteRtTensorBufferTypeFastRpc); - auto duplicate_buffer = (*input_buffer).Duplicate(); - ASSERT_TRUE(duplicate_buffer); - auto status = buffer_context.RegisterTensorBuffer( - interpreter.input_tensor(i), std::move(*duplicate_buffer)); - ASSERT_EQ(status, kLiteRtStatusOk); - input_buffers.push_back(std::move(*input_buffer)); - } - - std::vector output_buffers; - for (int i = 0; i < interpreter.outputs().size(); ++i) { - auto output_buffer_requirements = - buffer_context.GetBufferRequirement(interpreter.output_tensor(i)); - ASSERT_TRUE(output_buffer_requirements); - ASSERT_EQ((*output_buffer_requirements)->SupportedTypes().Value()[0], - kLiteRtTensorBufferTypeFastRpc); - auto output_buffer = - buffer_context.CreateBufferForTensor(interpreter.output_tensor(i)); - ASSERT_TRUE(output_buffer.HasValue()); - ASSERT_TRUE(output_buffer->IsOwned()); - ASSERT_EQ(*output_buffer->BufferType(), kLiteRtTensorBufferTypeFastRpc); - auto duplicate_buffer = (*output_buffer).Duplicate(); - ASSERT_TRUE(duplicate_buffer); - auto status = buffer_context.RegisterTensorBuffer( - interpreter.output_tensor(i), std::move(*duplicate_buffer)); - ASSERT_EQ(status, kLiteRtStatusOk); - output_buffers.push_back(std::move(*output_buffer)); - } - - // Get the list of signatures and check it. - auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 1); - - tflite::impl::SignatureRunner* runner = - interpreter.GetSignatureRunner(/*signature_key=*/nullptr); - ASSERT_NE(runner, nullptr); - - EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); - - // Fill model inputs. - ASSERT_STREQ(runner->input_names()[0], "arg0"); - auto& input_0_buffer = input_buffers[0]; - input_0_buffer.Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); - - ASSERT_STREQ(runner->input_names()[1], "arg1"); - auto& input_1_buffer = input_buffers[1]; - input_1_buffer.Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); - - EXPECT_EQ(runner->Invoke(), kTfLiteOk); - - // Check model output. - ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); - auto& output_buffer = output_buffers[0]; - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - auto read_success = output_buffer.Read(output_span); - ASSERT_TRUE(read_success); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, CompiledModel) { - auto model_with_byte_code = - internal::GetModelBufWithByteCode(testing::GetTestFilePath(kTfliteFile), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - auto model = Model::CreateFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "Qualcomm HTP"; -#endif - auto jit_compilation_options = CompilationOptions::Create(); - ASSERT_TRUE(jit_compilation_options); - ASSERT_TRUE(jit_compilation_options->SetHardwareAccelerators( - kLiteRtHwAcceleratorCpu)); - - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - auto res_compiled_model = - CompiledModel::Create(*env, *model, *jit_compilation_options); - ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; - auto& compiled_model = *res_compiled_model; - - auto signatures = model->GetSignatures(); - ASSERT_TRUE(signatures); - EXPECT_EQ(signatures->size(), 1); - auto& signature = signatures->at(0); - auto signature_key = signature.Key(); - EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); - size_t signature_index = 0; - - auto input_buffers_res = compiled_model.CreateInputBuffers(signature_index); - EXPECT_TRUE(input_buffers_res); - auto& input_buffers = *input_buffers_res; - - auto output_buffers_res = compiled_model.CreateOutputBuffers(signature_index); - EXPECT_TRUE(output_buffers_res); - auto& output_buffers = *output_buffers_res; - - // Fill model inputs. - auto input_names = signature.InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - auto output_names = signature.OutputNames(); - EXPECT_EQ(output_names.size(), 1); - EXPECT_EQ(output_names.at(0), "tfl.custom"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[0].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); -} - -TEST(DispatchDelegate, QualcommSharedInput) { - auto model_with_byte_code = internal::GetModelBufWithByteCode( - testing::GetTestFilePath("shared_input_cpu_npu.tflite"), - testing::GetTestFilePath(kNpuFile)); - ASSERT_TRUE(model_with_byte_code); - auto model = Model::CreateFromBuffer(*model_with_byte_code); - ASSERT_TRUE(model); - -#if !defined(__ANDROID__) - GTEST_SKIP() << "The rest of this test is specific to Android devices with a " - "Qualcomm HTP"; -#endif - auto jit_compilation_options = CompilationOptions::Create(); - ASSERT_TRUE(jit_compilation_options); - ASSERT_TRUE(jit_compilation_options->SetHardwareAccelerators( - kLiteRtHwAcceleratorCpu)); - - const std::vector environment_options = { - litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - kDispatchLibraryDir, - }, - }; - auto env = - litert::Environment::Create(absl::MakeConstSpan(environment_options)); - ASSERT_TRUE(env); - auto res_compiled_model = - CompiledModel::Create(*env, *model, *jit_compilation_options); - ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; - auto& compiled_model = *res_compiled_model; - - size_t signature_index = 0; - auto signature = *model->GetSignature(signature_index); - auto input_buffers = *compiled_model.CreateInputBuffers(signature_index); - auto output_buffers = *compiled_model.CreateOutputBuffers(signature_index); - - // Fill model inputs. - auto input_names = signature.InputNames(); - EXPECT_EQ(input_names.size(), 2); - EXPECT_EQ(input_names.at(0), "arg0"); - EXPECT_EQ(input_names.at(1), "arg1"); - ASSERT_TRUE(input_buffers[0].Write( - absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); - ASSERT_TRUE(input_buffers[1].Write( - absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - - // Execute model. - compiled_model.Run(signature_index, input_buffers, output_buffers); - - // Check model output. - auto output_names = signature.OutputNames(); - EXPECT_EQ(output_names.size(), 2); - { - EXPECT_EQ(output_names.at(0), "tfl.add"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[0].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } - { - EXPECT_EQ(output_names.at(1), "tfl.custom"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[1].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; - } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); - } -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/litert_dispatch.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/litert_dispatch.cc deleted file mode 100644 index f725941832d5be..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/litert_dispatch.cc +++ /dev/null @@ -1,571 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -#include - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" -#include "tensorflow/lite/experimental/litert/core/version.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h" - -#define INVOKE_FUNC(function, ...) \ - if (!TheApi.interface) { \ - LITERT_LOG(LITERT_ERROR, "Dispatch API interface not found"); \ - return kLiteRtStatusErrorRuntimeFailure; \ - } \ - if (!TheApi.interface->function) { \ - LITERT_LOG(LITERT_ERROR, #function " not found"); \ - return kLiteRtStatusErrorRuntimeFailure; \ - } \ - return TheApi.interface->function(__VA_ARGS__); - -#define INVOKE_ASYNC_FUNC(function, ...) \ - if (!TheApi.async_interface) { \ - LITERT_LOG(LITERT_ERROR, "Dispatch API async interface not found"); \ - return kLiteRtStatusErrorRuntimeFailure; \ - } \ - if (!TheApi.async_interface->function) { \ - LITERT_LOG(LITERT_ERROR, #function " not found"); \ - return kLiteRtStatusErrorRuntimeFailure; \ - } \ - return TheApi.async_interface->function(__VA_ARGS__); - -#define INVOKE_GRAPH_FUNC(function, ...) \ - if (!TheApi.graph_interface) { \ - LITERT_LOG(LITERT_ERROR, "Dispatch API graoh interface not found"); \ - return kLiteRtStatusErrorRuntimeFailure; \ - } \ - if (!TheApi.graph_interface->function) { \ - LITERT_LOG(LITERT_ERROR, #function " not found"); \ - return kLiteRtStatusErrorRuntimeFailure; \ - } \ - return TheApi.graph_interface->function(__VA_ARGS__); - -namespace { - -litert::SharedLibrary* DispatchSharedLibrary = nullptr; -bool IsTheApiInitialized = false; -LiteRtDispatchApi TheApi = { - /*.version=*/{/*.major=*/0, /*.minor=*/0, /*.patch=*/0}, - /*.interface=*/nullptr, - /*.async_interface=*/nullptr, - /*.graph_interface=*/nullptr, -}; - -LiteRtStatus Initialize(const LiteRtDispatchOption* options, int num_options) { - INVOKE_FUNC(initialize, options, num_options); -} - -litert::Expected GetSharedLibraryPath( - const LiteRtDispatchOption* options, int num_options) { - std::vector dispatch_lib_paths; - for (auto i = 0; i < num_options; ++i) { - auto& option = options[i]; - if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { - litert::internal::FindLiteRtDispatchSharedLibs(option.value.str_value, - dispatch_lib_paths); - } - } - if (dispatch_lib_paths.empty()) { - LITERT_LOG(LITERT_ERROR, "No dispatch library found"); - return litert::Error(kLiteRtStatusErrorRuntimeFailure); - } - if (dispatch_lib_paths.size() > 1) { - LITERT_LOG(LITERT_WARNING, "Multiple dispatch libraries found"); - } - return dispatch_lib_paths[0]; -} -} // namespace - -// ///////////////////////////////////////////////////////////////////////////// -// Basic Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus LiteRtDispatchInitialize(const LiteRtDispatchOption* options, - int num_options) { - if (IsTheApiInitialized) { - return kLiteRtStatusOk; - } - - // TODO(piyu): support Android systems where libraries are not unpacked in the - // system directory. - LITERT_ASSIGN_OR_RETURN(auto shared_lib_path, - GetSharedLibraryPath(options, num_options)); - - LITERT_LOG(LITERT_INFO, "Loading shared library: %s", - shared_lib_path.c_str()); - - if (!DispatchSharedLibrary) { - DispatchSharedLibrary = new litert::SharedLibrary(); - } - - LITERT_ASSIGN_OR_RETURN( - *DispatchSharedLibrary, - litert::SharedLibrary::Load(shared_lib_path, - litert::RtldFlags::Now().Local())); - - using LiteRtDispatchGetApi_t = LiteRtStatus (*)(LiteRtDispatchApi*); - LITERT_ASSIGN_OR_RETURN( - auto LiteRtDispatchGetApi, - DispatchSharedLibrary->LookupSymbol( - "LiteRtDispatchGetApi")); - - if (auto status = LiteRtDispatchGetApi(&TheApi); status != kLiteRtStatusOk) { - return status; - } - - if (!litert::internal::IsSameVersionAsRuntime(TheApi.version)) { - LITERT_LOG(LITERT_ERROR, "Unsupported dispatch runtime version"); - return kLiteRtStatusErrorWrongVersion; - } - - auto status = Initialize(options, num_options); - if (status == kLiteRtStatusOk) { - IsTheApiInitialized = true; - } - return status; -} - -LiteRtStatus LiteRtDispatchGetApiVersion(LiteRtApiVersion* api_version) { - if (!api_version) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - *api_version = TheApi.version; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtDispatchGetVendorId(const char** vendor_id) { - if (!vendor_id) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_vendor_id, vendor_id); -} - -LiteRtStatus LiteRtDispatchGetBuildId(const char** build_id) { - if (!build_id) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_build_id, build_id); -} - -LiteRtStatus LiteRtDispatchGetCapabilities(int* capabilities) { - if (!capabilities) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_capabilities, capabilities); -} - -LiteRtStatus LiteRtDispatchDeviceContextCreate( - LiteRtDispatchDeviceContext* device_context) { - if (!device_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(device_context_create, device_context); -} - -LiteRtStatus LiteRtDispatchDeviceContextDestroy( - LiteRtDispatchDeviceContext device_context) { - if (!device_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(device_context_destroy, device_context); -} - -LiteRtStatus LiteRtDispatchGetInputRequirements( - LiteRtDispatchInvocationContext invocation_context, int input_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (!invocation_context || !tensor_type || !tensor_buffer_requirements) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_input_requirements, invocation_context, input_index, - tensor_type, tensor_buffer_requirements); -} - -LiteRtStatus LiteRtDispatchGetOutputRequirements( - LiteRtDispatchInvocationContext invocation_context, int output_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (!invocation_context || !tensor_type || !tensor_buffer_requirements) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_output_requirements, invocation_context, output_index, - tensor_type, tensor_buffer_requirements); -} - -LiteRtStatus LiteRtDispatchRegisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBuffer tensor_buffer, - LiteRtTensorBufferHandle* tensor_buffer_handle) { - if (!device_context || !tensor_buffer || !tensor_buffer_handle) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(register_tensor_buffer, device_context, tensor_buffer, - tensor_buffer_handle); -} - -LiteRtStatus LiteRtDispatchUnregisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (!device_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(unregister_tensor_buffer, device_context, tensor_buffer_handle); -} - -LiteRtStatus LiteRtDispatchInvocationContextCreate( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs, - LiteRtDispatchInvocationContext* invocation_context) { - if (!device_context || !exec_bytecode_buffer || !invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(invocation_context_create, device_context, exec_type, - exec_bytecode_buffer, function_name, num_inputs, num_outputs, - invocation_context); -} - -LiteRtStatus LiteRtDispatchInvocationContextDestroy( - LiteRtDispatchInvocationContext invocation_context) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(invocation_context_destroy, invocation_context); -} - -LiteRtStatus LiteRtDispatchAttachInput( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(attach_input, invocation_context, graph_input_index, - tensor_buffer_handle); -} - -LiteRtStatus LiteRtDispatchAttachOutput( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - if (!TheApi.interface) { - LITERT_LOG(LITERT_ERROR, "Dispatch API interface not found"); - return kLiteRtStatusErrorRuntimeFailure; - } - if (!TheApi.interface->attach_output) { - LITERT_LOG(LITERT_ERROR, "attach_output_tensor_buffer not found"); - return kLiteRtStatusErrorRuntimeFailure; - } - INVOKE_FUNC(attach_output, invocation_context, graph_output_index, - tensor_buffer_handle); -} - -LiteRtStatus LiteRtDispatchDetachInput( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(detach_input, invocation_context, graph_input_index, - tensor_buffer_handle); -} - -LiteRtStatus LiteRtDispatchDetachOutput( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(detach_output, invocation_context, graph_output_index, - tensor_buffer_handle); -} - -LiteRtStatus LiteRtDispatchInvoke( - LiteRtDispatchInvocationContext invocation_context) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(invoke, invocation_context); -} - -LiteRtStatus LiteRtDispatchStartMetricsCollection( - LiteRtDispatchInvocationContext invocation_context, int detail_level) { - if (!invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } else if (detail_level < 0) { - LITERT_LOG(LITERT_ERROR, "Invalid detail level"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(start_metrics_collection, invocation_context, detail_level); -} - -LiteRtStatus LiteRtDispatchStopMetricsCollection( - LiteRtDispatchInvocationContext invocation_context, - LiteRtDispatchMetrics* metrics) { - if (!invocation_context || !metrics) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(stop_metrics_collection, invocation_context, metrics); -} - -LiteRtStatus LiteRtDispatchGetNumMetrics(LiteRtDispatchMetrics metrics, - int* num_metrics) { - if (!metrics || !num_metrics) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_num_metrics, metrics, num_metrics); -} - -LiteRtStatus LiteRtDispatchGetMetric(LiteRtDispatchMetrics metrics, - int metric_index, LiteRtMetric* metric) { - if (!metrics || !metric) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(get_metric, metrics, metric_index, metric); -} - -LiteRtStatus LiteRtDispatchDestroyMetrics(LiteRtDispatchMetrics metrics) { - if (!metrics) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_FUNC(destroy_metrics, metrics); -} - -// ///////////////////////////////////////////////////////////////////////////// -// Async Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus LiteRtDispatchAttachInputEvent( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtEvent input_event) { - if (!invocation_context || !input_event) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_ASYNC_FUNC(attach_input_event, invocation_context, graph_input_index, - input_event); -} - -LiteRtStatus LiteRtDispatchInvokeAsync( - LiteRtDispatchInvocationContext invocation_context, int num_output_events, - LiteRtEvent* output_events) { - if (!invocation_context || !output_events) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_ASYNC_FUNC(invoke_async, invocation_context, num_output_events, - output_events); -} - -// ///////////////////////////////////////////////////////////////////////////// -// Graph Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus LiteRtDispatchGraphCreate( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph* graph) { - if (!device_context || !graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(graph_create, device_context, graph); -} - -LiteRtStatus LiteRtDispatchGraphDestroy(LiteRtDispatchGraph graph) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(graph_destroy, graph); -} - -LiteRtStatus LiteRtDispatchAddNode(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, - LiteRtDispatchNodeType node_type) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(add_node, graph, node_id, node_type); -} - -LiteRtStatus LiteRtDispatchAddEdge(LiteRtDispatchGraph graph, - LiteRtDispatchEdgeId edge_id) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(add_edge, graph, edge_id); -} - -LiteRtStatus LiteRtDispatchConnectNodeInput(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, - int input_index, - LiteRtDispatchEdgeId edge_id) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(connect_node_input, graph, node_id, input_index, edge_id); -} - -LiteRtStatus LiteRtDispatchConnectNodeOutput(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, - int output_index, - LiteRtDispatchEdgeId edge_id) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(connect_node_output, graph, node_id, output_index, edge_id); -} - -LiteRtStatus LiteRtDispatchConnectGraphInput(LiteRtDispatchGraph graph, - int input_index, - LiteRtDispatchEdgeId edge_id) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(connect_graph_input, graph, input_index, edge_id); -} - -LiteRtStatus LiteRtDispatchConnectGraphOutput(LiteRtDispatchGraph graph, - int output_index, - LiteRtDispatchEdgeId edge_id) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(connect_graph_output, graph, output_index, edge_id); -} - -LiteRtStatus LiteRtDispatchLoadExecutable( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType type, const LiteRtMemBuffer* bytecode_buffer, - LiteRtDispatchExecutableHandle* exec_handle) { - if (!device_context || !bytecode_buffer || !exec_handle) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - if (!TheApi.graph_interface) { - LITERT_LOG(LITERT_ERROR, "Dispatch API graph interface not found"); - return kLiteRtStatusErrorRuntimeFailure; - } - if (!TheApi.graph_interface->load_executable) { - LITERT_LOG(LITERT_ERROR, "load_executable not found"); - return kLiteRtStatusErrorRuntimeFailure; - } - INVOKE_GRAPH_FUNC(load_executable, device_context, type, bytecode_buffer, - exec_handle); -} - -LiteRtStatus LiteRtDispatchUnloadExecutable( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableHandle exec_handle) { - if (!device_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(unload_executable, device_context, exec_handle); -} - -LiteRtStatus LiteRtDispatchAssignNodeFunction( - LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, - LiteRtDispatchExecutableHandle exec_handle, const char* function_name) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(assign_node_function, graph, node_id, exec_handle, - function_name); -} - -LiteRtStatus LiteRtDispatchAnnotateGraph(LiteRtDispatchGraph graph, - const char* key, const char* value) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(annotate_graph, graph, key, value); -} - -LiteRtStatus LiteRtDispatchAnnotateNode(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, - const char* key, const char* value) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(annotate_node, graph, node_id, key, value); -} - -LiteRtStatus LiteRtDispatchAnnotateEdge(LiteRtDispatchGraph graph, - LiteRtDispatchEdgeId edge_id, - const char* key, const char* value) { - if (!graph) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(annotate_edge, graph, edge_id, key, value); -} - -LiteRtStatus LiteRtDispatchInvocationContextCreateFromGraph( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, - LiteRtDispatchInvocationContext* invocation_context) { - if (!device_context || !graph || !invocation_context) { - LITERT_LOG(LITERT_ERROR, "Null input"); - return kLiteRtStatusErrorInvalidArgument; - } - INVOKE_GRAPH_FUNC(invocation_context_create_from_graph, device_context, graph, - invocation_context); -} diff --git a/tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.cc b/tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.cc deleted file mode 100644 index 450ed56f8dc2aa..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.cc +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h" - -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/const_init.h" -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#if LITERT_HAS_DMABUF_SUPPORT -#include -#include -#endif // LITERT_HAS_DMABUF_SUPPORT - -namespace litert { -namespace internal { - -#if LITERT_HAS_DMABUF_SUPPORT -namespace { - -class DmaBufLibrary { - public: - using Ptr = std::unique_ptr; - - ~DmaBufLibrary() { - if (allocator_) { - free_allocator_(allocator_); - } - } - - static Expected Create() { - DlHandle dlhandle(::dlopen("libdmabufheap.so", RTLD_LAZY | RTLD_LOCAL), - ::dlclose); - if (!dlhandle) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "libdmabufheap.so not found"); - } - - auto create_allocator = reinterpret_cast( - ::dlsym(dlhandle.get(), "CreateDmabufHeapBufferAllocator")); - if (!create_allocator) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "CreateDmabufHeapBufferAllocator not found"); - } - - auto free_allocator = reinterpret_cast( - ::dlsym(dlhandle.get(), "FreeDmabufHeapBufferAllocator")); - if (!free_allocator) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "FreeDmabufHeapBufferAllocator not found"); - } - - auto alloc_buffer = reinterpret_cast( - ::dlsym(dlhandle.get(), "DmabufHeapAlloc")); - if (!alloc_buffer) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "DmabufHeapAlloc not found"); - } - - void* allocator = create_allocator(); - if (!allocator) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "CreateDmabufHeapBufferAllocator failed"); - } - - return Ptr(new DmaBufLibrary(std::move(dlhandle), allocator, free_allocator, - alloc_buffer)); - } - - Expected Alloc(size_t size) { - int fd = alloc_buffer_(allocator_, kDmaBufHeap, size, /*flags=*/0, - /*legacy_align=*/0); - if (fd < 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate DMA-BUF buffer"); - } - void* addr = - ::mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); - if (addr == MAP_FAILED) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to mem-map DMA-BUF buffer"); - } - records_[addr] = Record{.fd = fd, .addr = addr, .size = size}; - return DmaBufBuffer{.fd = fd, .addr = addr}; - } - - void Free(void* addr) { - auto iter = records_.find(addr); - if (iter == records_.end()) { - return; - } - auto& record = iter->second; - ::munmap(record.addr, record.size); - ::close(record.fd); - records_.erase(iter); - } - - private: - static constexpr const char* kDmaBufHeap = "system"; - - struct Record { - int fd; - void* addr; - size_t size; - }; - - using DlHandle = std::unique_ptr; - using CreateAllocator = void* (*)(); - using FreeAllocator = void (*)(void*); - using AllocBuffer = int (*)(void*, const char*, size_t, unsigned int, size_t); - - DmaBufLibrary(DlHandle&& dlhandle, void* allocator, - FreeAllocator free_allocator, AllocBuffer alloc_buffer) - : dlhandle_(std::move(dlhandle)) { - allocator_ = allocator; - free_allocator_ = free_allocator; - alloc_buffer_ = alloc_buffer; - } - - DlHandle dlhandle_; - void* allocator_; - FreeAllocator free_allocator_; - AllocBuffer alloc_buffer_; - absl::node_hash_map records_; -}; - -DmaBufLibrary* TheDmaBufLibrary; -ABSL_CONST_INIT absl::Mutex TheMutex(absl::kConstInit); - -Expected InitLibraryIfNeededUnlocked() { - if (!TheDmaBufLibrary) { - if (auto library = DmaBufLibrary::Create(); library) { - TheDmaBufLibrary = library->release(); - } else { - return Unexpected(library.Error()); - } - } - return {}; -} - -} // namespace -#endif // LITERT_HAS_DMABUF_SUPPORT - -bool DmaBufBuffer::IsSupported() { -#if LITERT_HAS_DMABUF_SUPPORT - absl::MutexLock lock(&TheMutex); - auto status = InitLibraryIfNeededUnlocked(); - return static_cast(status); -#else // LITERT_HAS_DMABUF_SUPPORT - return false; -#endif // LITERT_HAS_DMABUF_SUPPORT -} - -Expected DmaBufBuffer::Alloc(size_t size) { -#if LITERT_HAS_DMABUF_SUPPORT - absl::MutexLock lock(&TheMutex); - if (auto status = InitLibraryIfNeededUnlocked(); !status) { - return Unexpected(status.Error()); - } - return TheDmaBufLibrary->Alloc(size); -#else // LITERT_HAS_DMABUF_SUPPORT - return Unexpected(kLiteRtStatusErrorUnsupported, - "DmaBufBuffer::Alloc not implemented for this platform"); -#endif // LITERT_HAS_DMABUF_SUPPORT -} - -void DmaBufBuffer::Free(void* addr) { -#if LITERT_HAS_DMABUF_SUPPORT - absl::MutexLock lock(&TheMutex); - if (TheDmaBufLibrary) { - TheDmaBufLibrary->Free(addr); - } -#endif // LITERT_HAS_DMABUF_SUPPORT -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h b/tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h deleted file mode 100644 index a391e0cf892a56..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DMABUF_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DMABUF_BUFFER_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -struct DmaBufBuffer { - int fd; - void* addr; - - static bool IsSupported(); - static Expected Alloc(size_t size); - static void Free(void* addr); -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_DMABUF_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/event.cc b/tensorflow/lite/experimental/litert/runtime/event.cc deleted file mode 100644 index 70f2cbb5beb512..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/event.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/event.h" - -#include - -#include -#include - -#include "absl/strings/str_format.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -#if LITERT_HAS_SYNC_FENCE_SUPPORT -#include -#include -#endif // LITERT_HAS_SYNC_FENCE_SUPPORT -#if LITERT_HAS_OPENCL_SUPPORT -#include "tensorflow/lite/experimental/litert/runtime/gpu_environment.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_event.h" -#endif // LITERT_HAS_OPENCL_SUPPORT - -using litert::Error; -using litert::Expected; - -Expected LiteRtEventT::Wait(int64_t timeout_in_ms) { - if (type == LiteRtEventTypeSyncFenceFd) { -#if LITERT_HAS_SYNC_FENCE_SUPPORT - struct pollfd fds = { - .fd = fd, - .events = POLLIN, - }; - - int ret; - do { - ret = ::poll(&fds, 1, timeout_in_ms); - if (ret == 1) { - break; - } else if (ret == 0) { - return Error(kLiteRtStatusErrorTimeoutExpired, "Timeout expired"); - } - } while (ret == -1 && (errno == EINTR || errno == EAGAIN)); - - if (ret < 0) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Error waiting for fence"); - } - - return {}; - -#else - return Error(kLiteRtStatusErrorUnsupported, - "LiteRtEventWait not implemented for this platform"); -#endif - } else if (type == LiteRtEventTypeOpenCl) { -#if LITERT_HAS_OPENCL_SUPPORT - return litert::cl::WaitForEvents(/*num_events=*/1, - /*event_list=*/&opencl_event); -#else - return Error(kLiteRtStatusErrorUnsupported, - "LiteRtEventWait not implemented for this platform"); -#endif - } - return Error(kLiteRtStatusErrorInvalidArgument, "Invalid event type"); -} - -#if LITERT_HAS_SYNC_FENCE_SUPPORT -namespace { -inline bool IsFdValid(int fd) { - return ::fcntl(fd, F_GETFD) != -1 || errno != EBADF; -} -} // namespace -#endif - -LiteRtEventT::~LiteRtEventT() { -#if LITERT_HAS_SYNC_FENCE_SUPPORT - if (type == LiteRtEventTypeSyncFenceFd && owns_fd && IsFdValid(fd)) { - ::close(fd); - } -#endif -} - -Expected LiteRtEventT::Signal() { -#if LITERT_HAS_OPENCL_SUPPORT - if (type == LiteRtEventTypeOpenCl) { - return litert::cl::SetUserEventStatus(opencl_event); - } -#endif - return Error(kLiteRtStatusErrorInvalidArgument, - "The event signal is not supported"); -} - -Expected LiteRtEventT::CreateManaged(LiteRtEventType type) { -#if LITERT_HAS_OPENCL_SUPPORT - if (type == LiteRtEventTypeOpenCl) { - auto& env = litert::internal::GpuEnvironmentSingleton::GetInstance(); - LITERT_ASSIGN_OR_RETURN( - cl_event user_event, - litert::cl::CreateUserEvent(env.getContext()->context())); - return new LiteRtEventT{ - .type = LiteRtEventTypeOpenCl, - .opencl_event = user_event, - }; - } -#endif - return Error(kLiteRtStatusErrorInvalidArgument, - absl::StrFormat("CreateManaged doesn't support type %d", type)); -} diff --git a/tensorflow/lite/experimental/litert/runtime/event.h b/tensorflow/lite/experimental/litert/runtime/event.h deleted file mode 100644 index df93d5cfac10b4..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/event.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EVENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EVENT_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#if LITERT_HAS_OPENCL_SUPPORT -extern "C" { -typedef struct _cl_event* cl_event; -} -#endif // LITERT_HAS_OPENCL_SUPPORT - -struct LiteRtEventT { - LiteRtEventType type = LiteRtEventTypeUnknown; -#if LITERT_HAS_SYNC_FENCE_SUPPORT - int fd = -1; - bool owns_fd = false; -#endif -#if LITERT_HAS_OPENCL_SUPPORT - cl_event opencl_event; -#endif - ~LiteRtEventT(); - litert::Expected Wait(int64_t timeout_in_ms); - litert::Expected GetSyncFenceFd() const { -#if LITERT_HAS_SYNC_FENCE_SUPPORT - return fd; -#else - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Sync fence is not supported on this platform"); -#endif - } - litert::Expected Signal(); - static litert::Expected CreateManaged(LiteRtEventType type); -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EVENT_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.cc b/tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.cc deleted file mode 100644 index 63ace18c1a85da..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.cc +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" - -#include - -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/runtime/tfl_utils.h" - -namespace litert { -namespace internal { - -LiteRtStatus ExternalLiteRtBufferContext::RegisterBufferRequirement( - const TfLiteOpaqueTensor* tensor, - TensorBufferRequirements&& buffer_requirements) { - if (buffer_requirements_.find(tensor) != buffer_requirements_.end()) { - LITERT_LOG(LITERT_ERROR, - "RegisterBufferRequirement already exists for tensor: %p", - tensor); - return kLiteRtStatusErrorRuntimeFailure; - } - buffer_requirements_[tensor] = std::move(buffer_requirements); - return kLiteRtStatusOk; -} - -litert::Expected -ExternalLiteRtBufferContext::GetBufferRequirement( - const TfLiteOpaqueTensor* tensor) { - auto it = buffer_requirements_.find(tensor); - if (it == buffer_requirements_.end()) { - return litert::Unexpected(kLiteRtStatusErrorNotFound, - "Buffer requirement not found"); - } - return &(it->second); -} - -LiteRtStatus ExternalLiteRtBufferContext::RegisterTensorBuffer( - const TfLiteOpaqueTensor* tensor, TensorBuffer&& tensor_buffer) { - tensor_buffers_[tensor] = std::move(tensor_buffer); - return kLiteRtStatusOk; -} - -litert::Expected ExternalLiteRtBufferContext::GetTensorBuffer( - const TfLiteOpaqueTensor* tensor) { - auto it = tensor_buffers_.find(tensor); - if (it == tensor_buffers_.end()) { - return litert::Unexpected(kLiteRtStatusErrorNotFound, - "Tensor buffer not found"); - } - - auto duplicate_tensor_buffer = it->second.Duplicate(); - if (!duplicate_tensor_buffer) { - return litert::Unexpected(duplicate_tensor_buffer.Error()); - } - return std::move(duplicate_tensor_buffer.Value()); -} - -litert::Expected -ExternalLiteRtBufferContext::CreateBufferForTensor( - const TfLiteOpaqueTensor* tensor) { - auto tensor_buffer_requirements = GetBufferRequirement(tensor); - if (!tensor_buffer_requirements) { - return litert::Unexpected(tensor_buffer_requirements.Error()); - } - - auto tensor_type = litert::internal::ConvertTensorType(tensor); - if (!tensor_type) { - return litert::Unexpected(tensor_type.Error()); - } - - auto supported_tensor_buffer_types = - (*tensor_buffer_requirements)->SupportedTypes(); - if (!supported_tensor_buffer_types) { - return litert::Unexpected(supported_tensor_buffer_types.Error()); - } - if (supported_tensor_buffer_types->empty()) { - return litert::Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "Insufficient number of supported tensor buffer types"); - } - - // For now we simply pick the first buffer type that's supported. - LiteRtTensorBufferType tensor_buffer_type = - (*supported_tensor_buffer_types)[0]; - - auto tensor_buffer_size = (*tensor_buffer_requirements)->BufferSize(); - if (!tensor_buffer_size) { - return litert::Unexpected(tensor_buffer_size.Error()); - } - auto litert_tensor_type = static_cast(*tensor_type); - - LiteRtTensorBuffer litert_tensor_buffer; - if (auto status = LiteRtCreateManagedTensorBuffer( - tensor_buffer_type, &litert_tensor_type, *tensor_buffer_size, - &litert_tensor_buffer); - status != kLiteRtStatusOk) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create managed tensor buffer"); - } - - return TensorBuffer(litert_tensor_buffer, /*owned=*/true); -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h b/tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h deleted file mode 100644 index 81fc1fcdea9871..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EXTERNAL_LITERT_BUFFER_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EXTERNAL_LITERT_BUFFER_CONTEXT_H_ - -#include -#include -#include - -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" - -namespace litert::internal { - -class ExternalLiteRtBufferContext : public TfLiteExternalContext { - public: - ExternalLiteRtBufferContext() = default; - ~ExternalLiteRtBufferContext() = default; - - // Registers a tensor buffer requirements for the given tensor. - // The registered TensorBufferRequirements object is owned by - // ExternalLiteRtBufferContext. - // Note: Currently, the system pre-registers tensor buffer requirements before - // they're actually used. A more efficient approach would be to query - // DelegateKernel only when these requirements are needed. - LiteRtStatus RegisterBufferRequirement( - const TfLiteOpaqueTensor* tensor, - TensorBufferRequirements&& buffer_requirements); - - inline LiteRtStatus RegisterBufferRequirement( - const TfLiteTensor* tensor, - TensorBufferRequirements&& buffer_requirements) { - return RegisterBufferRequirement( - reinterpret_cast(tensor), - std::move(buffer_requirements)); - } - - inline LiteRtStatus RegisterLiteRtBufferRequirement( - const TfLiteTensor* tensor, - LiteRtTensorBufferRequirements& litert_buffer_requirements) { - return RegisterBufferRequirement( - reinterpret_cast(tensor), - TensorBufferRequirements(litert_buffer_requirements, - /*owned=*/true)); - } - - // Gets a registered tensor buffer requirements for the given tensor. - // The returned TensorBufferRequirements object is still owned by - // ExternalLiteRtBufferContext. - litert::Expected GetBufferRequirement( - const TfLiteOpaqueTensor* tensor); - - inline litert::Expected GetBufferRequirement( - const TfLiteTensor* tensor) { - return GetBufferRequirement( - reinterpret_cast(tensor)); - } - - // Registers a tensor buffer for the given tensor. - // The registered TensorBuffer object is owned by ExternalLiteRtBufferContext. - LiteRtStatus RegisterTensorBuffer(const TfLiteOpaqueTensor* tensor, - TensorBuffer&& tensor_buffer); - - inline LiteRtStatus RegisterTensorBuffer(const TfLiteTensor* tensor, - TensorBuffer&& tensor_buffer) { - return RegisterTensorBuffer( - reinterpret_cast(tensor), - std::move(tensor_buffer)); - } - - // Gets a registered tensor buffer for the given tensor. - // The returned TensorBuffer object is duplication (reference counted) - // of registered TensorBuffer. - litert::Expected GetTensorBuffer( - const TfLiteOpaqueTensor* tensor); - - inline litert::Expected GetTensorBuffer( - const TfLiteTensor* tensor) { - return GetTensorBuffer(reinterpret_cast(tensor)); - } - - // Creates a tensor buffer for the given tensor. - // The callers takes ownership of the returned TensorBuffer object. - litert::Expected CreateBufferForTensor( - const TfLiteOpaqueTensor* tensor); - - inline litert::Expected CreateBufferForTensor( - const TfLiteTensor* tensor) { - return CreateBufferForTensor( - reinterpret_cast(tensor)); - } - - // Sets the async execution mode. It's set by CompiledModel and used by - // DelegateKernel to decide whether to use async execution mode. - inline void SetAsyncExecutionMode(bool async_execution_mode) { - async_execution_mode_ = async_execution_mode; - } - - // Returns true if the async execution mode is set. - inline bool IsAsyncExecutionMode() const { return async_execution_mode_; } - - private: - std::unordered_map - buffer_requirements_; - std::unordered_map tensor_buffers_; - - ExternalLiteRtBufferContext(const ExternalLiteRtBufferContext&) = delete; - ExternalLiteRtBufferContext& operator=(const ExternalLiteRtBufferContext&) = - delete; - - bool async_execution_mode_ = false; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EXTERNAL_LITERT_BUFFER_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.cc b/tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.cc deleted file mode 100644 index d0ec124b3177da..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.cc +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h" - -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/const_init.h" -#include "absl/synchronization/mutex.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#if LITERT_HAS_FASTRPC_SUPPORT -#include -#endif // LITERT_HAS_FASTRPC_SUPPORT - -namespace litert { -namespace internal { - -#if LITERT_HAS_FASTRPC_SUPPORT -namespace { - -class FastRpcMemLibrary { - public: - using Ptr = std::unique_ptr; - - static Expected Create() { - DlHandle dlhandle(::dlopen("libcdsprpc.so", RTLD_NOW | RTLD_LOCAL), - ::dlclose); - if (!dlhandle) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "libcdsprpc.so not found"); - } - - auto rpcmem_alloc = - reinterpret_cast(::dlsym(dlhandle.get(), "rpcmem_alloc")); - if (!rpcmem_alloc) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "rpcmem_alloc not found"); - } - - auto rpcmem_free = - reinterpret_cast(::dlsym(dlhandle.get(), "rpcmem_free")); - if (!rpcmem_free) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "rpcmem_free not found"); - } - - auto rpcmem_to_fd = - reinterpret_cast(::dlsym(dlhandle.get(), "rpcmem_to_fd")); - if (!rpcmem_to_fd) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "rpcmem_to_fd not found"); - } - - return Ptr(new FastRpcMemLibrary(std::move(dlhandle), rpcmem_alloc, - rpcmem_free, rpcmem_to_fd)); - } - - void* Alloc(size_t size) const { - return rpcmem_alloc_(kRpcmemHeapIdSystem, kRpcmemDefaultFlags, size); - } - - void Free(void* buffer) const { return rpcmem_free_(buffer); } - - int ToFd(void* buffer) const { return rpcmem_to_fd_(buffer); } - - private: - static constexpr int kRpcmemHeapIdSystem = 25; - static constexpr uint32_t kRpcmemDefaultFlags = 1; - - using DlHandle = std::unique_ptr; - using RpcMemAlloc = void* (*)(int, uint32_t, int); - using RpcMemFree = void (*)(void*); - using RpcMemToFd = int (*)(void*); - - FastRpcMemLibrary(DlHandle&& dlhandle, RpcMemAlloc rpcmem_alloc, - RpcMemFree rpcmem_free, RpcMemToFd rpcmem_to_fd) - : dlhandle_(std::move(dlhandle)) { - rpcmem_alloc_ = rpcmem_alloc; - rpcmem_free_ = rpcmem_free; - rpcmem_to_fd_ = rpcmem_to_fd; - } - - DlHandle dlhandle_; - RpcMemAlloc rpcmem_alloc_; - RpcMemFree rpcmem_free_; - RpcMemToFd rpcmem_to_fd_; -}; - -FastRpcMemLibrary* TheFastRpcMemLibrary; -ABSL_CONST_INIT absl::Mutex TheMutex(absl::kConstInit); - -Expected InitLibraryIfNeededUnlocked() { - if (!TheFastRpcMemLibrary) { - if (auto library = FastRpcMemLibrary::Create(); library) { - TheFastRpcMemLibrary = library->release(); - } else { - return Unexpected(library.Error()); - } - } - return {}; -} - -} // namespace -#endif // LITERT_HAS_FASTRPC_SUPPORT - -bool FastRpcBuffer::IsSupported() { -#if LITERT_HAS_FASTRPC_SUPPORT - absl::MutexLock lock(&TheMutex); - auto status = InitLibraryIfNeededUnlocked(); - return static_cast(status); -#else // LITERT_HAS_FASTRPC_SUPPORT - return false; -#endif // LITERT_HAS_FASTRPC_SUPPORT -} - -Expected FastRpcBuffer::Alloc(size_t size) { -#if LITERT_HAS_FASTRPC_SUPPORT - absl::MutexLock lock(&TheMutex); - if (auto status = InitLibraryIfNeededUnlocked(); !status) { - return status.Error(); - } - void* addr = TheFastRpcMemLibrary->Alloc(size); - int fd = TheFastRpcMemLibrary->ToFd(addr); - return FastRpcBuffer{.fd = fd, .addr = addr}; -#else // LITERT_HAS_FASTRPC_SUPPORT - return Unexpected(kLiteRtStatusErrorUnsupported, - "FastRpcBuffer::Alloc not implemented for this platform"); -#endif // LITERT_HAS_FASTRPC_SUPPORT -} - -void FastRpcBuffer::Free(void* addr) { -#if LITERT_HAS_FASTRPC_SUPPORT - absl::MutexLock lock(&TheMutex); - if (TheFastRpcMemLibrary) { - TheFastRpcMemLibrary->Free(addr); - } -#endif // LITERT_HAS_FASTRPC_SUPPORT -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h b/tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h deleted file mode 100644 index fa934ce0b693df..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_FASTRPC_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_FASTRPC_BUFFER_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -struct FastRpcBuffer { - int fd; - void* addr; - - static bool IsSupported(); - static Expected Alloc(size_t size); - static void Free(void* addr); -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_FASTRPC_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/gl_buffer.cc b/tensorflow/lite/experimental/litert/runtime/gl_buffer.cc deleted file mode 100644 index 6befdf4b844bb2..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gl_buffer.cc +++ /dev/null @@ -1,338 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" - -#include - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" - -#if LITERT_HAS_OPENGL_SUPPORT -#include -#include -#include -#include - -#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -namespace litert { -namespace internal { - -#if LITERT_HAS_AHWB_SUPPORT - -PFNGLBUFFERSTORAGEEXTERNALEXTPROC glBufferStorageExternalEXT; -PFNEGLGETNATIVECLIENTBUFFERANDROIDPROC eglGetNativeClientBufferANDROID; -PFNEGLDUPNATIVEFENCEFDANDROIDPROC eglDupNativeFenceFDANDROID; -PFNEGLCREATESYNCKHRPROC eglCreateSyncKHR; -PFNEGLWAITSYNCKHRPROC eglWaitSyncKHR; -PFNEGLCLIENTWAITSYNCKHRPROC eglClientWaitSyncKHR; -PFNEGLDESTROYSYNCKHRPROC eglDestroySyncKHR; - -bool IsAhwbToGlInteropSupported() { - static const bool extensions_allowed = [] { - eglGetNativeClientBufferANDROID = - reinterpret_cast( - eglGetProcAddress("eglGetNativeClientBufferANDROID")); - glBufferStorageExternalEXT = - reinterpret_cast( - eglGetProcAddress("glBufferStorageExternalEXT")); - eglDupNativeFenceFDANDROID = - reinterpret_cast( - eglGetProcAddress("eglDupNativeFenceFDANDROID")); - eglCreateSyncKHR = reinterpret_cast( - eglGetProcAddress("eglCreateSyncKHR")); - eglWaitSyncKHR = reinterpret_cast( - eglGetProcAddress("eglWaitSyncKHR")); - eglClientWaitSyncKHR = reinterpret_cast( - eglGetProcAddress("eglClientWaitSyncKHR")); - eglDestroySyncKHR = reinterpret_cast( - eglGetProcAddress("eglDestroySyncKHR")); - return eglClientWaitSyncKHR && eglWaitSyncKHR && - eglGetNativeClientBufferANDROID && glBufferStorageExternalEXT && - eglCreateSyncKHR && eglDupNativeFenceFDANDROID && eglDestroySyncKHR; - }(); - return extensions_allowed; -} - -Expected GlBuffer::AllocFromAhwbBuffer(AhwbBuffer& ahwb_buffer) { - LITERT_RETURN_IF_ERROR( - IsAhwbToGlInteropSupported(), - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer to GL interop is not supported")); - LITERT_RETURN_IF_ERROR( - ahwb_buffer.ahwb != nullptr, - Unexpected(kLiteRtStatusErrorRuntimeFailure, "AHardwareBuffer is null")); - - // Create GL buffer id. - GLuint gl_id; - glGenBuffers(1, &gl_id); - glBindBuffer(GL_SHADER_STORAGE_BUFFER, gl_id); - - // Create EGLClientBuffer from AHardwareBuffer. - EGLClientBuffer native_buffer = - eglGetNativeClientBufferANDROID(ahwb_buffer.ahwb); - LITERT_RETURN_IF_ERROR( - native_buffer != nullptr, - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create EGLClientBuffer from AHardwareBuffer")); - - LITERT_ASSIGN_OR_RETURN( - size_t size_bytes, - litert::internal::AhwbBuffer::GetSize(ahwb_buffer.ahwb)); - LITERT_RETURN_IF_ERROR(size_bytes != 0, - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer size is 0")); - - // Create OpenGl buffer object backed by the AHardwareBuffer. - glBufferStorageExternalEXT( - GL_SHADER_STORAGE_BUFFER, 0, size_bytes, native_buffer, - GL_MAP_READ_BIT | GL_MAP_WRITE_BIT | GL_MAP_COHERENT_BIT_EXT | - GL_MAP_PERSISTENT_BIT_EXT); - // Check for OpenGL errors. - absl::Status status = tflite::gpu::gl::GetOpenGlErrors(); - if (!status.ok()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - absl::StrCat("glBufferStorageExternalEXT: Failed to " - "create GL buffer from AHardwareBuffer: ", - status.message())); - } - // Unbind the buffer. - glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); - - // Create GL buffer object. We assume ownership of the GL buffer id so that it - // will be automatically deallocated when the internal::GlBuffer is destroyed. - tflite::gpu::gl::GlBuffer tflite_gl_buffer(GL_SHADER_STORAGE_BUFFER, gl_id, - size_bytes, /*offset=*/0, - /*has_ownership=*/true); - return GlBuffer(std::move(tflite_gl_buffer), ahwb_buffer.ahwb); -} -#endif // LITERT_HAS_AHWB_SUPPORT - -GlBuffer::GlBuffer(LiteRtGLenum target, LiteRtGLuint id, size_t size_bytes, - size_t offset, LiteRtGlBufferDeallocator deallocator) { -#if LITERT_HAS_OPENGL_SUPPORT - size_bytes_ = size_bytes; - - if (deallocator != nullptr) { - tflite_gl_buffer_ = tflite::gpu::gl::GlBuffer( - target, id, size_bytes, offset, /*has_ownership=*/false); - deallocator_ = std::move(deallocator); - } else { - tflite_gl_buffer_ = tflite::gpu::gl::GlBuffer( - target, id, size_bytes, offset, /*has_ownership=*/true); - deallocator_ = nullptr; - } -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::GlBuffer() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -GlBuffer::GlBuffer(GlBuffer&& other) { -#if LITERT_HAS_OPENGL_SUPPORT - tflite_gl_buffer_ = std::move(other.tflite_gl_buffer_); - deallocator_ = std::move(other.deallocator_); - data_ = other.data_; - size_bytes_ = other.size_bytes_; -#if LITERT_HAS_AHWB_SUPPORT - ahwb_ = other.ahwb_; -#endif // LITERT_HAS_AHWB_SUPPORT - // Reset the other GlBuffer to a default state. - other.data_ = nullptr; - other.size_bytes_ = 0; -#if LITERT_HAS_AHWB_SUPPORT - other.ahwb_ = nullptr; -#endif // LITERT_HAS_AHWB_SUPPORT -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::GlBuffer() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -GlBuffer::~GlBuffer() { -#if LITERT_HAS_OPENGL_SUPPORT - if (deallocator_ != nullptr) { - deallocator_(reinterpret_cast(tflite_gl_buffer_.id())); - } - if (data_ != nullptr) { - free(data_); - } -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::~GlBuffer() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -LiteRtGLenum GlBuffer::target() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_buffer_.target(); -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::target() is not supported"); - return 0; -#endif // LITERT_HAS_OPENGL_SUPPORT -} -LiteRtGLuint GlBuffer::id() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_buffer_.id(); -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::id() is not supported"); - return 0; -#endif // LITERT_HAS_OPENGL_SUPPORT -} -size_t GlBuffer::size_bytes() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_buffer_.bytes_size(); -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::size_bytes() is not supported"); - return 0; -#endif // LITERT_HAS_OPENGL_SUPPORT -} -size_t GlBuffer::offset() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_buffer_.offset(); -#else - LITERT_LOG(LITERT_ERROR, "GlBuffer::offset() is not supported"); - return 0; -#endif -} - -Expected GlBuffer::Alloc(size_t size_bytes) { -#if LITERT_HAS_OPENGL_SUPPORT - tflite::gpu::gl::GlBuffer tflite_gl_buffer; - - if (!tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( - size_bytes, &tflite_gl_buffer) - .ok()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate GL buffer"); - } - - return GlBuffer(std::move(tflite_gl_buffer)); -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "OpenGL buffers are not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -template Expected GlBuffer::Lock(); -template Expected GlBuffer::Lock(); -template Expected GlBuffer::Unlock(); -template Expected GlBuffer::Unlock(); - -template -Expected GlBuffer::Lock() { -#if LITERT_HAS_OPENGL_SUPPORT - absl::MutexLock lock(&mutex_); -#if LITERT_HAS_AHWB_SUPPORT - if (ahwb_ != nullptr) { - LITERT_ASSIGN_OR_RETURN(void* data, - litert::internal::AhwbBuffer::Lock(ahwb_)); - return static_cast(data); - } -#endif // LITERT_HAS_AHWB_SUPPORT - if (data_ == nullptr) { - // Ensure the data is aligned. - if (auto rc = posix_memalign(&data_, LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, - size_bytes_); - rc) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate aligned memory"); - } - if (auto status = tflite_gl_buffer_.Read( - absl::MakeSpan(static_cast(data_), size_bytes_ / sizeof(T))); - !status.ok()) { - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrCat("Failed to read GL buffer: ", status.message())); - } - } - return Expected(static_cast(data_)); -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "GlBuffer::Lock() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -template -Expected GlBuffer::Unlock() { -#if LITERT_HAS_OPENGL_SUPPORT - absl::MutexLock lock(&mutex_); -#if LITERT_HAS_AHWB_SUPPORT - if (ahwb_ != nullptr) { - return litert::internal::AhwbBuffer::Unlock(ahwb_); - } -#endif // LITERT_HAS_AHWB_SUPPORT - if (data_ == nullptr) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - "Cannot unlock a buffer that wasn't locked in the first place"); - } - if (auto status = tflite_gl_buffer_.Write(absl::MakeSpan( - static_cast(data_), size_bytes_ / sizeof(T))); - !status.ok()) { - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrCat("Failed to write GL buffer: ", status.message())); - } - return Expected(); -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "GlBuffer::Unlock() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -Expected GlBuffer::CreateEglSyncAndFence() { -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT - LITERT_RETURN_IF_ERROR( - IsAhwbToGlInteropSupported(), - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer to GL interop is not supported")); - - auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY); - LITERT_RETURN_IF_ERROR(egl_display != EGL_NO_DISPLAY, - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get EGL display")); - - EGLSyncKHR egl_sync = - eglCreateSyncKHR(egl_display, EGL_SYNC_NATIVE_FENCE_ANDROID, nullptr); - LITERT_RETURN_IF_ERROR( - egl_sync != EGL_NO_SYNC_KHR, - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create EGL sync from AHardwareBuffer")); - - int native_fence = eglDupNativeFenceFDANDROID(egl_display, egl_sync); - LITERT_RETURN_IF_ERROR( - native_fence != -1, - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to dup native fence from AHardwareBuffer")); - - return native_fence; -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer to GL interop is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/gl_buffer.h b/tensorflow/lite/experimental/litert/runtime/gl_buffer.h deleted file mode 100644 index f691317aa87405..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gl_buffer.h +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GL_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GL_BUFFER_H_ - -#include -#include - -#include "absl/synchronization/mutex.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#if LITERT_HAS_OPENGL_SUPPORT -#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -#if LITERT_HAS_AHWB_SUPPORT -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" -#endif // LITERT_HAS_AHWB_SUPPORT - -namespace litert::internal { - -class GlBuffer { - public: -#if LITERT_HAS_OPENGL_SUPPORT - explicit GlBuffer(tflite::gpu::gl::GlBuffer&& tflite_gl_buffer -#if LITERT_HAS_AHWB_SUPPORT - , - AHardwareBuffer* ahwb = nullptr -#endif // LITERT_HAS_AHWB_SUPPORT - ) - : tflite_gl_buffer_(std::move(tflite_gl_buffer)), - deallocator_(nullptr), - size_bytes_(tflite_gl_buffer.bytes_size()) -#if LITERT_HAS_AHWB_SUPPORT - , - ahwb_(ahwb) -#endif // LITERT_HAS_AHWB_SUPPORT - { - } -#endif // LITERT_HAS_OPENGL_SUPPORT - - GlBuffer(LiteRtGLenum target, LiteRtGLuint id, size_t size_bytes, - size_t offset, LiteRtGlBufferDeallocator deallocator); - - GlBuffer(GlBuffer&& other); - - ~GlBuffer(); - - static bool IsSupported() { return true; } - static Expected Alloc(size_t size_bytes); - -#if LITERT_HAS_AHWB_SUPPORT - static Expected AllocFromAhwbBuffer(AhwbBuffer& ahwb_buffer); -#endif // LITERT_HAS_AHWB_SUPPORT - - template - Expected Lock(); - - template - Expected Unlock(); - - LiteRtGLenum target() const; - LiteRtGLuint id() const; - size_t size_bytes() const; - size_t offset() const; - - // Creates an EGL sync object on the GPU command queue and returns a native - // fence associated with the sync object. - // Note: This function assumes that all GL operations have been already added - // to the GPU command queue. - static Expected CreateEglSyncAndFence(); - - private: - absl::Mutex mutex_; -#if LITERT_HAS_OPENGL_SUPPORT - tflite::gpu::gl::GlBuffer tflite_gl_buffer_; - LiteRtGlBufferDeallocator deallocator_; - // The cpu memory buffer pointer. - void* data_ = nullptr; - // The size of the buffer in bytes. - size_t size_bytes_ = 0; -#endif // LITERT_HAS_OPENGL_SUPPORT -#if LITERT_HAS_AHWB_SUPPORT - AHardwareBuffer* ahwb_ = nullptr; -#endif // LITERT_HAS_AHWB_SUPPORT -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GL_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/gl_buffer_test.cc b/tensorflow/lite/experimental/litert/runtime/gl_buffer_test.cc deleted file mode 100644 index 905056a860cccc..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gl_buffer_test.cc +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -#if LITERT_HAS_OPENGL_SUPPORT -#include - -#include -#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -#if LITERT_HAS_AHWB_SUPPORT -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" -#endif // LITERT_HAS_AHWB_SUPPORT - -namespace litert { -namespace internal { -namespace { - -using ::testing::FloatEq; -using ::testing::FloatNear; -using ::testing::Pointwise; - -constexpr const float kTensorData[] = {10, 20, 30, 40}; - -TEST(Buffer, GlBufferAlloc) { - if (!GlBuffer::IsSupported()) { - GTEST_SKIP() << "OpenGL buffers are not supported on this platform"; - } - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - auto buffer = GlBuffer::Alloc(4 * sizeof(float)); - ASSERT_TRUE(buffer); - - // Test lock and unlock. - LITERT_ASSERT_OK_AND_ASSIGN(float* data, buffer->Lock()); - EXPECT_NE(data, nullptr); - LITERT_ASSERT_OK(buffer->Unlock()); -} - -#if LITERT_HAS_AHWB_SUPPORT -TEST(Buffer, GlBufferAllocFromAhwb) { - if (!GlBuffer::IsSupported()) { - GTEST_SKIP() << "OpenGL buffers are not supported on this platform"; - } - // TODO(gcarranza): Incorporate this into LiteRT environment. - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - LITERT_ASSERT_OK_AND_ASSIGN(AhwbBuffer ahwb_buffer, - AhwbBuffer::Alloc(4 * sizeof(float))); - // Write to AHWB on CPU. - LITERT_ASSERT_OK_AND_ASSIGN( - void* ahwb_host_data, - litert::internal::AhwbBuffer::Lock(ahwb_buffer.ahwb)); - std::memcpy(ahwb_host_data, kTensorData, sizeof(kTensorData)); - LITERT_ASSERT_OK(litert::internal::AhwbBuffer::Unlock(ahwb_buffer.ahwb)); - - // Create GL buffer from AHWB. - LITERT_ASSERT_OK_AND_ASSIGN(GlBuffer gl_buffer, - GlBuffer::AllocFromAhwbBuffer(ahwb_buffer)); - - // Read from GL buffer backed by AHWB. - LITERT_ASSERT_OK_AND_ASSIGN(float* gl_host_data, gl_buffer.Lock()); - ASSERT_NE(gl_host_data, nullptr); - EXPECT_EQ(std::memcmp(gl_host_data, kTensorData, sizeof(kTensorData)), 0); - LITERT_EXPECT_OK(gl_buffer.Unlock()); -} - -TEST(Buffer, NegativeFenceAhwbRead) { - LITERT_ASSERT_OK_AND_ASSIGN(AhwbBuffer ahwb_buffer, - AhwbBuffer::Alloc(4 * sizeof(float))); - - LiteRtEventT event; - LITERT_ASSERT_OK_AND_ASSIGN(int fence_fd, event.GetSyncFenceFd()); - ASSERT_EQ(fence_fd, -1); - // Since fence is -1, there should be no wait on fence. - LITERT_ASSERT_OK_AND_ASSIGN(void* ahwb_host_data, - AhwbBuffer::Lock(ahwb_buffer.ahwb, &event)); - ASSERT_TRUE(ahwb_host_data != nullptr); - LITERT_ASSERT_OK(AhwbBuffer::Unlock(ahwb_buffer.ahwb)); -} - -// Utility function to fill the GPU buffer. -void FillGlBuffer(GLuint id, std::size_t size) { - std::string shader_source = R"( #version 310 es - precision highp float; - layout(local_size_x = 1, local_size_y = 1) in; - layout(std430, binding = 0) buffer Output {float elements[];} output_data; - void main() { - uint v = gl_GlobalInvocationID.x * 2u; - output_data.elements[v] = float(v) / 10.0; - output_data.elements[v + 1u] = float(v + 1u) / 10.0; - })"; - GLuint shader = glCreateShader(GL_COMPUTE_SHADER); - const GLchar* sources[] = {shader_source.c_str()}; - glShaderSource(shader, 1, sources, nullptr); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glCompileShader(shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - - GLuint to_buffer_program = glCreateProgram(); - glAttachShader(to_buffer_program, shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDeleteShader(shader); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glLinkProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, id); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glUseProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDispatchCompute(size / 2, 1, 1); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); - ABSL_CHECK(glGetError() == GL_NO_ERROR); - glDeleteProgram(to_buffer_program); - ABSL_CHECK(glGetError() == GL_NO_ERROR); -} - -TEST(Buffer, GpuWriteAhwbRead) { - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - LITERT_ASSERT_OK_AND_ASSIGN(AhwbBuffer ahwb_buffer, - AhwbBuffer::Alloc(4 * sizeof(float))); - // Write to AHWB on CPU. - LITERT_ASSERT_OK_AND_ASSIGN( - void* ahwb_host_data, - litert::internal::AhwbBuffer::Lock(ahwb_buffer.ahwb)); - std::memcpy(ahwb_host_data, kTensorData, sizeof(kTensorData)); - LITERT_ASSERT_OK(litert::internal::AhwbBuffer::Unlock(ahwb_buffer.ahwb)); - - // Create GL buffer from AHWB. - LITERT_ASSERT_OK_AND_ASSIGN(GlBuffer gl_buffer, - GlBuffer::AllocFromAhwbBuffer(ahwb_buffer)); - - // Schedule GPU write to GL buffer. - FillGlBuffer(gl_buffer.id(), 4); - - // Create EGL sync and fence before AHWB read. - LITERT_ASSERT_OK_AND_ASSIGN(int native_fence, - GlBuffer::CreateEglSyncAndFence()); - - // Wrap native fence in LiteRT event. - LiteRtEventT gpu_write_event = {.fd = native_fence, .owns_fd = true}; - - // Read from AHWB on CPU, waiting for GPU write to complete. - LITERT_ASSERT_OK_AND_ASSIGN( - void* ahwb_host_data_after_write_data, - AhwbBuffer::Lock(ahwb_buffer.ahwb, &gpu_write_event)); - ASSERT_NE(ahwb_host_data_after_write_data, nullptr); - auto ahwb_host_data_after_write = absl::MakeSpan( - reinterpret_cast(ahwb_host_data_after_write_data), 4); - // Check that the data is the same as the GPU write. - std::vector expected_data = {0.0f, 0.1f, 0.2f, 0.3f}; - EXPECT_THAT(ahwb_host_data_after_write, - Pointwise(FloatNear(1e-5), expected_data)); - LITERT_ASSERT_OK(AhwbBuffer::Unlock(ahwb_buffer.ahwb)); -} - -#endif // LITERT_HAS_AHWB_SUPPORT - -} // namespace -} // namespace internal -} // namespace litert - -#endif // LITERT_HAS_OPENGL_SUPPORT diff --git a/tensorflow/lite/experimental/litert/runtime/gl_texture.cc b/tensorflow/lite/experimental/litert/runtime/gl_texture.cc deleted file mode 100644 index 6b453f32957f21..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gl_texture.cc +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/gl_texture.h" - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -#if LITERT_HAS_OPENGL_SUPPORT -#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -namespace litert { -namespace internal { - -LiteRtGLenum GlTexture::target() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_texture_.target(); -#endif - LITERT_LOG(LITERT_ERROR, "GlTexture::target() is not supported"); - return 0; -} - -LiteRtGLuint GlTexture::id() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_texture_.id(); -#endif - LITERT_LOG(LITERT_ERROR, "GlTexture::id() is not supported"); - return 0; -} - -LiteRtGLenum GlTexture::format() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_texture_.format(); -#endif - LITERT_LOG(LITERT_ERROR, "GlTexture::format() is not supported"); - return 0; -} - -size_t GlTexture::size_bytes() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_texture_.bytes_size(); -#endif - LITERT_LOG(LITERT_ERROR, "GlTexture::size_bytes() is not supported"); - return 0; -} - -LiteRtGLint GlTexture::layer() const { -#if LITERT_HAS_OPENGL_SUPPORT - return tflite_gl_texture_.layer(); -#else - LITERT_LOG(LITERT_ERROR, "GlTexture::layer() is not supported"); - return 0; -#endif -} - -GlTexture::GlTexture(LiteRtGLenum target, LiteRtGLuint id, LiteRtGLenum format, - size_t size_bytes, LiteRtGLint layer, - LiteRtGlTextureDeallocator deallocator) { -#if LITERT_HAS_OPENGL_SUPPORT - if (deallocator != nullptr) { - tflite_gl_texture_ = tflite::gpu::gl::GlTexture( - target, id, format, size_bytes, layer, /*has_ownership=*/false); - deallocator_ = std::move(deallocator); - } else { - tflite_gl_texture_ = tflite::gpu::gl::GlTexture( - target, id, format, size_bytes, layer, /*has_ownership=*/true); - deallocator_ = nullptr; - } -#else - LITERT_LOG(LITERT_ERROR, "GlTexture::GlTexture() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -GlTexture::GlTexture(GlTexture&& other) { -#if LITERT_HAS_OPENGL_SUPPORT - tflite_gl_texture_ = std::move(other.tflite_gl_texture_); - deallocator_ = std::move(other.deallocator_); -#else - LITERT_LOG(LITERT_ERROR, "GlTexture::GlTexture() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -GlTexture::~GlTexture() { -#if LITERT_HAS_OPENGL_SUPPORT - if (deallocator_ != nullptr) { - deallocator_(reinterpret_cast(tflite_gl_texture_.id())); - } -#else - LITERT_LOG(LITERT_ERROR, "GlTexture::~GlTexture() is not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/gl_texture.h b/tensorflow/lite/experimental/litert/runtime/gl_texture.h deleted file mode 100644 index 3d358fe9515a32..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gl_texture.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GL_TEXTURE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GL_TEXTURE_H_ - -#include - -#include "absl/synchronization/mutex.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -#if LITERT_HAS_OPENGL_SUPPORT -#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -namespace litert::internal { - -class GlTexture { - public: - GlTexture(LiteRtGLenum target, LiteRtGLuint id, LiteRtGLenum format, - size_t size_bytes, LiteRtGLint layer, - LiteRtGlTextureDeallocator deallocator); - - GlTexture(GlTexture&& other); - - ~GlTexture(); - - LiteRtGLenum target() const; - LiteRtGLuint id() const; - LiteRtGLenum format() const; - size_t size_bytes() const; - LiteRtGLint layer() const; - - private: - absl::Mutex mutex_; -#if LITERT_HAS_OPENGL_SUPPORT - tflite::gpu::gl::GlTexture tflite_gl_texture_; - LiteRtGlTextureDeallocator deallocator_; -#endif // LITERT_HAS_OPENGL_SUPPORT -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GL_TEXTURE_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/gpu_environment.cc b/tensorflow/lite/experimental/litert/runtime/gpu_environment.cc deleted file mode 100644 index c30f9055570dee..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gpu_environment.cc +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/gpu_environment.h" - -#include -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace internal { - -GpuEnvironmentSingleton::GpuEnvironmentSingleton( - LiteRtEnvironmentT* environment) { - cl_device_id device_id = nullptr; - cl_platform_id platform_id = nullptr; - cl_context context = nullptr; - cl_command_queue command_queue = nullptr; - if (environment) { - auto device_option = - environment->GetOption(kLiteRtEnvOptionTagOpenClDeviceId); - if (device_option.has_value() && device_option->type == kLiteRtAnyTypeInt) { - device_id = reinterpret_cast(device_option->int_value); - } - auto platform_option = - environment->GetOption(kLiteRtEnvOptionTagOpenClPlatformId); - if (platform_option.has_value() && - platform_option->type == kLiteRtAnyTypeInt) { - platform_id = - reinterpret_cast(platform_option->int_value); - } - auto context_option = - environment->GetOption(kLiteRtEnvOptionTagOpenClContext); - if (context_option.has_value() && - context_option->type == kLiteRtAnyTypeInt) { - context = reinterpret_cast(context_option->int_value); - } - auto command_queue_option = - environment->GetOption(kLiteRtEnvOptionTagOpenClCommandQueue); - if (command_queue_option.has_value() && - command_queue_option->type == kLiteRtAnyTypeInt) { - command_queue = - reinterpret_cast(command_queue_option->int_value); - } - } - if (device_id && platform_id) { - device_ = litert::cl::ClDevice(device_id, platform_id); - } else { - auto status = litert::cl::CreateDefaultGPUDevice(&device_); - if (!status.ok()) { - LITERT_LOG(LITERT_ERROR, "Failed to create OpenCL device"); - } - } - if (context) { - context_ = litert::cl::ClContext(context, /*has_ownership=*/false); - } else { - auto status = litert::cl::CreateClContext(device_, &context_); - if (!status.ok()) { - LITERT_LOG(LITERT_ERROR, "Failed to create OpenCL contxt"); - } - } - if (command_queue) { - command_queue_ = - litert::cl::ClCommandQueue(command_queue, /*has_ownership=*/false); - } else { - auto status = - litert::cl::CreateClCommandQueue(device_, context_, &command_queue_); - if (!status.ok()) { - LITERT_LOG(LITERT_ERROR, "Failed to create OpenCL command queue"); - } - } -} - -GpuEnvironmentSingleton* GpuEnvironmentSingleton::instance_ = nullptr; - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/gpu_environment.h b/tensorflow/lite/experimental/litert/runtime/gpu_environment.h deleted file mode 100644 index 38f4d9215da15d..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gpu_environment.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GPU_ENVIRONMENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GPU_ENVIRONMENT_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" - -namespace litert::internal { - -// Inner singleton class that is for storing the MLD global environment. -// This class is used to store OpenCL, OpenGL environment objects. -class GpuEnvironmentSingleton { - public: - GpuEnvironmentSingleton(const GpuEnvironmentSingleton&) = delete; - GpuEnvironmentSingleton& operator=(const GpuEnvironmentSingleton&) = delete; - ~GpuEnvironmentSingleton() = default; - litert::cl::ClDevice* getDevice() { return &device_; } - litert::cl::ClContext* getContext() { return &context_; } - litert::cl::ClCommandQueue* getCommandQueue() { return &command_queue_; } - - static GpuEnvironmentSingleton& GetInstance() { - if (instance_ == nullptr) { - instance_ = new GpuEnvironmentSingleton(nullptr); - } - return *instance_; - } - - // Create the singleton instance with the given environment. - // It will fail if the singleton instance already exists. - static Expected Create( - LiteRtEnvironmentT* environment) { - if (instance_ == nullptr) { - instance_ = new GpuEnvironmentSingleton(environment); - LITERT_LOG(LITERT_INFO, "Created LiteRT EnvironmentSingleton."); - } else { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "EnvironmentSingleton already exists"); - } - return instance_; - } - - private: - // Load the OpenCL device, context and command queue from the environment if - // available. Otherwise, create the default device, context and command queue. - explicit GpuEnvironmentSingleton(LiteRtEnvironmentT* environment); - - litert::cl::ClDevice device_; - litert::cl::ClContext context_; - litert::cl::ClCommandQueue command_queue_; - static GpuEnvironmentSingleton* instance_; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_GPU_ENVIRONMENT_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/gpu_environment_test.cc b/tensorflow/lite/experimental/litert/runtime/gpu_environment_test.cc deleted file mode 100644 index 5bf76f813d95e5..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/gpu_environment_test.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/gpu_environment.h" - -#include -#include -#include - -#include -#include -#include "third_party/ml_drift/cl/environment.h" -#include "third_party/ml_drift/cl/opencl_wrapper.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace { - -TEST(EnvironmentSingletonTest, OpenClEnvironment) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported in MSAN"; -#endif - - if (!ml_drift::cl::LoadOpenCL().ok()) { - GTEST_SKIP() << "OpenCL not loaded for ml_drift"; - } - if (!litert::cl::LoadOpenCL().ok()) { - GTEST_SKIP() << "OpenCL not loaded for litert"; - } - - ml_drift::cl::Environment env; - ASSERT_OK(ml_drift::cl::CreateEnvironment(&env)); - - const std::array environment_options = { - LiteRtEnvOption{ - /*.tag=*/kLiteRtEnvOptionTagOpenClContext, - /*.value=*/ - *ToLiteRtAny( - std::any(reinterpret_cast(env.context().context()))), - }, - LiteRtEnvOption{ - /*.tag=*/kLiteRtEnvOptionTagOpenClCommandQueue, - /*.value=*/ - *ToLiteRtAny( - std::any(reinterpret_cast(env.queue()->queue()))), - }, - }; - auto litert_envt = LiteRtEnvironmentT::CreateWithOptions(environment_options); - ASSERT_TRUE(litert_envt); - auto singleton_env = - litert::internal::GpuEnvironmentSingleton::Create(litert_envt->get()); - ASSERT_TRUE(singleton_env); - EXPECT_EQ((*singleton_env)->getContext()->context(), env.context().context()); - EXPECT_EQ((*singleton_env)->getCommandQueue()->queue(), env.queue()->queue()); - - // Create another singleton environment should fail. - auto another_singleton_env = - litert::internal::GpuEnvironmentSingleton::Create(litert_envt->get()); - EXPECT_FALSE(another_singleton_env); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/ion_buffer.cc b/tensorflow/lite/experimental/litert/runtime/ion_buffer.cc deleted file mode 100644 index 41a3ee09c82643..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/ion_buffer.cc +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/ion_buffer.h" - -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/const_init.h" -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#if LITERT_HAS_ION_SUPPORT -#include -#include -#endif // LITERT_HAS_ION_SUPPORT - -namespace litert { -namespace internal { - -#if LITERT_HAS_ION_SUPPORT -namespace { - -class IonLibrary { - public: - using Ptr = std::unique_ptr; - - ~IonLibrary() { - if (client_fd_ > 0) { - ion_close_(client_fd_); - } - } - - static Expected Create() { - DlHandle dlhandle(::dlopen("libion.so", RTLD_NOW | RTLD_LOCAL), ::dlclose); - if (!dlhandle) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "libion.so not found"); - } - - auto ion_open = - reinterpret_cast(::dlsym(dlhandle.get(), "ion_open")); - if (!ion_open) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, "ion_open not found"); - } - - auto ion_close = - reinterpret_cast(::dlsym(dlhandle.get(), "ion_close")); - if (!ion_close) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "ion_close not found"); - } - - auto ion_alloc_fd = - reinterpret_cast(::dlsym(dlhandle.get(), "ion_alloc_fd")); - if (!ion_alloc_fd) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "ion_alloc_fd not found"); - } - - int client_fd = ion_open(); - if (client_fd < 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to open ion device"); - } - - return Ptr(new IonLibrary(std::move(dlhandle), client_fd, ion_close, - ion_alloc_fd)); - } - - Expected Alloc(size_t size, size_t alignment) { - int heap_id_mask = 1 << kIonHeapId; - int fd; - if (auto status = ion_alloc_fd_(client_fd_, size, alignment, heap_id_mask, - kIonFlags, &fd); - status != 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate DMA-BUF buffer"); - } - void* addr = - ::mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); - if (addr == MAP_FAILED) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to mem-map DMA-BUF buffer"); - } - records_[addr] = Record{.fd = fd, .addr = addr, .size = size}; - return IonBuffer{.fd = fd, .addr = addr}; - } - - void Free(void* addr) { - auto iter = records_.find(addr); - if (iter == records_.end()) { - return; - } - auto& record = iter->second; - ::munmap(record.addr, record.size); - ::close(record.fd); - records_.erase(iter); - } - - private: - static constexpr const int kIonHeapId = 25; - static constexpr const int kIonFlags = 1; - - struct Record { - int fd; - void* addr; - size_t size; - }; - - using DlHandle = std::unique_ptr; - using IonOpen = int (*)(); - using IonClose = int (*)(int); - using IonAllocFd = int (*)(int, size_t, size_t, unsigned int, unsigned int, - int*); - - IonLibrary(DlHandle&& dlhandle, int client_fd, IonClose ion_close, - IonAllocFd ion_alloc_fd) - : dlhandle_(std::move(dlhandle)), - client_fd_(client_fd), - ion_close_(ion_close), - ion_alloc_fd_(ion_alloc_fd) {} - - DlHandle dlhandle_; - int client_fd_; - IonClose ion_close_; - IonAllocFd ion_alloc_fd_; - absl::node_hash_map records_; -}; - -IonLibrary* TheIonLibrary; -ABSL_CONST_INIT absl::Mutex TheMutex(absl::kConstInit); - -Expected InitLibraryIfNeededUnlocked() { - if (!TheIonLibrary) { - if (auto library = IonLibrary::Create(); library) { - TheIonLibrary = library->release(); - } else { - return Unexpected(library.Error()); - } - } - return {}; -} - -} // namespace -#endif // LITERT_HAS_ION_SUPPORT - -bool IonBuffer::IsSupported() { -#if LITERT_HAS_ION_SUPPORT - absl::MutexLock lock(&TheMutex); - auto status = InitLibraryIfNeededUnlocked(); - return static_cast(status); -#else // LITERT_HAS_ION_SUPPORT - return false; -#endif // LITERT_HAS_ION_SUPPORT -} - -Expected IonBuffer::Alloc(size_t size, size_t alignment) { -#if LITERT_HAS_ION_SUPPORT - absl::MutexLock lock(&TheMutex); - if (auto status = InitLibraryIfNeededUnlocked(); !status) { - return status.Error(); - } - return TheIonLibrary->Alloc(size, alignment); -#else // LITERT_HAS_ION_SUPPORT - return Unexpected(kLiteRtStatusErrorUnsupported, - "IonBuffer::Alloc not implemented for this platform"); -#endif // LITERT_HAS_ION_SUPPORT -} - -void IonBuffer::Free(void* addr) { -#if LITERT_HAS_ION_SUPPORT - absl::MutexLock lock(&TheMutex); - if (TheIonLibrary) { - TheIonLibrary->Free(addr); - } -#endif // LITERT_HAS_ION_SUPPORT -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/ion_buffer.h b/tensorflow/lite/experimental/litert/runtime/ion_buffer.h deleted file mode 100644 index 38a0b19abdc137..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/ion_buffer.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ION_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ION_BUFFER_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::internal { - -struct IonBuffer { - int fd; - void* addr; - - static bool IsSupported(); - static Expected Alloc(size_t size, size_t alignment); - static void Free(void* addr); -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_ION_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/open_cl_buffer.cc b/tensorflow/lite/experimental/litert/runtime/open_cl_buffer.cc deleted file mode 100644 index d99a9875472981..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/open_cl_buffer.cc +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h" - -#include - -#include -#include -#include -#include - -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/runtime/gpu_environment.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace internal { - -template Expected OpenClBuffer::Lock(); -template Expected OpenClBuffer::Lock(); -template Expected OpenClBuffer::Unlock(); -template Expected OpenClBuffer::Unlock(); - -template -Expected OpenClBuffer::Lock() { - absl::MutexLock lock(&mutex_); - // The buffer has not been locked, so we need to read from the OpenCL - // buffer. - if (data_ == nullptr) { - litert::cl::ClCommandQueue* queue = - GpuEnvironmentSingleton::GetInstance().getCommandQueue(); - std::vector result; - auto status = buffer_.ReadData(queue, &result); - if (!status.ok()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to read OpenCL buffer"); - } - // Ensure the data is aligned. - if (auto rc = - posix_memalign(&data_, LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, size_); - rc) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate aligned memory"); - } - // Copy the data from the OpenCL buffer to the aligned memory. - // TODO(piyu): Consider adding support in MLD OpenCL buffer to directly - // write to the aligned memory. - std::copy(result.begin(), result.end(), static_cast(data_)); - } - return Expected(static_cast(data_)); -} - -template -Expected OpenClBuffer::Unlock() { - absl::MutexLock lock(&mutex_); - litert::cl::ClCommandQueue* queue = - GpuEnvironmentSingleton::GetInstance().getCommandQueue(); - // The buffer has not been locked, so we don't need to write back. - if (data_ == nullptr) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - "Cannot unlock a buffer that wasn't locked in the first place"); - } - size_t write_size = (size_ + sizeof(T) - 1) / sizeof(T); - auto status = buffer_.WriteData( - queue, absl::MakeSpan(static_cast(data_), write_size)); - - if (status.ok()) { - return Expected(); - } - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "The data failed to write to the OpenCL buffer when unlocked"); -} - -bool OpenClBuffer::IsSupported() { - static bool is_supported = ::litert::cl::LoadOpenCL().ok(); - return is_supported; -} - -Expected OpenClBuffer::Alloc(size_t bytes_size) { - LITERT_RETURN_IF_ERROR( - IsSupported(), - Unexpected(kLiteRtStatusErrorRuntimeFailure, "OpenCL is not supported")); - - litert::cl::Buffer buffer; - - litert::cl::ClContext* cl_context = - GpuEnvironmentSingleton::GetInstance().getContext(); - auto result = - litert::cl::CreateReadWriteBuffer(bytes_size, cl_context, &buffer); - if (!result.ok()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create OpenCL buffer"); - } - - return Expected(std::move(buffer), bytes_size); -} -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h b/tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h deleted file mode 100644 index cf5c422afd1adb..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPEN_CL_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPEN_CL_BUFFER_H_ - -#include -#include -#include - -#include "absl/synchronization/mutex.h" -#include -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert::internal { - -/** - * The OpenCL buffer class that provides GPU memory allocation and two-way sync - * between the CPU memory and the GPU OpenCL buffer. - */ -class OpenClBuffer { - public: - OpenClBuffer(OpenClBuffer&& other) { - data_ = other.data_; - buffer_ = std::move(other.buffer_); - size_ = other.size_; - other.data_ = nullptr; - other.size_ = 0; - } - - OpenClBuffer(litert::cl::Buffer buffer, size_t size) - : buffer_(std::move(buffer)), size_(size) {} - - OpenClBuffer(cl_mem buffer, size_t size, LiteRtOpenClDeallocator deallocator) - : deallocator_(deallocator), size_(size) { - if (deallocator_ != nullptr) { - buffer_ = litert::cl::CreateBufferShared(buffer); - } else { // The buffer will be deallocated automatically. - buffer_ = litert::cl::Buffer(buffer, size); - } - } - - ~OpenClBuffer() { - if (deallocator_ != nullptr) { - deallocator_(buffer_.GetMemoryPtr()); - } - if (data_ != nullptr) { - free(data_); - }; - } - - cl_mem GetMemoryPtr() { return buffer_.GetMemoryPtr(); } - // Allocates a CPU memory and conducts a copy from the OpenCL buffer to the - // CPU memory. - template - Expected Lock(); - - // Writes the data from the CPU memory to the OpenCL buffer. - template - Expected Unlock(); - - static bool IsSupported(); - static Expected Alloc(size_t bytes_size); - size_t size_bytes() const { return size_; } - - private: - absl::Mutex mutex_; - // The cpu memory buffer pointer. - void* data_ = nullptr; - litert::cl::Buffer buffer_; - LiteRtOpenClDeallocator deallocator_ = nullptr; - // The size of the buffer in bytes. - size_t size_ = 0; -}; - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPEN_CL_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/BUILD b/tensorflow/lite/experimental/litert/runtime/opencl/BUILD deleted file mode 100644 index 21b74be48a932e..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/BUILD +++ /dev/null @@ -1,129 +0,0 @@ -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "cl_command_queue", - srcs = [ - "cl_command_queue.cc", - ], - hdrs = [ - "cl_command_queue.h", - ], - deps = [ - ":cl_context", - ":cl_device", - ":opencl_wrapper", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@opencl_headers", - ], -) - -cc_library( - name = "cl_device", - srcs = [ - "cl_device.cc", - ], - hdrs = [ - "cl_device.h", - ], - deps = [ - ":opencl_wrapper", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", - "@opencl_headers", - ], -) - -cc_library( - name = "cl_context", - srcs = [ - "cl_context.cc", - ], - hdrs = [ - "cl_context.h", - ], - deps = [ - ":cl_device", - ":opencl_wrapper", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@opencl_headers", - ], -) - -cc_library( - name = "cl_event", - srcs = [ - "cl_event.cc", - ], - hdrs = [ - "cl_event.h", - ], - deps = [ - ":opencl_wrapper", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/strings:str_format", - "@opencl_headers", - ], -) - -cc_library( - name = "opencl_wrapper", - srcs = [ - "opencl_wrapper.cc", - ], - hdrs = [ - "opencl_wrapper.h", - ], - visibility = [ - "//tensorflow/lite/experimental/litert:__subpackages__", - "//third_party/odml/infra:__subpackages__", - ], - deps = [ - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@opencl_headers", - ], -) - -cc_library( - name = "buffer", - srcs = [ - "buffer.cc", - ], - hdrs = [ - "buffer.h", - ], - deps = [ - ":cl_command_queue", - ":cl_context", - ":opencl_wrapper", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@opencl_headers", - ], -) - -cc_test( - name = "buffer_test", - srcs = ["buffer_test.cc"], - # require GPU to run OpenCL tests. - tags = [ - "requires-gpu-nvidia", - ], - deps = [ - ":buffer", - ":cl_command_queue", - ":cl_context", - ":cl_device", - ":opencl_wrapper", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/buffer.cc b/tensorflow/lite/experimental/litert/runtime/opencl/buffer.cc deleted file mode 100644 index c2878a4839517a..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/buffer.cc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is a copy of third_party/ml_drift/cl/buffer.cc. -#include "tensorflow/lite/experimental/litert/runtime/opencl/buffer.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace cl { -absl::Status CreateClBuffer(cl_context context, size_t size_in_bytes, - bool read_only, void* data, cl_mem* result) { - cl_mem_flags flags = read_only ? CL_MEM_READ_ONLY : CL_MEM_READ_WRITE; - if (data) { - flags |= CL_MEM_COPY_HOST_PTR; - } - cl_int error_code; - *result = clCreateBuffer(context, flags, size_in_bytes, data, &error_code); - if (!*result) { - return absl::UnknownError( - absl::StrCat("Failed to allocate device memory (clCreateBuffer): ", - std::to_string(error_code))); - } - return absl::OkStatus(); -} -absl::Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, - const void* data, ClContext* context, - Buffer* result) { - cl_mem buffer; - auto status = CreateClBuffer(context->context(), size_in_bytes, gpu_read_only, - const_cast(data), &buffer); - if (!status.ok()) { - return status; - } - *result = Buffer(buffer, size_in_bytes); - - return absl::OkStatus(); -} - -Buffer::Buffer(cl_mem buffer, size_t size_in_bytes, bool is_sub_buffer) - : buffer_(buffer), size_(size_in_bytes), is_sub_buffer_(is_sub_buffer) {} - -Buffer::Buffer(cl_mem buffer) - : buffer_(buffer), size_(0), is_sub_buffer_(false), owner_(false) {} - -Buffer::Buffer(Buffer&& buffer) - : buffer_(buffer.buffer_), - size_(buffer.size_), - is_sub_buffer_(buffer.is_sub_buffer_), - owner_(buffer.owner_) { - buffer.buffer_ = nullptr; - buffer.size_ = 0; - buffer.is_sub_buffer_ = false; -} - -Buffer& Buffer::operator=(Buffer&& buffer) { - if (this != &buffer) { - Release(); - std::swap(size_, buffer.size_); - std::swap(buffer_, buffer.buffer_); - std::swap(is_sub_buffer_, buffer.is_sub_buffer_); - std::swap(owner_, buffer.owner_); - } - return *this; -} - -void Buffer::Release() { - if (owner_ && buffer_) { - clReleaseMemObject(buffer_); - buffer_ = nullptr; - size_ = 0; - is_sub_buffer_ = false; - } -} - -Buffer CreateBufferShared(cl_mem buffer) { return Buffer(buffer); } - -absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, ClContext* context, - Buffer* result) { - return CreateBuffer(size_in_bytes, true, nullptr, context, result); -} - -absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, - ClContext* context, Buffer* result) { - return CreateBuffer(size_in_bytes, true, data, context, result); -} - -absl::Status CreateReadWriteBuffer(size_t size_in_bytes, ClContext* context, - Buffer* result) { - return CreateBuffer(size_in_bytes, false, nullptr, context, result); -} - -} // namespace cl -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/buffer.h b/tensorflow/lite/experimental/litert/runtime/opencl/buffer.h deleted file mode 100644 index b1cb09f065508f..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/buffer.h +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is a copy of third_party/ml_drift/cl/buffer.h. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_BUFFER_H_ - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/types/span.h" -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" - -namespace litert { -namespace cl { - -// Buffer represent linear GPU data storage with arbitrary data format. -// Buffer is moveable but not copyable. -class Buffer { - public: - Buffer() = default; // just for using Buffer as a class members - Buffer(cl_mem buffer, size_t size_in_bytes, bool is_sub_buffer = false); - explicit Buffer(cl_mem buffer); - - // Move only - Buffer(Buffer&& buffer); - Buffer& operator=(Buffer&& buffer); - Buffer(const Buffer&) = delete; - Buffer& operator=(const Buffer&) = delete; - - ~Buffer() { Release(); } - - // for profiling and memory statistics - uint64_t GetMemorySizeInBytes() const { return size_; } - - cl_mem GetMemoryPtr() const { return buffer_; } - - bool IsSubBuffer() const { return is_sub_buffer_; } - - // Writes data to a buffer. Data should point to a region that - // has exact size in bytes as size_in_bytes(constructor parameter). - template - absl::Status WriteData(ClCommandQueue* queue, absl::Span data); - - // Reads data from Buffer into CPU memory. - template - absl::Status ReadData(ClCommandQueue* queue, std::vector* result) const; - - private: - void Release(); - - cl_mem buffer_ = nullptr; - size_t size_ = 0; - bool is_sub_buffer_ = false; - bool owner_ = true; -}; - -Buffer CreateBufferShared(cl_mem buffer); - -absl::Status CreateClBuffer(cl_context context, size_t size_in_bytes, - bool read_only, void* data, cl_mem* result); - -absl::Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, - const void* data, ClContext* context, Buffer* result); - -absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, ClContext* context, - Buffer* result); - -absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, - ClContext* context, Buffer* result); - -absl::Status CreateReadWriteBuffer(size_t size_in_bytes, ClContext* context, - Buffer* result); - -absl::Status CreateReadWriteSubBuffer(const Buffer& parent, - size_t origin_in_bytes, - size_t size_in_bytes, ClContext* context, - Buffer* result); - -template -absl::Status Buffer::WriteData(ClCommandQueue* queue, - const absl::Span data) { - if (sizeof(T) * data.size() > size_) { - return absl::InvalidArgumentError( - "absl::Span data size is greater from buffer allocated size."); - } - auto status = queue->EnqueueWriteBuffer(buffer_, size_, data.data()); - if (!status.ok()) { - return status; - } - return absl::OkStatus(); -} - -template -absl::Status Buffer::ReadData(ClCommandQueue* queue, - std::vector* result) const { - if (size_ % sizeof(T) != 0) { - return absl::UnknownError("Wrong element size(typename T is not correct?"); - } - - const int elements_count = size_ / sizeof(T); - result->resize(elements_count); - - return queue->EnqueueReadBuffer(buffer_, size_, result->data()); -} - -} // namespace cl -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/buffer_test.cc b/tensorflow/lite/experimental/litert/runtime/opencl/buffer_test.cc deleted file mode 100644 index 84280ef6af23b9..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/buffer_test.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2024 The ML Drift Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/opencl/buffer.h" - -#include - -#include -#include -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -using ::testing::FloatNear; -using ::testing::Pointwise; - -namespace litert { -namespace internal { - -TEST(OpenCLTest, BufferTestFloat) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported In msan"; -#endif - - if (!litert::cl::LoadOpenCL().ok()) { - GTEST_SKIP() << "OpenCL buffers are not supported on this platform; " - "skipping the test"; - } - const std::vector data = {1.0, 2.0, 3.0, -4.0, 5.1}; - litert::cl::Buffer buffer; - litert::cl::ClContext context; - litert::cl::ClDevice device; - litert::cl::ClCommandQueue queue; - ASSERT_TRUE(CreateDefaultGPUDevice(&device).ok()); - ASSERT_TRUE(CreateClContext(device, &context).ok()); - ASSERT_TRUE(CreateClCommandQueue(device, context, &queue).ok()); - ASSERT_TRUE(CreateReadWriteBuffer(sizeof(float) * 5, &context, &buffer).ok()); - ASSERT_TRUE( - buffer.WriteData(&queue, absl::MakeConstSpan(data.data(), data.size())) - .ok()); - std::vector gpu_data; - ASSERT_TRUE(buffer.ReadData(&queue, &gpu_data).ok()); - - EXPECT_THAT(gpu_data, Pointwise(FloatNear(0.0f), data)); -} -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.cc b/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.cc deleted file mode 100644 index 278862c3f87d20..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.cc +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is a copy of third_party/ml_drift/cl/cl_command_queue.cc. -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" - -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace cl { -namespace { - -absl::StatusOr CreateClCommandQueueWithProperties( - const ClDevice& device, const ClContext& context, - cl_command_queue_properties queue_properties) { - int error_code; - cl_command_queue queue; - if (clCreateCommandQueueWithProperties) { - std::vector props; - if (queue_properties != 0) { - props.push_back(CL_QUEUE_PROPERTIES); - props.push_back(queue_properties); - } - props.push_back(0); - - queue = clCreateCommandQueueWithProperties(context.context(), device.id(), - props.data(), &error_code); - } else { - // Backwards compatibility for OpenCL versions before 2.0. - queue = clCreateCommandQueue(context.context(), device.id(), - queue_properties, &error_code); - } - if (!queue) { - return absl::UnknownError(absl::StrCat( - "Failed to create a command queue - ", std::to_string(error_code))); - } - return queue; -} - -} // namespace - -ClCommandQueue::ClCommandQueue() = default; - -ClCommandQueue::ClCommandQueue(cl_command_queue queue, bool has_ownership) - : queue_(queue), has_ownership_(has_ownership) {} - -ClCommandQueue::ClCommandQueue(ClCommandQueue&& queue) - : queue_(queue.queue_), has_ownership_(queue.has_ownership_) { - queue.queue_ = nullptr; -} - -ClCommandQueue& ClCommandQueue::operator=(ClCommandQueue&& queue) { - if (this != &queue) { - Release(); - std::swap(queue_, queue.queue_); - has_ownership_ = queue.has_ownership_; - } - return *this; -} - -ClCommandQueue::~ClCommandQueue() { Release(); } - -void ClCommandQueue::Release() { - if (has_ownership_ && queue_) { - clReleaseCommandQueue(queue_); - queue_ = nullptr; - } -} - -absl::Status ClCommandQueue::EnqueueWriteBuffer(cl_mem memory, - size_t size_in_bytes, - const void* data, bool async) { - const cl_bool blocking = async ? CL_FALSE : CL_TRUE; - auto error_code = clEnqueueWriteBuffer( - queue_, memory, blocking, 0, size_in_bytes, data, 0, nullptr, nullptr); - if (error_code != CL_SUCCESS) { - return absl::UnknownError( - absl::StrCat("Failed to upload data to GPU (clEnqueueWriteBuffer) - ", - std::to_string(error_code))); - } - return absl::OkStatus(); -} - -absl::Status ClCommandQueue::EnqueueReadBuffer(cl_mem memory, - size_t size_in_bytes, void* data, - bool async) { - const cl_bool blocking = async ? CL_FALSE : CL_TRUE; - auto error_code = clEnqueueReadBuffer( - queue_, memory, blocking, 0, size_in_bytes, data, 0, nullptr, nullptr); - if (error_code != CL_SUCCESS) { - return absl::UnknownError( - absl::StrCat("Failed to read data from GPU (clEnqueueReadBuffer) - ", - std::to_string(error_code))); - } - return absl::OkStatus(); -} - -absl::Status ClCommandQueue::WaitForCompletion() { - auto error_code = clFinish(queue_); - if (error_code != CL_SUCCESS) { - return absl::UnknownError( - absl::StrCat("Failed to clFinish - ", std::to_string(error_code))); - } - return absl::OkStatus(); -} - -absl::Status CreateClCommandQueue(const ClDevice& device, - const ClContext& context, - ClCommandQueue* result) { - auto queue = CreateClCommandQueueWithProperties(device, context, 0); - if (!queue.ok()) { - return queue.status(); - } - *result = ClCommandQueue(*queue, true); - return absl::OkStatus(); -} - -} // namespace cl -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h b/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h deleted file mode 100644 index a7691d52e6c65b..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is a copy of third_party/ml_drift/cl/cl_command_queue.h. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_COMMAND_QUEUE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_COMMAND_QUEUE_H_ - -#include -#include - -#include "absl/status/status.h" -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" - -namespace litert { -namespace cl { - -// A wrapper around opencl command queue -class ClCommandQueue { - public: - ClCommandQueue(); - ClCommandQueue(cl_command_queue queue, bool has_ownership); - - // Move only - ClCommandQueue(ClCommandQueue&& queue); - ClCommandQueue& operator=(ClCommandQueue&& queue); - ClCommandQueue(const ClCommandQueue&) = delete; - ClCommandQueue& operator=(const ClCommandQueue&) = delete; - - virtual ~ClCommandQueue(); - - cl_command_queue queue() const { return queue_; } - - absl::Status EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes, - const void* data, bool async = false); - absl::Status EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes, - void* data, bool async = false); - - absl::Status WaitForCompletion(); - - protected: - void Release(); - - cl_command_queue queue_ = nullptr; - bool has_ownership_ = false; -}; - -class ProfilingCommandQueue : public ClCommandQueue { - public: - ProfilingCommandQueue(); - explicit ProfilingCommandQueue(cl_command_queue queue); - - // Move only - ProfilingCommandQueue(ProfilingCommandQueue&& queue); - ProfilingCommandQueue& operator=(ProfilingCommandQueue&& queue); - ProfilingCommandQueue(const ProfilingCommandQueue&) = delete; - ProfilingCommandQueue& operator=(const ProfilingCommandQueue&) = delete; - - private: - std::string current_label_; -}; - -absl::Status CreateClCommandQueue(const ClDevice& device, - const ClContext& context, - ClCommandQueue* result); - -} // namespace cl -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_COMMAND_QUEUE_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.cc b/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.cc deleted file mode 100644 index 5eb5f4949d37f8..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" - -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace cl { -namespace { - -absl::Status CreateClContext(const ClDevice& device, - const std::vector& props, - ClContext* result) { - int error_code; - cl_device_id device_id = device.id(); - std::vector props_local = props; - if (!props_local.empty()) { - props_local.push_back(0); - } - cl_context_properties* properties_ptr = - props_local.empty() ? nullptr : props_local.data(); - cl_context context = clCreateContext(properties_ptr, 1, &device_id, nullptr, - nullptr, &error_code); - if (!context) { - return absl::UnknownError( - absl::StrCat("Failed to create a compute context - ", error_code)); - } - - *result = ClContext(context, true); - return absl::OkStatus(); -} - -} // namespace - -ClContext::ClContext() = default; - -ClContext::ClContext(cl_context context, bool has_ownership) - : context_(context), has_ownership_(has_ownership) {} - -ClContext::ClContext(cl_context context, bool has_ownership, ClDevice& device) - : context_(context), has_ownership_(has_ownership) {} - -ClContext::ClContext(ClContext&& context) - : context_(context.context_), has_ownership_(context.has_ownership_) { - context.context_ = nullptr; -} - -ClContext& ClContext::operator=(ClContext&& context) { - if (this != &context) { - Release(); - std::swap(context_, context.context_); - has_ownership_ = context.has_ownership_; - } - return *this; -} - -ClContext::~ClContext() { Release(); } - -void ClContext::Release() { - if (has_ownership_ && context_) { - clReleaseContext(context_); - context_ = nullptr; - } -} - -absl::Status CreateClContext(const ClDevice& device, ClContext* result) { - std::vector props; - return CreateClContext(device, props, result); -} - -absl::Status CreateClGlContext(const ClDevice& device, - cl_context_properties egl_context, - cl_context_properties egl_display, - ClContext* result) { - cl_context_properties platform = - reinterpret_cast(device.platform()); - - std::vector props = {CL_GL_CONTEXT_KHR, egl_context, - CL_EGL_DISPLAY_KHR, egl_display, - CL_CONTEXT_PLATFORM, platform}; - - return CreateClContext(device, props, result); -} - -} // namespace cl -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h b/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h deleted file mode 100644 index 880e42b7c4a5c1..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_CONTEXT_H_ - -#include "absl/status/status.h" -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" - -namespace litert { -namespace cl { - -// A RAII wrapper around opencl context -class ClContext { - public: - ClContext(); - ClContext(cl_context context, bool has_ownership); - ClContext(cl_context context, bool has_ownership, ClDevice& device); - // Move only - ClContext(ClContext&& context); - ClContext& operator=(ClContext&& context); - ClContext(const ClContext&) = delete; - ClContext& operator=(const ClContext&) = delete; - - ~ClContext(); - - cl_context context() const { return context_; } - - private: - void Release(); - - cl_context context_ = nullptr; - bool has_ownership_ = false; -}; - -absl::Status CreateClContext(const ClDevice& device, ClContext* result); -absl::Status CreateClGlContext(const ClDevice& device, - cl_context_properties egl_context, - cl_context_properties egl_display, - ClContext* result); - -} // namespace cl -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.cc b/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.cc deleted file mode 100644 index 5677e50927a3c8..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// this is a copy of ml_drift/cl/cl_device.cc -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" - -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include -#include -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace cl { - -ClDevice::ClDevice(cl_device_id id, cl_platform_id platform_id) - : id_(id), platform_id_(platform_id) {} - -ClDevice::ClDevice(const ClDevice& device) = default; - -ClDevice& ClDevice::operator=(const ClDevice& device) { - if (this != &device) { - id_ = device.id_; - platform_id_ = device.platform_id_; - } - return *this; -} - -ClDevice::ClDevice(ClDevice&& device) - : id_(device.id_), platform_id_(device.platform_id_) { - device.id_ = nullptr; - device.platform_id_ = nullptr; -} - -ClDevice& ClDevice::operator=(ClDevice&& device) { - if (this != &device) { - id_ = nullptr; - platform_id_ = nullptr; - std::swap(id_, device.id_); - std::swap(platform_id_, device.platform_id_); - } - return *this; -} - -absl::Status CreateDefaultGPUDevice(ClDevice* result) { - cl_uint num_platforms; - cl_int status = clGetPlatformIDs(0, nullptr, &num_platforms); - if (status != CL_SUCCESS) { - return absl::UnknownError( - absl::StrFormat("clGetPlatformIDs returned %d", status)); - } - if (num_platforms == 0) { - return absl::UnknownError("No supported OpenCL platform."); - } - std::vector platforms(num_platforms); - status = clGetPlatformIDs(num_platforms, platforms.data(), nullptr); - if (status != CL_SUCCESS) { - return absl::UnknownError( - absl::StrFormat("clGetPlatformIDs returned %d", status)); - } - - cl_platform_id platform_id = platforms[0]; - cl_uint num_devices; - status = - clGetDeviceIDs(platform_id, CL_DEVICE_TYPE_GPU, 0, nullptr, &num_devices); - if (status != CL_SUCCESS) { - return absl::UnknownError( - absl::StrFormat("clGetDeviceIDs returned %d", status)); - } - if (num_devices == 0) { - return absl::UnknownError("No GPU on current platform."); - } - - std::vector devices(num_devices); - status = clGetDeviceIDs(platform_id, CL_DEVICE_TYPE_GPU, num_devices, - devices.data(), nullptr); - if (status != CL_SUCCESS) { - return absl::UnknownError( - absl::StrFormat("clGetDeviceIDs returned %d", status)); - } - - *result = ClDevice(devices[0], platform_id); - LoadOpenCLFunctionExtensions(platform_id); - return absl::OkStatus(); -} - -} // namespace cl -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h b/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h deleted file mode 100644 index 71d93e64ace879..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2024 The ML Drift Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_DEVICE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_DEVICE_H_ - -#include - -#include "absl/status/status.h" -#include -#include - -namespace litert { -namespace cl { - -// A wrapper around opencl device id -class ClDevice { - public: - ClDevice() = default; - ClDevice(cl_device_id id, cl_platform_id platform_id); - - ClDevice(ClDevice&& device); - ClDevice& operator=(ClDevice&& device); - ClDevice(const ClDevice&); - ClDevice& operator=(const ClDevice&); - - ~ClDevice() = default; - - cl_device_id id() const { return id_; } - cl_platform_id platform() const { return platform_id_; } - std::string GetPlatformVersion() const; - - private: - cl_device_id id_ = nullptr; - cl_platform_id platform_id_ = nullptr; -}; - -absl::Status CreateDefaultGPUDevice(ClDevice* result); - -template -T GetDeviceInfo(cl_device_id id, cl_device_info info) { - T result; - cl_int error = clGetDeviceInfo(id, info, sizeof(T), &result, nullptr); - if (error != CL_SUCCESS) { - return {}; - } - return result; -} - -template -absl::Status GetDeviceInfo(cl_device_id id, cl_device_info info, T* result) { - cl_int error = clGetDeviceInfo(id, info, sizeof(T), result, nullptr); - if (error != CL_SUCCESS) { - return absl::InvalidArgumentError("cl error:" + std::to_string(error)); - } - return absl::OkStatus(); -} - -} // namespace cl -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_DEVICE_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_event.cc b/tensorflow/lite/experimental/litert/runtime/opencl/cl_event.cc deleted file mode 100644 index 4fd14a130b9dec..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_event.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_event.h" - -#include "absl/strings/str_format.h" -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -namespace litert { -namespace cl { - -Expected WaitForEvents(int num_events, const cl_event* event_list) { - cl_int res = clWaitForEvents(num_events, event_list); - if (res != CL_SUCCESS) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("clWaitForEvents fails with error code %d", res)); - } - return {}; -} - -Expected SetUserEventStatus(cl_event event) { - cl_int res = clSetUserEventStatus(event, CL_COMPLETE); - if (res != CL_SUCCESS) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("clSetUserEventStatus fails with error code %d", res)); - } - return {}; -} - -Expected CreateUserEvent(cl_context context) { - cl_int res; - cl_event user_event = clCreateUserEvent(context, &res); - if (res != CL_SUCCESS) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("clCreateUserEvent fails with error code %d", res)); - } - return user_event; -} - -} // namespace cl -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_event.h b/tensorflow/lite/experimental/litert/runtime/opencl/cl_event.h deleted file mode 100644 index 1ba38b99a90555..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/cl_event.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_EVENT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_EVENT_H_ - -#include -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { -namespace cl { - -Expected WaitForEvents(int num_events, const cl_event* event_list); - -Expected SetUserEventStatus(cl_event event); - -Expected CreateUserEvent(cl_context context); - -} // namespace cl -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_EVENT_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.cc b/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.cc deleted file mode 100644 index 79c4e33e2eb72f..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.cc +++ /dev/null @@ -1,470 +0,0 @@ -// Copyright 2024 The Tensorflow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is copied from third_party/ml_drift/cl/opencl_wrapper.cc. -#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" - -#if defined(_WIN32) -#define __WINDOWS__ -#endif - -#ifdef __WINDOWS__ -#include -#else -#include -#endif - -#include - -#include "absl/strings/str_cat.h" - -namespace litert { -namespace cl { - -#ifdef __ANDROID__ -#define LoadFunction(function) \ - if (use_wrapper) { \ - function = reinterpret_cast(loadOpenCLPointer(#function)); \ - } else { \ - function = reinterpret_cast(dlsym(libopencl, #function)); \ - } - -namespace { - -// Loads a library from Android SP-HAL namespace which includes libraries from -// the path /vendor/lib[64] directly and several sub-folders in it. -// First tries using dlopen(), which should work if the process is running with -// linker namespace "sphal" (so has permissions to sphal paths). -// If it fails, for example if process is running with linker default namespace -// because it's a sub-process of the app, then tries loading the library using -// a sphal helper loader function from Vendor NDK support library. -void* AndroidDlopenSphalLibrary(const char* filename, int dlopen_flags) { - void* lib = dlopen(filename, dlopen_flags); - if (lib != nullptr) { - return lib; - } - static void* (*android_load_sphal_library)(const char*, int) = nullptr; - if (android_load_sphal_library != nullptr) { - return android_load_sphal_library(filename, dlopen_flags); - } - android_load_sphal_library = - reinterpret_cast( - dlsym(RTLD_NEXT, "android_load_sphal_library")); - if (android_load_sphal_library == nullptr) { - void* vndk = dlopen("libvndksupport.so", RTLD_NOW); - if (vndk != nullptr) { - android_load_sphal_library = - reinterpret_cast( - dlsym(vndk, "android_load_sphal_library")); - } - if (android_load_sphal_library == nullptr) { - return nullptr; - } - } - return android_load_sphal_library(filename, dlopen_flags); -} - -} // namespace - -#elif defined(__WINDOWS__) -#define LoadFunction(function) \ - function = \ - reinterpret_cast(GetProcAddress(libopencl, #function)); -#else -#define LoadFunction(function) \ - function = reinterpret_cast(dlsym(libopencl, #function)); -#endif - -#define LoadFunctionExtension(plat_id, function) \ - function = reinterpret_cast( \ - clGetExtensionFunctionAddressForPlatform(plat_id, #function)); - -#ifdef __WINDOWS__ -void LoadOpenCLFunctions(HMODULE libopencl); -#else -void LoadOpenCLFunctions(void* libopencl, bool use_wrapper); -#endif - -absl::Status LoadOpenCL() { -#ifdef __WINDOWS__ - HMODULE libopencl = LoadLibraryA("OpenCL.dll"); - if (libopencl) { - LoadOpenCLFunctions(libopencl); - return absl::OkStatus(); - } else { - DWORD error_code = GetLastError(); - return absl::UnknownError(absl::StrCat( - "Can not open OpenCL library on this device, error code - ", - error_code)); - } -#else - void* libopencl = nullptr; -#ifdef __APPLE__ - static const char* kClLibName = - "/System/Library/Frameworks/OpenCL.framework/OpenCL"; -#else - static const char* kClLibName = "libOpenCL.so"; -#endif -#ifdef __ANDROID__ - libopencl = AndroidDlopenSphalLibrary(kClLibName, RTLD_NOW | RTLD_LOCAL); - if (!libopencl) { - // Legacy Pixel phone or auto path? - libopencl = - AndroidDlopenSphalLibrary("libOpenCL-pixel.so", RTLD_NOW | RTLD_LOCAL); - if (!libopencl) { - libopencl = - AndroidDlopenSphalLibrary("libOpenCL-car.so", RTLD_NOW | RTLD_LOCAL); - } - if (libopencl) { - typedef void (*enableOpenCL_t)(); - enableOpenCL_t enableOpenCL = - reinterpret_cast(dlsym(libopencl, "enableOpenCL")); - enableOpenCL(); - LoadOpenCLFunctions(libopencl, true); - return absl::OkStatus(); - } - } -#else - libopencl = dlopen(kClLibName, RTLD_NOW | RTLD_LOCAL); -#endif - if (libopencl) { - LoadOpenCLFunctions(libopencl, false); - return absl::OkStatus(); - } - // record error - std::string error(dlerror()); - - // Check if OpenCL functions are found via OpenCL ICD Loader. - LoadOpenCLFunctions(libopencl, /*use_wrapper=*/false); - if (clGetPlatformIDs != nullptr) { - cl_uint num_platforms; - cl_int status = clGetPlatformIDs(0, nullptr, &num_platforms); - if (status == CL_SUCCESS && num_platforms != 0) { - return absl::OkStatus(); - } - return absl::UnknownError("OpenCL is not supported."); - } - return absl::UnknownError( - absl::StrCat("Can not open OpenCL library on this device - ", error)); -#endif -} - -void LoadOpenCLFunctionExtensions(cl_platform_id platform_id) { - // cl_khr_command_buffer extension - LoadFunctionExtension(platform_id, clCreateCommandBufferKHR); - LoadFunctionExtension(platform_id, clRetainCommandBufferKHR); - LoadFunctionExtension(platform_id, clReleaseCommandBufferKHR); - LoadFunctionExtension(platform_id, clFinalizeCommandBufferKHR); - LoadFunctionExtension(platform_id, clEnqueueCommandBufferKHR); - LoadFunctionExtension(platform_id, clCommandNDRangeKernelKHR); - LoadFunctionExtension(platform_id, clGetCommandBufferInfoKHR); -} - -#ifdef __WINDOWS__ -void LoadOpenCLFunctions(HMODULE libopencl) { -#else -void LoadOpenCLFunctions(void* libopencl, bool use_wrapper) { -#ifdef __ANDROID__ - typedef void* (*loadOpenCLPointer_t)(const char* name); - loadOpenCLPointer_t loadOpenCLPointer; - if (use_wrapper) { - loadOpenCLPointer = reinterpret_cast( - dlsym(libopencl, "loadOpenCLPointer")); - } -#endif -#endif - - LoadFunction(clGetPlatformIDs); - LoadFunction(clGetPlatformInfo); - LoadFunction(clGetDeviceIDs); - LoadFunction(clGetDeviceInfo); - LoadFunction(clCreateSubDevices); - LoadFunction(clRetainDevice); - LoadFunction(clReleaseDevice); - LoadFunction(clCreateContext); - LoadFunction(clCreateContextFromType); - LoadFunction(clRetainContext); - LoadFunction(clReleaseContext); - LoadFunction(clGetContextInfo); - LoadFunction(clCreateCommandQueueWithProperties); - LoadFunction(clRetainCommandQueue); - LoadFunction(clReleaseCommandQueue); - LoadFunction(clGetCommandQueueInfo); - LoadFunction(clCreateBuffer); - LoadFunction(clCreateSubBuffer); - LoadFunction(clCreateImage); - LoadFunction(clCreatePipe); - LoadFunction(clRetainMemObject); - LoadFunction(clReleaseMemObject); - LoadFunction(clGetSupportedImageFormats); - LoadFunction(clGetMemObjectInfo); - LoadFunction(clGetImageInfo); - LoadFunction(clGetPipeInfo); - LoadFunction(clSetMemObjectDestructorCallback); - LoadFunction(clSVMAlloc); - LoadFunction(clSVMFree); - LoadFunction(clCreateSamplerWithProperties); - LoadFunction(clRetainSampler); - LoadFunction(clReleaseSampler); - LoadFunction(clGetSamplerInfo); - LoadFunction(clCreateProgramWithSource); - LoadFunction(clCreateProgramWithBinary); - LoadFunction(clCreateProgramWithBuiltInKernels); - LoadFunction(clRetainProgram); - LoadFunction(clReleaseProgram); - LoadFunction(clBuildProgram); - LoadFunction(clCompileProgram); - LoadFunction(clLinkProgram); - LoadFunction(clUnloadPlatformCompiler); - LoadFunction(clGetProgramInfo); - LoadFunction(clGetProgramBuildInfo); - LoadFunction(clCreateKernel); - LoadFunction(clCreateKernelsInProgram); - LoadFunction(clRetainKernel); - LoadFunction(clReleaseKernel); - LoadFunction(clSetKernelArg); - LoadFunction(clSetKernelArgSVMPointer); - LoadFunction(clSetKernelExecInfo); - LoadFunction(clGetKernelInfo); - LoadFunction(clGetKernelArgInfo); - LoadFunction(clGetKernelWorkGroupInfo); - LoadFunction(clWaitForEvents); - LoadFunction(clGetEventInfo); - LoadFunction(clCreateUserEvent); - LoadFunction(clRetainEvent); - LoadFunction(clReleaseEvent); - LoadFunction(clSetUserEventStatus); - LoadFunction(clSetEventCallback); - LoadFunction(clGetEventProfilingInfo); - LoadFunction(clFlush); - LoadFunction(clFinish); - LoadFunction(clEnqueueReadBuffer); - LoadFunction(clEnqueueReadBufferRect); - LoadFunction(clEnqueueWriteBuffer); - LoadFunction(clEnqueueWriteBufferRect); - LoadFunction(clEnqueueFillBuffer); - LoadFunction(clEnqueueCopyBuffer); - LoadFunction(clEnqueueCopyBufferRect); - LoadFunction(clEnqueueReadImage); - LoadFunction(clEnqueueWriteImage); - LoadFunction(clEnqueueFillImage); - LoadFunction(clEnqueueCopyImage); - LoadFunction(clEnqueueCopyImageToBuffer); - LoadFunction(clEnqueueCopyBufferToImage); - LoadFunction(clEnqueueMapBuffer); - LoadFunction(clEnqueueMapImage); - LoadFunction(clEnqueueUnmapMemObject); - LoadFunction(clEnqueueMigrateMemObjects); - LoadFunction(clEnqueueNDRangeKernel); - LoadFunction(clEnqueueNativeKernel); - LoadFunction(clEnqueueMarkerWithWaitList); - LoadFunction(clEnqueueBarrierWithWaitList); - LoadFunction(clEnqueueSVMFree); - LoadFunction(clEnqueueSVMMemcpy); - LoadFunction(clEnqueueSVMMemFill); - LoadFunction(clEnqueueSVMMap); - LoadFunction(clEnqueueSVMUnmap); - LoadFunction(clGetExtensionFunctionAddressForPlatform); - LoadFunction(clCreateImage2D); - LoadFunction(clCreateImage3D); - LoadFunction(clEnqueueMarker); - LoadFunction(clEnqueueWaitForEvents); - LoadFunction(clEnqueueBarrier); - LoadFunction(clUnloadCompiler); - LoadFunction(clGetExtensionFunctionAddress); - LoadFunction(clCreateCommandQueue); - LoadFunction(clCreateSampler); - LoadFunction(clEnqueueTask); - - // OpenGL sharing - LoadFunction(clCreateFromGLBuffer); - LoadFunction(clCreateFromGLTexture); - LoadFunction(clEnqueueAcquireGLObjects); - LoadFunction(clEnqueueReleaseGLObjects); - - // cl_khr_egl_event extension - LoadFunction(clCreateEventFromEGLSyncKHR); - - // EGL sharing - LoadFunction(clCreateFromEGLImageKHR); - LoadFunction(clEnqueueAcquireEGLObjectsKHR); - LoadFunction(clEnqueueReleaseEGLObjectsKHR); - - // OpenCL 3.0 - LoadFunction(clCreateBufferWithProperties); - LoadFunction(clCreateImageWithProperties); -} - -// No OpenCL support, do not set function addresses -PFN_clGetPlatformIDs clGetPlatformIDs; -PFN_clGetPlatformInfo clGetPlatformInfo; -PFN_clGetDeviceIDs clGetDeviceIDs; -PFN_clGetDeviceInfo clGetDeviceInfo; -PFN_clCreateSubDevices clCreateSubDevices; -PFN_clRetainDevice clRetainDevice; -PFN_clReleaseDevice clReleaseDevice; -PFN_clCreateContext clCreateContext; -PFN_clCreateContextFromType clCreateContextFromType; -PFN_clRetainContext clRetainContext; -PFN_clReleaseContext clReleaseContext; -PFN_clGetContextInfo clGetContextInfo; -PFN_clCreateCommandQueueWithProperties clCreateCommandQueueWithProperties; -PFN_clRetainCommandQueue clRetainCommandQueue; -PFN_clReleaseCommandQueue clReleaseCommandQueue; -PFN_clGetCommandQueueInfo clGetCommandQueueInfo; -PFN_clCreateBuffer clCreateBuffer; -PFN_clCreateSubBuffer clCreateSubBuffer; -PFN_clCreateImage clCreateImage; -PFN_clCreatePipe clCreatePipe; -PFN_clRetainMemObject clRetainMemObject; -PFN_clReleaseMemObject clReleaseMemObject; -PFN_clGetSupportedImageFormats clGetSupportedImageFormats; -PFN_clGetMemObjectInfo clGetMemObjectInfo; -PFN_clGetImageInfo clGetImageInfo; -PFN_clGetPipeInfo clGetPipeInfo; -PFN_clSetMemObjectDestructorCallback clSetMemObjectDestructorCallback; -PFN_clSVMAlloc clSVMAlloc; -PFN_clSVMFree clSVMFree; -PFN_clCreateSamplerWithProperties clCreateSamplerWithProperties; -PFN_clRetainSampler clRetainSampler; -PFN_clReleaseSampler clReleaseSampler; -PFN_clGetSamplerInfo clGetSamplerInfo; -PFN_clCreateProgramWithSource clCreateProgramWithSource; -PFN_clCreateProgramWithBinary clCreateProgramWithBinary; -PFN_clCreateProgramWithBuiltInKernels clCreateProgramWithBuiltInKernels; -PFN_clRetainProgram clRetainProgram; -PFN_clReleaseProgram clReleaseProgram; -PFN_clBuildProgram clBuildProgram; -PFN_clCompileProgram clCompileProgram; -PFN_clLinkProgram clLinkProgram; -PFN_clUnloadPlatformCompiler clUnloadPlatformCompiler; -PFN_clGetProgramInfo clGetProgramInfo; -PFN_clGetProgramBuildInfo clGetProgramBuildInfo; -PFN_clCreateKernel clCreateKernel; -PFN_clCreateKernelsInProgram clCreateKernelsInProgram; -PFN_clRetainKernel clRetainKernel; -PFN_clReleaseKernel clReleaseKernel; -PFN_clSetKernelArg clSetKernelArg; -PFN_clSetKernelArgSVMPointer clSetKernelArgSVMPointer; -PFN_clSetKernelExecInfo clSetKernelExecInfo; -PFN_clGetKernelInfo clGetKernelInfo; -PFN_clGetKernelArgInfo clGetKernelArgInfo; -PFN_clGetKernelWorkGroupInfo clGetKernelWorkGroupInfo; -PFN_clWaitForEvents clWaitForEvents; -PFN_clGetEventInfo clGetEventInfo; -PFN_clCreateUserEvent clCreateUserEvent; -PFN_clRetainEvent clRetainEvent; -PFN_clReleaseEvent clReleaseEvent; -PFN_clSetUserEventStatus clSetUserEventStatus; -PFN_clSetEventCallback clSetEventCallback; -PFN_clGetEventProfilingInfo clGetEventProfilingInfo; -PFN_clFlush clFlush; -PFN_clFinish clFinish; -PFN_clEnqueueReadBuffer clEnqueueReadBuffer; -PFN_clEnqueueReadBufferRect clEnqueueReadBufferRect; -PFN_clEnqueueWriteBuffer clEnqueueWriteBuffer; -PFN_clEnqueueWriteBufferRect clEnqueueWriteBufferRect; -PFN_clEnqueueFillBuffer clEnqueueFillBuffer; -PFN_clEnqueueCopyBuffer clEnqueueCopyBuffer; -PFN_clEnqueueCopyBufferRect clEnqueueCopyBufferRect; -PFN_clEnqueueReadImage clEnqueueReadImage; -PFN_clEnqueueWriteImage clEnqueueWriteImage; -PFN_clEnqueueFillImage clEnqueueFillImage; -PFN_clEnqueueCopyImage clEnqueueCopyImage; -PFN_clEnqueueCopyImageToBuffer clEnqueueCopyImageToBuffer; -PFN_clEnqueueCopyBufferToImage clEnqueueCopyBufferToImage; -PFN_clEnqueueMapBuffer clEnqueueMapBuffer; -PFN_clEnqueueMapImage clEnqueueMapImage; -PFN_clEnqueueUnmapMemObject clEnqueueUnmapMemObject; -PFN_clEnqueueMigrateMemObjects clEnqueueMigrateMemObjects; -PFN_clEnqueueNDRangeKernel clEnqueueNDRangeKernel; -PFN_clEnqueueNativeKernel clEnqueueNativeKernel; -PFN_clEnqueueMarkerWithWaitList clEnqueueMarkerWithWaitList; -PFN_clEnqueueBarrierWithWaitList clEnqueueBarrierWithWaitList; -PFN_clEnqueueSVMFree clEnqueueSVMFree; -PFN_clEnqueueSVMMemcpy clEnqueueSVMMemcpy; -PFN_clEnqueueSVMMemFill clEnqueueSVMMemFill; -PFN_clEnqueueSVMMap clEnqueueSVMMap; -PFN_clEnqueueSVMUnmap clEnqueueSVMUnmap; -PFN_clGetExtensionFunctionAddressForPlatform - clGetExtensionFunctionAddressForPlatform; -PFN_clCreateImage2D clCreateImage2D; -PFN_clCreateImage3D clCreateImage3D; -PFN_clEnqueueMarker clEnqueueMarker; -PFN_clEnqueueWaitForEvents clEnqueueWaitForEvents; -PFN_clEnqueueBarrier clEnqueueBarrier; -PFN_clUnloadCompiler clUnloadCompiler; -PFN_clGetExtensionFunctionAddress clGetExtensionFunctionAddress; -PFN_clCreateCommandQueue clCreateCommandQueue; -PFN_clCreateSampler clCreateSampler; -PFN_clEnqueueTask clEnqueueTask; - -// OpenGL sharing -PFN_clCreateFromGLBuffer clCreateFromGLBuffer; -PFN_clCreateFromGLTexture clCreateFromGLTexture; -PFN_clEnqueueAcquireGLObjects clEnqueueAcquireGLObjects; -PFN_clEnqueueReleaseGLObjects clEnqueueReleaseGLObjects; - -// cl_khr_egl_event extension -PFN_clCreateEventFromEGLSyncKHR clCreateEventFromEGLSyncKHR; - -// EGL sharing -PFN_clCreateFromEGLImageKHR clCreateFromEGLImageKHR; -PFN_clEnqueueAcquireEGLObjectsKHR clEnqueueAcquireEGLObjectsKHR; -PFN_clEnqueueReleaseEGLObjectsKHR clEnqueueReleaseEGLObjectsKHR; - -// cl_khr_command_buffer extension -PFN_clCreateCommandBufferKHR clCreateCommandBufferKHR; -PFN_clRetainCommandBufferKHR clRetainCommandBufferKHR; -PFN_clReleaseCommandBufferKHR clReleaseCommandBufferKHR; -PFN_clFinalizeCommandBufferKHR clFinalizeCommandBufferKHR; -PFN_clEnqueueCommandBufferKHR clEnqueueCommandBufferKHR; -PFN_clCommandNDRangeKernelKHR clCommandNDRangeKernelKHR; -PFN_clGetCommandBufferInfoKHR clGetCommandBufferInfoKHR; - -// OpenCL 3.0 -PFN_clCreateBufferWithProperties clCreateBufferWithProperties; -PFN_clCreateImageWithProperties clCreateImageWithProperties; - -cl_mem CreateImage2DLegacy(cl_context context, cl_mem_flags flags, - const cl_image_format* image_format, - const cl_image_desc* image_desc, void* host_ptr, - cl_int* errcode_ret) { - if (clCreateImage) { // clCreateImage available since OpenCL 1.2 - return clCreateImage(context, flags, image_format, image_desc, host_ptr, - errcode_ret); - } else { - return clCreateImage2D(context, flags, image_format, - image_desc->image_width, image_desc->image_height, - image_desc->image_row_pitch, host_ptr, errcode_ret); - } -} - -cl_mem CreateImage3DLegacy(cl_context context, cl_mem_flags flags, - const cl_image_format* image_format, - const cl_image_desc* image_desc, void* host_ptr, - cl_int* errcode_ret) { - if (clCreateImage) { // clCreateImage available since OpenCL 1.2 - return clCreateImage(context, flags, image_format, image_desc, host_ptr, - errcode_ret); - } else { - return clCreateImage3D(context, flags, image_format, - image_desc->image_width, image_desc->image_height, - image_desc->image_depth, image_desc->image_row_pitch, - image_desc->image_slice_pitch, host_ptr, - errcode_ret); - } -} -} // namespace cl -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h b/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h deleted file mode 100644 index cfbeb805dbb49d..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h +++ /dev/null @@ -1,737 +0,0 @@ -// Copyright 2024 The TensorFlow Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is copied from third_party/ml_drift/cl/opencl_wrapper.h. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_OPENCL_WRAPPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_OPENCL_WRAPPER_H_ - -#include - -#include "absl/status/status.h" -#include // IWYU pragma: export -#include // IWYU pragma: export -#include // IWYU pragma: export -#include // IWYU pragma: export -#include // IWYU pragma: export - -namespace litert { -namespace cl { - -absl::Status LoadOpenCL(); -void LoadOpenCLFunctionExtensions(cl_platform_id platform_id); - -typedef cl_int(CL_API_CALL *PFN_clGetPlatformIDs)( - cl_uint /* num_entries */, cl_platform_id * /* platforms */, - cl_uint * /* num_platforms */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetPlatformInfo)( - cl_platform_id /* platform */, cl_platform_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetDeviceIDs)( - cl_platform_id /* platform */, cl_device_type /* device_type */, - cl_uint /* num_entries */, cl_device_id * /* devices */, - cl_uint * /* num_devices */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetDeviceInfo)( - cl_device_id /* device */, cl_device_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clCreateSubDevices)( - cl_device_id /* in_device */, - const cl_device_partition_property * /* properties */, - cl_uint /* num_devices */, cl_device_id * /* out_devices */, - cl_uint * /* num_devices_ret */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clRetainDevice)(cl_device_id /* device */) - CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clReleaseDevice)(cl_device_id /* device */) - CL_API_SUFFIX__VERSION_1_2; -typedef cl_context(CL_API_CALL *PFN_clCreateContext)( - const cl_context_properties * /* properties */, cl_uint /* num_devices */, - const cl_device_id * /* devices */, - void(CL_CALLBACK * /* pfn_notify */)(const char *, const void *, size_t, - void *), - void * /* user_data */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_context(CL_API_CALL *PFN_clCreateContextFromType)( - const cl_context_properties * /* properties */, - cl_device_type /* device_type */, - void(CL_CALLBACK * /* pfn_notify*/)(const char *, const void *, size_t, - void *), - void * /* user_data */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clRetainContext)(cl_context /* context */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseContext)(cl_context /* context */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetContextInfo)( - cl_context /* context */, cl_context_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_command_queue(CL_API_CALL *PFN_clCreateCommandQueueWithProperties)( - cl_context /* context */, cl_device_id /* device */, - const cl_queue_properties * /* properties */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clRetainCommandQueue)( - cl_command_queue /* command_queue */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseCommandQueue)( - cl_command_queue /* command_queue */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetCommandQueueInfo)( - cl_command_queue /* command_queue */, - cl_command_queue_info /* param_name */, size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_mem(CL_API_CALL *PFN_clCreateBuffer)( - cl_context /* context */, cl_mem_flags /* flags */, size_t /* size */, - void * /* host_ptr */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_mem(CL_API_CALL *PFN_clCreateSubBuffer)( - cl_mem /* buffer */, cl_mem_flags /* flags */, - cl_buffer_create_type /* buffer_create_type */, - const void * /* buffer_create_info */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_1; -typedef cl_mem(CL_API_CALL *PFN_clCreateImage)( - cl_context /* context */, cl_mem_flags /* flags */, - const cl_image_format * /* image_format */, - const cl_image_desc * /* image_desc */, void * /* host_ptr */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_mem(CL_API_CALL *PFN_clCreatePipe)( - cl_context /* context */, cl_mem_flags /* flags */, - cl_uint /* pipe_packet_size */, cl_uint /* pipe_max_packets */, - const cl_pipe_properties * /* properties */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clRetainMemObject)(cl_mem /* memobj */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseMemObject)(cl_mem /* memobj */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetSupportedImageFormats)( - cl_context /* context */, cl_mem_flags /* flags */, - cl_mem_object_type /* image_type */, cl_uint /* num_entries */, - cl_image_format * /* image_formats */, - cl_uint * /* num_image_formats */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetMemObjectInfo)( - cl_mem /* memobj */, cl_mem_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetImageInfo)( - cl_mem /* image */, cl_image_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetPipeInfo)( - cl_mem /* pipe */, cl_pipe_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clSetMemObjectDestructorCallback)( - cl_mem /* memobj */, - void(CL_CALLBACK * /*pfn_notify*/)(cl_mem /* memobj */, - void * /*user_data*/), - void * /*user_data */) CL_API_SUFFIX__VERSION_1_1; -typedef void *(CL_API_CALL *PFN_clSVMAlloc)( - cl_context /* context */, cl_svm_mem_flags /* flags */, size_t /* size */, - cl_uint /* alignment */)CL_API_SUFFIX__VERSION_2_0; -typedef void(CL_API_CALL *PFN_clSVMFree)(cl_context /* context */, - void * /* svm_pointer */) - CL_API_SUFFIX__VERSION_2_0; -typedef cl_sampler(CL_API_CALL *PFN_clCreateSamplerWithProperties)( - cl_context /* context */, - const cl_sampler_properties * /* normalized_coords */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clRetainSampler)(cl_sampler /* sampler */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseSampler)(cl_sampler /* sampler */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetSamplerInfo)( - cl_sampler /* sampler */, cl_sampler_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_program(CL_API_CALL *PFN_clCreateProgramWithSource)( - cl_context /* context */, cl_uint /* count */, const char ** /* strings */, - const size_t * /* lengths */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_program(CL_API_CALL *PFN_clCreateProgramWithBinary)( - cl_context /* context */, cl_uint /* num_devices */, - const cl_device_id * /* device_list */, const size_t * /* lengths */, - const unsigned char ** /* binaries */, cl_int * /* binary_status */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_program(CL_API_CALL *PFN_clCreateProgramWithBuiltInKernels)( - cl_context /* context */, cl_uint /* num_devices */, - const cl_device_id * /* device_list */, const char * /* kernel_names */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clRetainProgram)(cl_program /* program */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseProgram)(cl_program /* program */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clBuildProgram)( - cl_program /* program */, cl_uint /* num_devices */, - const cl_device_id * /* device_list */, const char * /* options */, - void(CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, - void * /* user_data */), - void * /* user_data */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clCompileProgram)( - cl_program /* program */, cl_uint /* num_devices */, - const cl_device_id * /* device_list */, const char * /* options */, - cl_uint /* num_input_headers */, const cl_program * /* input_headers */, - const char ** /* header_include_names */, - void(CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, - void * /* user_data */), - void * /* user_data */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_program(CL_API_CALL *PFN_clLinkProgram)( - cl_context /* context */, cl_uint /* num_devices */, - const cl_device_id * /* device_list */, const char * /* options */, - cl_uint /* num_input_programs */, const cl_program * /* input_programs */, - void(CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, - void * /* user_data */), - void * /* user_data */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clUnloadPlatformCompiler)( - cl_platform_id /* platform */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clGetProgramInfo)( - cl_program /* program */, cl_program_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetProgramBuildInfo)( - cl_program /* program */, cl_device_id /* device */, - cl_program_build_info /* param_name */, size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_kernel(CL_API_CALL *PFN_clCreateKernel)( - cl_program /* program */, const char * /* kernel_name */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clCreateKernelsInProgram)( - cl_program /* program */, cl_uint /* num_kernels */, - cl_kernel * /* kernels */, - cl_uint * /* num_kernels_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clRetainKernel)(cl_kernel /* kernel */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseKernel)(cl_kernel /* kernel */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clSetKernelArg)( - cl_kernel /* kernel */, cl_uint /* arg_index */, size_t /* arg_size */, - const void * /* arg_value */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clSetKernelArgSVMPointer)( - cl_kernel /* kernel */, cl_uint /* arg_index */, - const void * /* arg_value */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clSetKernelExecInfo)( - cl_kernel /* kernel */, cl_kernel_exec_info /* param_name */, - size_t /* param_value_size */, - const void * /* param_value */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clGetKernelInfo)( - cl_kernel /* kernel */, cl_kernel_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetKernelArgInfo)( - cl_kernel /* kernel */, cl_uint /* arg_indx */, - cl_kernel_arg_info /* param_name */, size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clGetKernelWorkGroupInfo)( - cl_kernel /* kernel */, cl_device_id /* device */, - cl_kernel_work_group_info /* param_name */, size_t /* param_value_size */, - void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clWaitForEvents)( - cl_uint /* num_events */, - const cl_event * /* event_list */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clGetEventInfo)( - cl_event /* event */, cl_event_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_event(CL_API_CALL *PFN_clCreateUserEvent)(cl_context /* context */, - cl_int * /* errcode_ret */) - CL_API_SUFFIX__VERSION_1_1; -typedef cl_int(CL_API_CALL *PFN_clRetainEvent)(cl_event /* event */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clReleaseEvent)(cl_event /* event */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clSetUserEventStatus)( - cl_event /* event */, - cl_int /* execution_status */) CL_API_SUFFIX__VERSION_1_1; -typedef cl_int(CL_API_CALL *PFN_clSetEventCallback)( - cl_event /* event */, cl_int /* command_exec_callback_type */, - void(CL_CALLBACK * /* pfn_notify */)(cl_event, cl_int, void *), - void * /* user_data */) CL_API_SUFFIX__VERSION_1_1; -typedef cl_int(CL_API_CALL *PFN_clGetEventProfilingInfo)( - cl_event /* event */, cl_profiling_info /* param_name */, - size_t /* param_value_size */, void * /* param_value */, - size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clFlush)(cl_command_queue /* command_queue */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clFinish)(cl_command_queue /* command_queue */) - CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueReadBuffer)( - cl_command_queue /* command_queue */, cl_mem /* buffer */, - cl_bool /* blocking_read */, size_t /* offset */, size_t /* size */, - void * /* ptr */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueReadBufferRect)( - cl_command_queue /* command_queue */, cl_mem /* buffer */, - cl_bool /* blocking_read */, const size_t * /* buffer_offset */, - const size_t * /* host_offset */, const size_t * /* region */, - size_t /* buffer_row_pitch */, size_t /* buffer_slice_pitch */, - size_t /* host_row_pitch */, size_t /* host_slice_pitch */, - void * /* ptr */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1; -typedef cl_int(CL_API_CALL *PFN_clEnqueueWriteBuffer)( - cl_command_queue /* command_queue */, cl_mem /* buffer */, - cl_bool /* blocking_write */, size_t /* offset */, size_t /* size */, - const void * /* ptr */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueWriteBufferRect)( - cl_command_queue /* command_queue */, cl_mem /* buffer */, - cl_bool /* blocking_write */, const size_t * /* buffer_offset */, - const size_t * /* host_offset */, const size_t * /* region */, - size_t /* buffer_row_pitch */, size_t /* buffer_slice_pitch */, - size_t /* host_row_pitch */, size_t /* host_slice_pitch */, - const void * /* ptr */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1; -typedef cl_int(CL_API_CALL *PFN_clEnqueueFillBuffer)( - cl_command_queue /* command_queue */, cl_mem /* buffer */, - const void * /* pattern */, size_t /* pattern_size */, size_t /* offset */, - size_t /* size */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyBuffer)( - cl_command_queue /* command_queue */, cl_mem /* src_buffer */, - cl_mem /* dst_buffer */, size_t /* src_offset */, size_t /* dst_offset */, - size_t /* size */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyBufferRect)( - cl_command_queue /* command_queue */, cl_mem /* src_buffer */, - cl_mem /* dst_buffer */, const size_t * /* src_origin */, - const size_t * /* dst_origin */, const size_t * /* region */, - size_t /* src_row_pitch */, size_t /* src_slice_pitch */, - size_t /* dst_row_pitch */, size_t /* dst_slice_pitch */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1; -typedef cl_int(CL_API_CALL *PFN_clEnqueueReadImage)( - cl_command_queue /* command_queue */, cl_mem /* image */, - cl_bool /* blocking_read */, const size_t * /* origin[3] */, - const size_t * /* region[3] */, size_t /* row_pitch */, - size_t /* slice_pitch */, void * /* ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueWriteImage)( - cl_command_queue /* command_queue */, cl_mem /* image */, - cl_bool /* blocking_write */, const size_t * /* origin[3] */, - const size_t * /* region[3] */, size_t /* input_row_pitch */, - size_t /* input_slice_pitch */, const void * /* ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueFillImage)( - cl_command_queue /* command_queue */, cl_mem /* image */, - const void * /* fill_color */, const size_t * /* origin[3] */, - const size_t * /* region[3] */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyImage)( - cl_command_queue /* command_queue */, cl_mem /* src_image */, - cl_mem /* dst_image */, const size_t * /* src_origin[3] */, - const size_t * /* dst_origin[3] */, const size_t * /* region[3] */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyImageToBuffer)( - cl_command_queue /* command_queue */, cl_mem /* src_image */, - cl_mem /* dst_buffer */, const size_t * /* src_origin[3] */, - const size_t * /* region[3] */, size_t /* dst_offset */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyBufferToImage)( - cl_command_queue /* command_queue */, cl_mem /* src_buffer */, - cl_mem /* dst_image */, size_t /* src_offset */, - const size_t * /* dst_origin[3] */, const size_t * /* region[3] */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef void *(CL_API_CALL *PFN_clEnqueueMapBuffer)( - cl_command_queue /* command_queue */, cl_mem /* buffer */, - cl_bool /* blocking_map */, cl_map_flags /* map_flags */, - size_t /* offset */, size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, cl_event * /* event */, - cl_int * /* errcode_ret */)CL_API_SUFFIX__VERSION_1_0; -typedef void *(CL_API_CALL *PFN_clEnqueueMapImage)( - cl_command_queue /* command_queue */, cl_mem /* image */, - cl_bool /* blocking_map */, cl_map_flags /* map_flags */, - const size_t * /* origin[3] */, const size_t * /* region[3] */, - size_t * /* image_row_pitch */, size_t * /* image_slice_pitch */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, cl_event * /* event */, - cl_int * /* errcode_ret */)CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueUnmapMemObject)( - cl_command_queue /* command_queue */, cl_mem /* memobj */, - void * /* mapped_ptr */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueMigrateMemObjects)( - cl_command_queue /* command_queue */, cl_uint /* num_mem_objects */, - const cl_mem * /* mem_objects */, cl_mem_migration_flags /* flags */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clEnqueueNDRangeKernel)( - cl_command_queue /* command_queue */, cl_kernel /* kernel */, - cl_uint /* work_dim */, const size_t * /* global_work_offset */, - const size_t * /* global_work_size */, const size_t * /* local_work_size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueNativeKernel)( - cl_command_queue /* command_queue */, - void(CL_CALLBACK * /*user_func*/)(void *), void * /* args */, - size_t /* cb_args */, cl_uint /* num_mem_objects */, - const cl_mem * /* mem_list */, const void ** /* args_mem_loc */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueMarkerWithWaitList)( - cl_command_queue /* command_queue */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clEnqueueBarrierWithWaitList)( - cl_command_queue /* command_queue */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMFree)( - cl_command_queue /* command_queue */, cl_uint /* num_svm_pointers */, - void *[] /* svm_pointers[] */, - void(CL_CALLBACK * /*pfn_free_func*/)(cl_command_queue /* queue */, - cl_uint /* num_svm_pointers */, - void *[] /* svm_pointers[] */, - void * /* user_data */), - void * /* user_data */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMMemcpy)( - cl_command_queue /* command_queue */, cl_bool /* blocking_copy */, - void * /* dst_ptr */, const void * /* src_ptr */, size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMMemFill)( - cl_command_queue /* command_queue */, void * /* svm_ptr */, - const void * /* pattern */, size_t /* pattern_size */, size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMMap)( - cl_command_queue /* command_queue */, cl_bool /* blocking_map */, - cl_map_flags /* flags */, void * /* svm_ptr */, size_t /* size */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; -typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMUnmap)( - cl_command_queue /* command_queue */, void * /* svm_ptr */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; -typedef void *(CL_API_CALL *PFN_clGetExtensionFunctionAddressForPlatform)( - cl_platform_id /* platform */, - const char * /* func_name */)CL_API_SUFFIX__VERSION_1_2; -typedef cl_mem(CL_API_CALL *PFN_clCreateImage2D)( - cl_context /* context */, cl_mem_flags /* flags */, - const cl_image_format * /* image_format */, size_t /* image_width */, - size_t /* image_height */, size_t /* image_row_pitch */, - void * /* host_ptr */, cl_int * /* errcode_ret */); -typedef cl_mem(CL_API_CALL *PFN_clCreateImage3D)( - cl_context /* context */, cl_mem_flags /* flags */, - const cl_image_format * /* image_format */, size_t /* image_width */, - size_t /* image_height */, size_t /* image_depth */, - size_t /* image_row_pitch */, size_t /* image_slice_pitch */, - void * /* host_ptr */, cl_int * /* errcode_ret */); -typedef cl_int(CL_API_CALL *PFN_clEnqueueMarker)( - cl_command_queue /* command_queue */, cl_event * /* event */); -typedef cl_int(CL_API_CALL *PFN_clEnqueueWaitForEvents)( - cl_command_queue /* command_queue */, cl_uint /* num_events */, - const cl_event * /* event_list */); -typedef cl_int(CL_API_CALL *PFN_clEnqueueBarrier)( - cl_command_queue /* command_queue */); -typedef cl_int(CL_API_CALL *PFN_clUnloadCompiler)(); -typedef void *(CL_API_CALL *PFN_clGetExtensionFunctionAddress)( - const char * /* func_name */); -typedef cl_command_queue(CL_API_CALL *PFN_clCreateCommandQueue)( - cl_context /* context */, cl_device_id /* device */, - cl_command_queue_properties /* properties */, cl_int * /* errcode_ret */); -typedef cl_sampler(CL_API_CALL *PFN_clCreateSampler)( - cl_context /* context */, cl_bool /* normalized_coords */, - cl_addressing_mode /* addressing_mode */, cl_filter_mode /* filter_mode */, - cl_int * /* errcode_ret */); -typedef cl_int(CL_API_CALL *PFN_clEnqueueTask)( - cl_command_queue /* command_queue */, cl_kernel /* kernel */, - cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, cl_event * /* event */); - -// OpenGL sharing -typedef cl_mem(CL_API_CALL *PFN_clCreateFromGLBuffer)(cl_context, cl_mem_flags, - cl_GLuint, int *); -typedef cl_mem(CL_API_CALL *PFN_clCreateFromGLTexture)( - cl_context /* context */, cl_mem_flags /* flags */, cl_GLenum /* target */, - cl_GLint /* miplevel */, cl_GLuint /* texture */, - cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; -typedef cl_int(CL_API_CALL *PFN_clEnqueueAcquireGLObjects)( - cl_command_queue /* command_queue */, cl_uint /* num_objects */, - const cl_mem * /* mem_objects */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, cl_event * /* event */); -typedef cl_int(CL_API_CALL *PFN_clEnqueueReleaseGLObjects)( - cl_command_queue /* command_queue */, cl_uint /* num_objects */, - const cl_mem * /* mem_objects */, cl_uint /* num_events_in_wait_list */, - const cl_event * /* event_wait_list */, - cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; - -// cl_khr_egl_event extension - -// CLeglDisplayKHR is an opaque handle to an EGLDisplay -typedef void *CLeglDisplayKHR; - -// CLeglSyncKHR is an opaque handle to an EGLSync object -typedef void *CLeglSyncKHR; - -typedef cl_event(CL_API_CALL *PFN_clCreateEventFromEGLSyncKHR)( - cl_context /* context */, CLeglSyncKHR /* sync */, - CLeglDisplayKHR /* display */, cl_int * /* errcode_ret */); - -// EGL sharing -typedef cl_mem(CL_API_CALL *PFN_clCreateFromEGLImageKHR)( - cl_context /*context*/, CLeglDisplayKHR /*display*/, - CLeglImageKHR /*image*/, cl_mem_flags /*flags*/, - const cl_egl_image_properties_khr * /*properties*/, - cl_int * /*errcode_ret*/); -typedef cl_int(CL_API_CALL *PFN_clEnqueueAcquireEGLObjectsKHR)( - cl_command_queue /*command_queue*/, cl_uint /*num_objects*/, - const cl_mem * /*mem_objects*/, cl_uint /*num_events_in_wait_list*/, - const cl_event * /*event_wait_list*/, cl_event * /*event*/); -typedef cl_int(CL_API_CALL *PFN_clEnqueueReleaseEGLObjectsKHR)( - cl_command_queue /*command_queue*/, cl_uint /*num_objects*/, - const cl_mem * /*mem_objects*/, cl_uint /*num_events_in_wait_list*/, - const cl_event * /*event_wait_list*/, cl_event * /*event*/); - -// cl_khr_command_buffer -typedef cl_command_buffer_khr(CL_API_CALL *PFN_clCreateCommandBufferKHR)( - cl_uint /*num_queues*/, const cl_command_queue * /*queues*/, - const cl_command_buffer_properties_khr * /*properties*/, - cl_int * /*errcode_ret*/); - -typedef cl_int(CL_API_CALL *PFN_clRetainCommandBufferKHR)( - cl_command_buffer_khr /*command_buffer*/); - -typedef cl_int(CL_API_CALL *PFN_clReleaseCommandBufferKHR)( - cl_command_buffer_khr /*command_buffer*/); - -typedef cl_int(CL_API_CALL *PFN_clFinalizeCommandBufferKHR)( - cl_command_buffer_khr /*command_buffer*/); - -typedef cl_int(CL_API_CALL *PFN_clEnqueueCommandBufferKHR)( - cl_uint /*num_queues*/, cl_command_queue * /*queues*/, - cl_command_buffer_khr /*command_buffer*/, - cl_uint /*num_events_in_wait_list*/, const cl_event * /*event_wait_list*/, - cl_event * /*event*/); - -#if CL_KHR_COMMAND_BUFFER_EXTENSION_VERSION >= CL_MAKE_VERSION(0, 9, 5) -typedef cl_int(CL_API_CALL *PFN_clCommandNDRangeKernelKHR)( - cl_command_buffer_khr /*command_buffer*/, - cl_command_queue /*command_queue*/, - const cl_command_properties_khr * /*properties*/, cl_kernel /*kernel*/, - cl_uint /*work_dim*/, const size_t * /*global_work_offset*/, - const size_t * /*global_work_size*/, const size_t * /*local_work_size*/, - cl_uint /*num_sync_points_in_wait_list*/, - const cl_sync_point_khr * /*sync_point_wait_list*/, - cl_sync_point_khr * /*sync_point*/, - cl_mutable_command_khr * /*mutable_handle*/); -#else -typedef cl_int(CL_API_CALL *PFN_clCommandNDRangeKernelKHR)( - cl_command_buffer_khr /*command_buffer*/, - cl_command_queue /*command_queue*/, - const cl_ndrange_kernel_command_properties_khr * /*properties*/, - cl_kernel /*kernel*/, cl_uint /*work_dim*/, - const size_t * /*global_work_offset*/, const size_t * /*global_work_size*/, - const size_t * /*local_work_size*/, - cl_uint /*num_sync_points_in_wait_list*/, - const cl_sync_point_khr * /*sync_point_wait_list*/, - cl_sync_point_khr * /*sync_point*/, - cl_mutable_command_khr * /*mutable_handle*/); -#endif - -typedef cl_int(CL_API_CALL *PFN_clGetCommandBufferInfoKHR)( - cl_command_buffer_khr /*command_buffer*/, - cl_command_buffer_info_khr /*param_name*/, size_t /*param_value_size*/, - void * /*param_value*/, size_t * /*param_value_size_ret*/); - -// OpenCL 3.0 -typedef cl_mem(CL_API_CALL *PFN_clCreateBufferWithProperties)( - cl_context /*context*/, const cl_mem_properties * /*properties*/, - cl_mem_flags /*flags*/, size_t /*size*/, void * /*host_ptr*/, - cl_int * /*errcode_ret*/); -typedef cl_mem(CL_API_CALL *PFN_clCreateImageWithProperties)( - cl_context /*context*/, const cl_mem_properties * /*properties*/, - cl_mem_flags /*flags*/, const cl_image_format * /*image_format*/, - const cl_image_desc * /*image_desc*/, void * /*host_ptr*/, - cl_int * /*errcode_ret*/); - -extern PFN_clGetPlatformIDs clGetPlatformIDs; -extern PFN_clGetPlatformInfo clGetPlatformInfo; -extern PFN_clGetDeviceIDs clGetDeviceIDs; -extern PFN_clGetDeviceInfo clGetDeviceInfo; -extern PFN_clCreateSubDevices clCreateSubDevices; -extern PFN_clRetainDevice clRetainDevice; -extern PFN_clReleaseDevice clReleaseDevice; -extern PFN_clCreateContext clCreateContext; -extern PFN_clCreateContextFromType clCreateContextFromType; -extern PFN_clRetainContext clRetainContext; -extern PFN_clReleaseContext clReleaseContext; -extern PFN_clGetContextInfo clGetContextInfo; -extern PFN_clCreateCommandQueueWithProperties - clCreateCommandQueueWithProperties; -extern PFN_clRetainCommandQueue clRetainCommandQueue; -extern PFN_clReleaseCommandQueue clReleaseCommandQueue; -extern PFN_clGetCommandQueueInfo clGetCommandQueueInfo; -extern PFN_clCreateBuffer clCreateBuffer; -extern PFN_clCreateSubBuffer clCreateSubBuffer; -extern PFN_clCreateImage clCreateImage; -extern PFN_clCreatePipe clCreatePipe; -extern PFN_clRetainMemObject clRetainMemObject; -extern PFN_clReleaseMemObject clReleaseMemObject; -extern PFN_clGetSupportedImageFormats clGetSupportedImageFormats; -extern PFN_clGetMemObjectInfo clGetMemObjectInfo; -extern PFN_clGetImageInfo clGetImageInfo; -extern PFN_clGetPipeInfo clGetPipeInfo; -extern PFN_clSetMemObjectDestructorCallback clSetMemObjectDestructorCallback; -extern PFN_clSVMAlloc clSVMAlloc; -extern PFN_clSVMFree clSVMFree; -extern PFN_clCreateSamplerWithProperties clCreateSamplerWithProperties; -extern PFN_clRetainSampler clRetainSampler; -extern PFN_clReleaseSampler clReleaseSampler; -extern PFN_clGetSamplerInfo clGetSamplerInfo; -extern PFN_clCreateProgramWithSource clCreateProgramWithSource; -extern PFN_clCreateProgramWithBinary clCreateProgramWithBinary; -extern PFN_clCreateProgramWithBuiltInKernels clCreateProgramWithBuiltInKernels; -extern PFN_clRetainProgram clRetainProgram; -extern PFN_clReleaseProgram clReleaseProgram; -extern PFN_clBuildProgram clBuildProgram; -extern PFN_clCompileProgram clCompileProgram; -extern PFN_clLinkProgram clLinkProgram; -extern PFN_clUnloadPlatformCompiler clUnloadPlatformCompiler; -extern PFN_clGetProgramInfo clGetProgramInfo; -extern PFN_clGetProgramBuildInfo clGetProgramBuildInfo; -extern PFN_clCreateKernel clCreateKernel; -extern PFN_clCreateKernelsInProgram clCreateKernelsInProgram; -extern PFN_clRetainKernel clRetainKernel; -extern PFN_clReleaseKernel clReleaseKernel; -extern PFN_clSetKernelArg clSetKernelArg; -extern PFN_clSetKernelArgSVMPointer clSetKernelArgSVMPointer; -extern PFN_clSetKernelExecInfo clSetKernelExecInfo; -extern PFN_clGetKernelInfo clGetKernelInfo; -extern PFN_clGetKernelArgInfo clGetKernelArgInfo; -extern PFN_clGetKernelWorkGroupInfo clGetKernelWorkGroupInfo; -extern PFN_clWaitForEvents clWaitForEvents; -extern PFN_clGetEventInfo clGetEventInfo; -extern PFN_clCreateUserEvent clCreateUserEvent; -extern PFN_clRetainEvent clRetainEvent; -extern PFN_clReleaseEvent clReleaseEvent; -extern PFN_clSetUserEventStatus clSetUserEventStatus; -extern PFN_clSetEventCallback clSetEventCallback; -extern PFN_clGetEventProfilingInfo clGetEventProfilingInfo; -extern PFN_clFlush clFlush; -extern PFN_clFinish clFinish; -extern PFN_clEnqueueReadBuffer clEnqueueReadBuffer; -extern PFN_clEnqueueReadBufferRect clEnqueueReadBufferRect; -extern PFN_clEnqueueWriteBuffer clEnqueueWriteBuffer; -extern PFN_clEnqueueWriteBufferRect clEnqueueWriteBufferRect; -extern PFN_clEnqueueFillBuffer clEnqueueFillBuffer; -extern PFN_clEnqueueCopyBuffer clEnqueueCopyBuffer; -extern PFN_clEnqueueCopyBufferRect clEnqueueCopyBufferRect; -extern PFN_clEnqueueReadImage clEnqueueReadImage; -extern PFN_clEnqueueWriteImage clEnqueueWriteImage; -extern PFN_clEnqueueFillImage clEnqueueFillImage; -extern PFN_clEnqueueCopyImage clEnqueueCopyImage; -extern PFN_clEnqueueCopyImageToBuffer clEnqueueCopyImageToBuffer; -extern PFN_clEnqueueCopyBufferToImage clEnqueueCopyBufferToImage; -extern PFN_clEnqueueMapBuffer clEnqueueMapBuffer; -extern PFN_clEnqueueMapImage clEnqueueMapImage; -extern PFN_clEnqueueUnmapMemObject clEnqueueUnmapMemObject; -extern PFN_clEnqueueMigrateMemObjects clEnqueueMigrateMemObjects; -extern PFN_clEnqueueNDRangeKernel clEnqueueNDRangeKernel; -extern PFN_clEnqueueNativeKernel clEnqueueNativeKernel; -extern PFN_clEnqueueMarkerWithWaitList clEnqueueMarkerWithWaitList; -extern PFN_clEnqueueBarrierWithWaitList clEnqueueBarrierWithWaitList; -extern PFN_clEnqueueSVMFree clEnqueueSVMFree; -extern PFN_clEnqueueSVMMemcpy clEnqueueSVMMemcpy; -extern PFN_clEnqueueSVMMemFill clEnqueueSVMMemFill; -extern PFN_clEnqueueSVMMap clEnqueueSVMMap; -extern PFN_clEnqueueSVMUnmap clEnqueueSVMUnmap; -extern PFN_clGetExtensionFunctionAddressForPlatform - clGetExtensionFunctionAddressForPlatform; -extern PFN_clCreateImage2D clCreateImage2D; -extern PFN_clCreateImage3D clCreateImage3D; -extern PFN_clEnqueueMarker clEnqueueMarker; -extern PFN_clEnqueueWaitForEvents clEnqueueWaitForEvents; -extern PFN_clEnqueueBarrier clEnqueueBarrier; -extern PFN_clUnloadCompiler clUnloadCompiler; -extern PFN_clGetExtensionFunctionAddress clGetExtensionFunctionAddress; -extern PFN_clCreateCommandQueue clCreateCommandQueue; -extern PFN_clCreateSampler clCreateSampler; -extern PFN_clEnqueueTask clEnqueueTask; - -// OpenGL sharing -extern PFN_clCreateFromGLBuffer clCreateFromGLBuffer; -extern PFN_clCreateFromGLTexture clCreateFromGLTexture; -extern PFN_clEnqueueAcquireGLObjects clEnqueueAcquireGLObjects; -extern PFN_clEnqueueReleaseGLObjects clEnqueueReleaseGLObjects; - -// cl_khr_egl_event extension -extern PFN_clCreateEventFromEGLSyncKHR clCreateEventFromEGLSyncKHR; - -// EGL sharing -extern PFN_clCreateFromEGLImageKHR clCreateFromEGLImageKHR; -extern PFN_clEnqueueAcquireEGLObjectsKHR clEnqueueAcquireEGLObjectsKHR; -extern PFN_clEnqueueReleaseEGLObjectsKHR clEnqueueReleaseEGLObjectsKHR; - -// cl_khr_command_buffer extension -extern PFN_clCreateCommandBufferKHR clCreateCommandBufferKHR; -extern PFN_clRetainCommandBufferKHR clRetainCommandBufferKHR; -extern PFN_clReleaseCommandBufferKHR clReleaseCommandBufferKHR; -extern PFN_clFinalizeCommandBufferKHR clFinalizeCommandBufferKHR; -extern PFN_clEnqueueCommandBufferKHR clEnqueueCommandBufferKHR; -extern PFN_clCommandNDRangeKernelKHR clCommandNDRangeKernelKHR; -extern PFN_clGetCommandBufferInfoKHR clGetCommandBufferInfoKHR; - -// OpenCL 3.0 -extern PFN_clCreateBufferWithProperties clCreateBufferWithProperties; -extern PFN_clCreateImageWithProperties clCreateImageWithProperties; - -// For convenient image creation -// It uses clCreateImage if it available (clCreateImage available since cl 1.2) -// otherwise it will use legacy clCreateImage2D -cl_mem CreateImage2DLegacy(cl_context context, cl_mem_flags flags, - const cl_image_format *image_format, - const cl_image_desc *image_desc, void *host_ptr, - cl_int *errcode_ret); - -// It uses clCreateImage if it available (clCreateImage available since cl 1.2) -// otherwise it will use legacy clCreateImage3D -cl_mem CreateImage3DLegacy(cl_context context, cl_mem_flags flags, - const cl_image_format *image_format, - const cl_image_desc *image_desc, void *host_ptr, - cl_int *errcode_ret); - -} // namespace cl -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_OPENCL_WRAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer.cc b/tensorflow/lite/experimental/litert/runtime/tensor_buffer.cc deleted file mode 100644 index d1b28539f4155d..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer.cc +++ /dev/null @@ -1,655 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.h" -#include "tensorflow/lite/experimental/litert/core/util/tensor_type_util.h" -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/event.h" -#include "tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_texture.h" -#include "tensorflow/lite/experimental/litert/runtime/ion_buffer.h" - -#if LITERT_HAS_OPENCL_SUPPORT -#include -#include "tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h" -#endif // LITERT_HAS_OPENCL_SUPPORT - -using litert::Expected; -using litert::Unexpected; - -namespace { - -template -void Copy(size_t array_size, const T* array, std::vector& vec) { - vec.clear(); - vec.reserve(array_size); - std::copy(array, array + array_size, std::back_inserter(vec)); - array = vec.data(); -} - -} // namespace - -LiteRtTensorBufferT::LiteRtTensorBufferT( - const LiteRtRankedTensorType& tensor_type, - LiteRtTensorBufferType buffer_type, size_t buffer_size, - size_t buffer_offset) - : tensor_type_(tensor_type), - buffer_type_(buffer_type), - buffer_size_(buffer_size), - buffer_offset_(buffer_offset), - ref_(1) { - // Copy local memory passed by the caller. - Copy(tensor_type_.layout.rank, tensor_type_.layout.dimensions, dimensions_); - if (tensor_type_.layout.strides) { - Copy(tensor_type_.layout.rank, tensor_type_.layout.strides, strides_); - } -} - -LiteRtTensorBufferT::~LiteRtTensorBufferT() { - switch (buffer_type()) { - case kLiteRtTensorBufferTypeUnknown: - // Nothing to do. - break; - case kLiteRtTensorBufferTypeHostMemory: - if (auto& buffer = std::get(buffer_); buffer.deallocator) { - buffer.deallocator(buffer.addr); - } - break; - case kLiteRtTensorBufferTypeAhwb: - if (auto& buffer = std::get(buffer_); buffer.deallocator) { - buffer.deallocator(buffer.ahwb); - } - break; - case kLiteRtTensorBufferTypeIon: - if (auto& buffer = std::get(buffer_); buffer.deallocator) { - buffer.deallocator(buffer.addr); - } - break; - case kLiteRtTensorBufferTypeDmaBuf: - if (auto& buffer = std::get(buffer_); buffer.deallocator) { - buffer.deallocator(buffer.addr); - } - break; - case kLiteRtTensorBufferTypeFastRpc: - if (auto& buffer = std::get(buffer_); buffer.deallocator) { - buffer.deallocator(buffer.addr); - } - break; - case kLiteRtTensorBufferTypeOpenCl: - // internal opencl buffer is auto-disposed by the - // litert::internal::OpenClBuffer destructor. - break; - case kLiteRtTensorBufferTypeGlBuffer: - // internal gl buffer is auto-disposed by the - // litert::internal::GlBuffer destructor. - case kLiteRtTensorBufferTypeGlTexture: - // internal gl texture is auto-disposed by the - // litert::internal::GlTexture destructor. - break; - } -} - -Expected LiteRtTensorBufferT::CreateFromHostMemory( - const LiteRtRankedTensorType& tensor_type, absl::Span host_memory, - LiteRtHostMemoryDeallocator deallocator) { - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeHostMemory, host_memory.size())); - tensor_buffer->buffer_ = HostBuffer{ - .addr = host_memory.data(), - .deallocator = deallocator, - }; - - if (auto status = tensor_buffer->IsValid(); !status) { - return Unexpected(status.Error()); - } - - return tensor_buffer; -} - -Expected -LiteRtTensorBufferT::CreateManagedOnHostMemory( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - void* host_memory_ptr; - if (auto rc = posix_memalign( - &host_memory_ptr, LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, buffer_size); - rc) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to allocate aligned memory"); - } - - LiteRtHostMemoryDeallocator deallocator = ::free; - LITERT_ASSIGN_OR_RETURN( - LiteRtTensorBufferT::Ptr tensor_buffer, - CreateFromHostMemory( - tensor_type, - absl::MakeSpan(static_cast(host_memory_ptr), buffer_size), - deallocator)); - - return std::move(tensor_buffer); -} - -Expected LiteRtTensorBufferT::CreateFromAhwb( - const LiteRtRankedTensorType& tensor_type, AHardwareBuffer* ahwb, - size_t ahwb_offset, LiteRtAhwbDeallocator deallocator) { - LITERT_ASSIGN_OR_RETURN(size_t buffer_size, - litert::internal::AhwbBuffer::GetSize(ahwb)); - - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeAhwb, buffer_size, ahwb_offset)); - tensor_buffer->buffer_ = AhwbBuffer{ - .ahwb = ahwb, - .deallocator = deallocator, - }; - - if (auto status = tensor_buffer->IsValid(); !status) { - return Unexpected(status.Error()); - } - - return tensor_buffer; -} - -Expected LiteRtTensorBufferT::CreateManagedAhwbBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - LITERT_ASSIGN_OR_RETURN(litert::internal::AhwbBuffer buffer, - litert::internal::AhwbBuffer::Alloc(buffer_size)); - return CreateFromAhwb(tensor_type, buffer.ahwb, /*ahwb_offset=*/0, - /*deallocator=*/litert::internal::AhwbBuffer::Free); -} - -Expected LiteRtTensorBufferT::CreateFromIonBuffer( - const LiteRtRankedTensorType& tensor_type, void* ion_buffer_addr, - int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, - LiteRtIonDeallocator deallocator) { - if (!ion_buffer_addr) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid ION buffer address"); - } - if (ion_buffer_fd < 0) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid ION buffer fd"); - } - - Ptr tensor_buffer( - new LiteRtTensorBufferT(tensor_type, kLiteRtTensorBufferTypeIon, - ion_buffer_size, ion_buffer_offset)); - tensor_buffer->buffer_ = IonBuffer{ - .addr = ion_buffer_addr, - .fd = ion_buffer_fd, - .deallocator = deallocator, - }; - - if (auto status = tensor_buffer->IsValid(); !status) { - return Unexpected(status.Error()); - } - - return tensor_buffer; -} - -Expected LiteRtTensorBufferT::CreateManagedIonBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - auto buffer = litert::internal::IonBuffer::Alloc( - buffer_size, /*alignment=*/LITERT_HOST_MEMORY_BUFFER_ALIGNMENT); - if (!buffer) { - return Unexpected(buffer.Error()); - } - return CreateFromIonBuffer(tensor_type, buffer->addr, buffer->fd, buffer_size, - /*ion_buffer_offset=*/0, - litert::internal::IonBuffer::Free); -} - -Expected LiteRtTensorBufferT::CreateFromDmaBufBuffer( - const LiteRtRankedTensorType& tensor_type, void* dmabuf_buffer_addr, - int dmabuf_buffer_fd, size_t dmabuf_buffer_size, - size_t dmabuf_buffer_offset, LiteRtDmaBufDeallocator deallocator) { - if (!dmabuf_buffer_addr) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid DMA-BUF buffer address"); - } - if (dmabuf_buffer_fd < 0) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid DMA-BUF buffer fd"); - } - - Ptr tensor_buffer( - new LiteRtTensorBufferT(tensor_type, kLiteRtTensorBufferTypeDmaBuf, - dmabuf_buffer_size, dmabuf_buffer_offset)); - tensor_buffer->buffer_ = DmaBufBuffer{ - .addr = dmabuf_buffer_addr, - .fd = dmabuf_buffer_fd, - .deallocator = deallocator, - }; - - if (auto status = tensor_buffer->IsValid(); !status) { - return Unexpected(status.Error()); - } - - return tensor_buffer; -} - -Expected -LiteRtTensorBufferT::CreateManagedDmaBufBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - auto buffer = litert::internal::DmaBufBuffer::Alloc(buffer_size); - if (!buffer) { - return Unexpected(buffer.Error()); - } - return CreateFromDmaBufBuffer(tensor_type, buffer->addr, buffer->fd, - buffer_size, /*dmabuf_buffer_offset=*/0, - litert::internal::DmaBufBuffer::Free); -} - -Expected LiteRtTensorBufferT::CreateFromFastRpcBuffer( - const LiteRtRankedTensorType& tensor_type, void* fastrpc_buffer_addr, - int fastrpc_buffer_fd, size_t fastrpc_buffer_size, - size_t fastrpc_buffer_offset, LiteRtFastRpcDeallocator deallocator) { - if (!fastrpc_buffer_addr) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid FastRPC buffer address"); - } - if (fastrpc_buffer_fd < 0) { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid FastRPC buffer fd"); - } - - Ptr tensor_buffer( - new LiteRtTensorBufferT(tensor_type, kLiteRtTensorBufferTypeFastRpc, - fastrpc_buffer_size, fastrpc_buffer_offset)); - tensor_buffer->buffer_ = FastRpcBuffer{ - .addr = fastrpc_buffer_addr, - .fd = fastrpc_buffer_fd, - .deallocator = deallocator, - }; - - if (auto status = tensor_buffer->IsValid(); !status) { - return Unexpected(status.Error()); - } - - return tensor_buffer; -} - -Expected -LiteRtTensorBufferT::CreateManagedFastRpcBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - auto buffer = litert::internal::FastRpcBuffer::Alloc(buffer_size); - if (!buffer) { - return Unexpected(buffer.Error()); - } - return CreateFromFastRpcBuffer(tensor_type, buffer->addr, buffer->fd, - buffer_size, /*fastrpc_buffer_offset=*/0, - litert::internal::FastRpcBuffer::Free); -} - -#if LITERT_HAS_OPENCL_SUPPORT -Expected LiteRtTensorBufferT::CreateFromOpenClBuffer( - const LiteRtRankedTensorType& tensor_type, cl_mem buffer, - size_t buffer_size, LiteRtOpenClDeallocator deallocator) { - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeOpenCl, buffer_size)); - tensor_buffer->buffer_.emplace( - buffer, buffer_size, deallocator); - return tensor_buffer; -} - -Expected -LiteRtTensorBufferT::CreateManagedOpenClBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - auto buffer = litert::internal::OpenClBuffer::Alloc(buffer_size); - if (!buffer) { - return Unexpected(buffer.Error()); - } - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeOpenCl, buffer_size)); - tensor_buffer->buffer_.emplace( - std::move(*buffer)); - return tensor_buffer; -} -#endif // LITERT_HAS_OPENCL_SUPPORT - -Expected LiteRtTensorBufferT::CreateFromGlBuffer( - const LiteRtRankedTensorType& tensor_type, LiteRtGLenum target, - LiteRtGLuint id, size_t size_bytes, size_t offset, - LiteRtGlBufferDeallocator deallocator) { - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeGlBuffer, size_bytes)); - tensor_buffer->buffer_.emplace( - target, id, size_bytes, offset, deallocator); - return tensor_buffer; -} - -Expected LiteRtTensorBufferT::CreateManagedGlBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - auto buffer = litert::internal::GlBuffer::Alloc(buffer_size); - if (!buffer) { - return Unexpected(buffer.Error()); - } - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeGlBuffer, buffer_size)); - tensor_buffer->buffer_.emplace( - std::move(*buffer)); - return tensor_buffer; -} - -Expected LiteRtTensorBufferT::CreateFromGlTexture( - const LiteRtRankedTensorType& tensor_type, LiteRtGLenum target, - LiteRtGLuint id, LiteRtGLenum format, size_t size_bytes, LiteRtGLint layer, - LiteRtGlTextureDeallocator deallocator) { - Ptr tensor_buffer(new LiteRtTensorBufferT( - tensor_type, kLiteRtTensorBufferTypeGlTexture, size_bytes)); - tensor_buffer->buffer_.emplace( - litert::internal::GlTexture(target, id, format, size_bytes, layer, - deallocator)); - return tensor_buffer; -} - -Expected LiteRtTensorBufferT::CreateManaged( - LiteRtTensorBufferType buffer_type, - const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { - switch (buffer_type) { - case kLiteRtTensorBufferTypeHostMemory: - return CreateManagedOnHostMemory(tensor_type, buffer_size); - case kLiteRtTensorBufferTypeAhwb: - return CreateManagedAhwbBuffer(tensor_type, buffer_size); - case kLiteRtTensorBufferTypeIon: - return CreateManagedIonBuffer(tensor_type, buffer_size); - case kLiteRtTensorBufferTypeDmaBuf: - return CreateManagedDmaBufBuffer(tensor_type, buffer_size); - case kLiteRtTensorBufferTypeFastRpc: - return CreateManagedFastRpcBuffer(tensor_type, buffer_size); - case kLiteRtTensorBufferTypeOpenCl: { -#if LITERT_HAS_OPENCL_SUPPORT - return CreateManagedOpenClBuffer(tensor_type, buffer_size); -#else - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "OpenCL buffers are not supported."); -#endif // LITERT_HAS_OPENCL_SUPPORT - } - case kLiteRtTensorBufferTypeGlBuffer: { - return CreateManagedGlBuffer(tensor_type, buffer_size); - } - case kLiteRtTensorBufferTypeGlTexture: { - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "LiteRT does not support managed GL textures."); - } - default: - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Unexpected tensor type"); - } -} - -Expected LiteRtTensorBufferT::IsValid() { - // Check for static dimensions. - for (auto i = 0; i < tensor_type_.layout.rank; ++i) { - if (tensor_type_.layout.dimensions[i] <= 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "TensorBuffer must have all static dimensions"); - } - } - - // Check for valid offset. - if (buffer_offset() >= buffer_size()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Invalid buffer offset"); - } - - // Check for sufficient size. - if (auto num_bytes = litert::internal::GetNumPackedBytes(tensor_type_); - !num_bytes) { - return Unexpected(num_bytes.Error()); - } else if (*num_bytes > buffer_size() - buffer_offset()) { - const std::string error_message = absl::StrFormat( - "Insufficient buffer size: Required %d bytes, actual size %d bytes", - *num_bytes, buffer_size() - buffer_offset()); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, error_message); - } - - // Check for proper alignment. - if (buffer_type() == kLiteRtTensorBufferTypeHostMemory) { - auto host_buffer = GetHostBuffer(); - if (!host_buffer) { - return Unexpected(host_buffer.Error()); - } - if (reinterpret_cast(*host_buffer) % - LITERT_HOST_MEMORY_BUFFER_ALIGNMENT) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unaligned host memory pointer"); - } - } - - return {}; -} - -Expected LiteRtTensorBufferT::GetHostBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeHostMemory) { - return std::get(buffer_).addr; - } - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeHostMemory), - BufferTypeToString(buffer_type_))); -} - -Expected LiteRtTensorBufferT::GetAhwbBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeAhwb) { - return std::get(buffer_).ahwb; - } - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeAhwb), - BufferTypeToString(buffer_type_))); -} - -Expected> LiteRtTensorBufferT::GetIonBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeIon) { - auto buffer = std::get(buffer_); - return std::make_pair(buffer.addr, buffer.fd); - } - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeIon), - BufferTypeToString(buffer_type_))); -} - -Expected> LiteRtTensorBufferT::GetDmaBufBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeDmaBuf) { - auto buffer = std::get(buffer_); - return std::make_pair(buffer.addr, buffer.fd); - } - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeDmaBuf), - BufferTypeToString(buffer_type_))); -} - -Expected> LiteRtTensorBufferT::GetFastRpcBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeFastRpc) { - auto buffer = std::get(buffer_); - return std::make_pair(buffer.addr, buffer.fd); - } - - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeFastRpc), - BufferTypeToString(buffer_type_))); -} - -#if LITERT_HAS_OPENCL_SUPPORT -Expected -LiteRtTensorBufferT::GetOpenClBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeOpenCl) { - return &std::get(buffer_); - } - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeOpenCl), - BufferTypeToString(buffer_type_))); -} -#endif // LITERT_HAS_OPENCL_SUPPORT - -Expected LiteRtTensorBufferT::GetGlTexture() { - if (buffer_type_ != kLiteRtTensorBufferTypeGlTexture) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unexpected tensor buffer type"); - } - return &std::get(buffer_); -} - -Expected LiteRtTensorBufferT::GetGlBuffer() { - if (buffer_type_ == kLiteRtTensorBufferTypeGlBuffer) { - return &std::get(buffer_); - } -#if LITERT_HAS_AHWB_SUPPORT - if (buffer_type_ == kLiteRtTensorBufferTypeAhwb) { - if (auto it = memory_backed_buffers_.find(kLiteRtTensorBufferTypeGlBuffer); - it != memory_backed_buffers_.end()) { - BufferVariant& memory_backed_buffer = it->second; - return &std::get(memory_backed_buffer); - } - // Create a new GL buffer from the AHWB buffer if not found. - litert::internal::AhwbBuffer ahwb_buffer = { - .ahwb = std::get(buffer_).ahwb}; - - LITERT_ASSIGN_OR_RETURN( - litert::internal::GlBuffer gl_buffer_from_ahwb, - litert::internal::GlBuffer::AllocFromAhwbBuffer(ahwb_buffer)); - - auto [it, inserted] = memory_backed_buffers_.insert( - {kLiteRtTensorBufferTypeGlBuffer, std::move(gl_buffer_from_ahwb)}); - LITERT_RETURN_IF_ERROR( - inserted == true, - Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to insert GL buffer into memory backed buffers")); - return &std::get(it->second); - } -#endif // LITERT_HAS_AHWB_SUPPORT - - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Cannot get %s buffer from %s tensor buffer", - BufferTypeToString(kLiteRtTensorBufferTypeGlBuffer), - BufferTypeToString(buffer_type_))); -} - -Expected LiteRtTensorBufferT::Lock() { - if (event_ != nullptr) { - // Only AHWB supports waiting on an input sync fence when locking the - // buffer. For all other buffer types we wait here. - if (buffer_type() != kLiteRtTensorBufferTypeAhwb) { - LITERT_RETURN_IF_ERROR(event_->Wait(/*timeout_in_ms=*/-1)); - } - } - - switch (buffer_type()) { - case kLiteRtTensorBufferTypeHostMemory: - return *GetHostBuffer(); - case kLiteRtTensorBufferTypeAhwb: - return litert::internal::AhwbBuffer::Lock( - *GetAhwbBuffer(), event_ != nullptr ? event_.get() : nullptr); - case kLiteRtTensorBufferTypeIon: - return GetIonBuffer()->first; - case kLiteRtTensorBufferTypeDmaBuf: - return GetDmaBufBuffer()->first; - case kLiteRtTensorBufferTypeFastRpc: - return GetFastRpcBuffer()->first; - case kLiteRtTensorBufferTypeOpenCl: { -#if LITERT_HAS_OPENCL_SUPPORT - auto opencl_buffer = *GetOpenClBuffer(); - auto host_memory_ptr = opencl_buffer->Lock(); - if (host_memory_ptr.HasValue()) { - return Expected(host_memory_ptr.Value()); - } else { - return Unexpected(host_memory_ptr.Error()); - } -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "OpenCL buffers are not supported"); -#endif // LITERT_HAS_OPENCL_SUPPORT - } - case kLiteRtTensorBufferTypeGlBuffer: { -#if LITERT_HAS_OPENGL_SUPPORT - auto gl_buffer = *GetGlBuffer(); - auto host_memory_ptr = gl_buffer->Lock(); - if (host_memory_ptr.HasValue()) { - return Expected(host_memory_ptr.Value()); - } else { - return Unexpected(host_memory_ptr.Error()); - } -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "OpenGL buffers are not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT - } - default: - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unexpected tensor buffer type"); - } -} - -Expected LiteRtTensorBufferT::Unlock() { - switch (buffer_type()) { - case kLiteRtTensorBufferTypeAhwb: { - auto ahwb = std::get(buffer_).ahwb; - return litert::internal::AhwbBuffer::Unlock(ahwb); - } - case kLiteRtTensorBufferTypeOpenCl: { -#if LITERT_HAS_OPENCL_SUPPORT - auto opencl_buffer = *GetOpenClBuffer(); - return opencl_buffer->Unlock(); -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "OpenCL buffers are not supported"); -#endif // LITERT_HAS_OPENCL_SUPPORT - } - case kLiteRtTensorBufferTypeGlBuffer: { -#if LITERT_HAS_OPENGL_SUPPORT - auto gl_buffer = *GetGlBuffer(); - return gl_buffer->Unlock(); -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "OpenGL buffers are not supported"); -#endif // LITERT_HAS_OPENGL_SUPPORT - } - default: - return {}; - } -} diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer.h b/tensorflow/lite/experimental/litert/runtime/tensor_buffer.h deleted file mode 100644 index f0c9d8085a7e42..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer.h +++ /dev/null @@ -1,239 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_gl_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_layout.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/event.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_texture.h" - -#if LITERT_HAS_OPENCL_SUPPORT -#include -#include "tensorflow/lite/experimental/litert/runtime/open_cl_buffer.h" -#endif // LITERT_HAS_OPENCL_SUPPORT - -class LiteRtTensorBufferT { - public: - using Ptr = std::unique_ptr; - - ~LiteRtTensorBufferT(); - - // Make this class non-copiable because it includes raw pointers and resource - // handles. - LiteRtTensorBufferT(const LiteRtTensorBufferT&) = delete; - LiteRtTensorBufferT(LiteRtTensorBufferT&&) = delete; - LiteRtTensorBufferT& operator=(const LiteRtTensorBufferT&) = delete; - LiteRtTensorBufferT& operator=(LiteRtTensorBufferT&&) = delete; - - static litert::Expected CreateFromHostMemory( - const LiteRtRankedTensorType& tensor_type, - absl::Span host_memory, - LiteRtHostMemoryDeallocator deallocator = nullptr); - - static litert::Expected CreateFromAhwb( - const LiteRtRankedTensorType& tensor_type, AHardwareBuffer* ahwb, - size_t ahwb_offset, LiteRtAhwbDeallocator deallocator = nullptr); - - static litert::Expected CreateFromIonBuffer( - const LiteRtRankedTensorType& tensor_type, void* ion_buffer_addr, - int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, - LiteRtIonDeallocator deallocator = nullptr); - - static litert::Expected CreateFromDmaBufBuffer( - const LiteRtRankedTensorType& tensor_type, void* dmabuf_buffer_addr, - int dmabuf_buffer_fd, size_t dmabuf_buffer_size, - size_t dmabuf_buffer_offset, - LiteRtDmaBufDeallocator deallocator = nullptr); - - static litert::Expected CreateFromFastRpcBuffer( - const LiteRtRankedTensorType& tensor_type, void* fastrpc_buffer_addr, - int fastrpc_buffer_fd, size_t fastrpc_buffer_size, - size_t fastrpc_buffer_offset, - LiteRtFastRpcDeallocator deallocator = nullptr); - -#if LITERT_HAS_OPENCL_SUPPORT - static litert::Expected CreateFromOpenClBuffer( - const LiteRtRankedTensorType& tensor_type, cl_mem buffer, - size_t opencl_buffer_size, LiteRtOpenClDeallocator deallocator = nullptr); -#endif // LITERT_HAS_OPENCL_SUPPORT - - static litert::Expected CreateFromGlBuffer( - const LiteRtRankedTensorType& tensor_type, LiteRtGLenum target, - LiteRtGLuint id, size_t size_bytes, size_t offset, - LiteRtGlBufferDeallocator deallocator = nullptr); - static litert::Expected CreateFromGlTexture( - const LiteRtRankedTensorType& tensor_type, LiteRtGLenum target, - LiteRtGLuint id, LiteRtGLenum format, size_t size_bytes, - LiteRtGLint layer, LiteRtGlTextureDeallocator deallocator = nullptr); - - static litert::Expected CreateManaged( - LiteRtTensorBufferType buffer_type, - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - LiteRtRankedTensorType tensor_type() const { return tensor_type_; } - LiteRtTensorBufferType buffer_type() const { return buffer_type_; } - size_t buffer_size() const { return buffer_size_; } - size_t buffer_offset() const { return buffer_offset_; } - - bool HasEvent() const { return event_ != nullptr; } - - litert::Expected GetEvent() const { - if (!HasEvent()) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "TensorBuffer has no event"); - } - return event_.get(); - } - - void SetEvent(LiteRtEventT* e) { - // Take ownership of the event. - event_ = std::unique_ptr(e); - } - void ClearEvent() { event_ = nullptr; } - - litert::Expected GetHostBuffer(); - litert::Expected GetAhwbBuffer(); - litert::Expected> GetIonBuffer(); - litert::Expected> GetDmaBufBuffer(); - litert::Expected> GetFastRpcBuffer(); -#if LITERT_HAS_OPENCL_SUPPORT - litert::Expected GetOpenClBuffer(); -#endif // LITERT_HAS_OPENCL_SUPPORT - litert::Expected GetGlBuffer(); - litert::Expected GetGlTexture(); - - litert::Expected Lock(); - litert::Expected Unlock(); - - // Used to duplicate the current tensor buffer. Internally it increases - // reference count to the underlying buffer. - void Duplicate() const { Ref(); } - - // Increments reference count by one. - void Ref() const { ref_.fetch_add(1, std::memory_order_relaxed); } - - // Decrements reference count by one. If the count remains - // positive, returns false. When the count reaches zero, returns - // true. - bool Unref() const { - if (ref_.fetch_sub(1, std::memory_order_acq_rel) == 1) { - return true; - } - return false; - } - - // Gets the current reference count. - int RefCount() const { return ref_.load(std::memory_order_relaxed); } - - private: - struct HostBuffer { - void* addr; - LiteRtHostMemoryDeallocator deallocator; - }; - - struct AhwbBuffer { - AHardwareBuffer* ahwb; - LiteRtAhwbDeallocator deallocator; - }; - - struct IonBuffer { - void* addr; - int fd; - LiteRtIonDeallocator deallocator; - }; - - struct DmaBufBuffer { - void* addr; - int fd; - LiteRtDmaBufDeallocator deallocator; - }; - - struct FastRpcBuffer { - void* addr; - int fd; - LiteRtFastRpcDeallocator deallocator; - }; - - using BufferVariant = - std::variant; - - LiteRtTensorBufferT(const LiteRtRankedTensorType& tensor_type, - LiteRtTensorBufferType buffer_type, size_t buffer_size, - size_t buffer_offset = 0); - - static litert::Expected CreateManagedOnHostMemory( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - static litert::Expected CreateManagedAhwbBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - static litert::Expected CreateManagedIonBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - static litert::Expected CreateManagedDmaBufBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - static litert::Expected CreateManagedFastRpcBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - static litert::Expected CreateManagedOpenClBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - static litert::Expected CreateManagedGlBuffer( - const LiteRtRankedTensorType& tensor_type, size_t buffer_size); - - litert::Expected IsValid(); - - LiteRtRankedTensorType tensor_type_; - std::vector> dimensions_; - std::vector> strides_; - LiteRtTensorBufferType buffer_type_; - size_t buffer_size_; - size_t buffer_offset_; - BufferVariant buffer_; - std::unique_ptr event_; - mutable std::atomic_int_fast32_t ref_; - // A map of memory backed buffers. Memory backed buffers are backed by the - // memory of buffer_. For example, a GL buffer can be backed by the memory of - // an AHWB buffer. - absl::flat_hash_map - memory_backed_buffers_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.cc b/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.cc deleted file mode 100644 index aac9c2b37f0af5..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.cc +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.h" - -#include "absl/strings/str_format.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_utils.h" - -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT -#include -#include -#include -#include - -#include - -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/gl_buffer.h" - -#endif // LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" - -namespace litert { -namespace internal { - -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_OPENCL_SUPPORT - -// TODO(b/383176413): Add gl-cl interop extension. -Expected CopyGlToCl(GlBuffer& src, OpenClBuffer& dest) { - if (src.target() != GL_SHADER_STORAGE_BUFFER) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported GL target for conversion to OpenCL"); - } - size_t cl_size = dest.size_bytes(); - if (src.bytes_size() != cl_size) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "GL buffer size does not match OpenCL size"); - } - LITERT_ASSIGN_OR_RETURN(void* host_src, src.Lock()); - LITERT_ASSIGN_OR_RETURN(void* host_dest, dest.Lock()); - std::memcpy(host_dest, host_src, src.bytes_size()); - LITERT_RETURN_IF_ERROR(dest.Unlock()); - LITERT_RETURN_IF_ERROR(src.Unlock()); - return {}; -} - -Expected TensorBufferConvertGlToCl( - LiteRtTensorBufferT& tensor_buffer_gl) { - // Create a new CL tensor buffer. - LITERT_ASSIGN_OR_RETURN( - LiteRtTensorBufferT::Ptr tensor_buffer_cl, - LiteRtTensorBufferT::CreateManaged(kLiteRtTensorBufferTypeOpenCl, - tensor_buffer_gl.tensor_type(), - tensor_buffer_gl.buffer_size())); - LITERT_ASSIGN_OR_RETURN(OpenClBuffer * cl_buffer, - tensor_buffer_cl->GetOpenClBuffer()); - LITERT_ASSIGN_OR_RETURN(GlBuffer * gl_buffer, tensor_buffer_gl.GetGlBuffer()); - CopyGlToCl(*gl_buffer, *cl_buffer); - return tensor_buffer_cl; -} -#endif // LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_CL_SUPPORT - -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT -Expected CopyGlToAhwb(GlBuffer& src, AhwbBuffer& dest) { - if (src.target() != GL_SHADER_STORAGE_BUFFER) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported GL target for conversion to AHWB"); - } - LITERT_ASSIGN_OR_RETURN(size_t ahwb_size, AhwbBuffer::GetSize(dest.ahwb)); - if (src.bytes_size() != ahwb_size) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "GL buffer size does not match AHWB size"); - } - LITERT_ASSIGN_OR_RETURN(void* host_src, src.Lock()); - LITERT_ASSIGN_OR_RETURN(void* host_dest, AhwbBuffer::Lock(dest.ahwb)); - std::memcpy(host_dest, host_src, src.bytes_size()); - LITERT_RETURN_IF_ERROR(AhwbBuffer::Unlock(dest.ahwb)); - LITERT_RETURN_IF_ERROR(src.Unlock()); - return {}; -} - -Expected TensorBufferConvertGlToAhwb( - LiteRtTensorBufferT& tensor_buffer_gl) { - // Create a new AHWB tensor buffer. - LITERT_ASSIGN_OR_RETURN( - LiteRtTensorBufferT::Ptr tensor_buffer_ahwb, - LiteRtTensorBufferT::CreateManaged(kLiteRtTensorBufferTypeAhwb, - tensor_buffer_gl.tensor_type(), - tensor_buffer_gl.buffer_size())); - LITERT_ASSIGN_OR_RETURN(AHardwareBuffer * ahwb, - tensor_buffer_ahwb->GetAhwbBuffer()); - AhwbBuffer ahwb_buffer{.ahwb = ahwb}; - LITERT_ASSIGN_OR_RETURN(GlBuffer * gl_buffer, tensor_buffer_gl.GetGlBuffer()); - CopyGlToAhwb(*gl_buffer, ahwb_buffer); - return tensor_buffer_ahwb; -} -#endif // LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT - -#if LITERT_HAS_OPENGL_SUPPORT -Expected CopyHostToGl(void* host_src, GlBuffer& dest) { - LITERT_ASSIGN_OR_RETURN(void* host_dest, dest.Lock()); - std::memcpy(host_dest, host_src, dest.bytes_size()); - return {}; -} - -Expected TensorBufferConvertHostToGl( - LiteRtTensorBufferT& tensor_buffer_host) { - // Create a new GL tensor buffer. - LITERT_ASSIGN_OR_RETURN( - LiteRtTensorBufferT::Ptr tensor_buffer_gl, - LiteRtTensorBufferT::CreateManaged(kLiteRtTensorBufferTypeGlBuffer, - tensor_buffer_host.tensor_type(), - tensor_buffer_host.buffer_size())); - LITERT_ASSIGN_OR_RETURN(void* host_memory, - tensor_buffer_host.GetHostBuffer()); - LITERT_ASSIGN_OR_RETURN(GlBuffer * gl_buffer, - tensor_buffer_gl->GetGlBuffer()); - CopyHostToGl(host_memory, *gl_buffer); - return tensor_buffer_gl; -} -#endif - -Expected TensorBufferConvertHostTo( - LiteRtTensorBufferType buffer_type, LiteRtTensorBufferT& tensor_buffer) { - switch (buffer_type) { - case kLiteRtTensorBufferTypeGlBuffer: -#if LITERT_HAS_OPENGL_SUPPORT - return TensorBufferConvertHostToGl(tensor_buffer); -#else - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("This buffer conversion is not supported: %s -> %s", - BufferTypeToString(tensor_buffer.buffer_type()), - BufferTypeToString(buffer_type))); -#endif - default: - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("This buffer conversion is not supported: %s -> %s", - BufferTypeToString(tensor_buffer.buffer_type()), - BufferTypeToString(buffer_type))); - } -} - -Expected TensorBufferConvertGlTo( - LiteRtTensorBufferType buffer_type, LiteRtTensorBufferT& tensor_buffer) { - switch (buffer_type) { - case kLiteRtTensorBufferTypeAhwb: -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT - return TensorBufferConvertGlToAhwb(tensor_buffer); -#else - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("This buffer conversion is not supported: %s -> %s", - BufferTypeToString(tensor_buffer.buffer_type()), - BufferTypeToString(buffer_type))); -#endif - case kLiteRtTensorBufferTypeOpenCl: -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_OPENCL_SUPPORT - return TensorBufferConvertGlToCl(tensor_buffer); -#else - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("This buffer conversion is not supported: %s -> %s", - BufferTypeToString(tensor_buffer.buffer_type()), - BufferTypeToString(buffer_type))); -#endif - default: - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("This buffer conversion is not supported: %s -> %s", - BufferTypeToString(tensor_buffer.buffer_type()), - BufferTypeToString(buffer_type))); - } -} - -Expected TensorBufferConvertTo( - LiteRtTensorBufferType buffer_type, LiteRtTensorBufferT& tensor_buffer) { - switch (tensor_buffer.buffer_type()) { - case kLiteRtTensorBufferTypeHostMemory: - return TensorBufferConvertHostTo(buffer_type, tensor_buffer); - case kLiteRtTensorBufferTypeGlBuffer: - return TensorBufferConvertGlTo(buffer_type, tensor_buffer); - default: - return Unexpected( - kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("This buffer conversion is not supported: %s -> %s", - BufferTypeToString(tensor_buffer.buffer_type()), - BufferTypeToString(buffer_type))); - } -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.h b/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.h deleted file mode 100644 index a3ebe1826303d8..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_CONVERSION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_CONVERSION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" - -namespace litert::internal { - -// Converts the given tensor buffer to the specified buffer type. A new tensor -// buffer is created and returned. This function locks/unlocks the tensor buffer -// and will involve a copy. -// TODO(b/383176413): Investigate zero/fast-copy conversions. -Expected TensorBufferConvertTo( - LiteRtTensorBufferType buffer_type, LiteRtTensorBufferT& tensor_buffer); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_CONVERSION_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion_test.cc b/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion_test.cc deleted file mode 100644 index d358d1c1c2f785..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion_test.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer_conversion.h" - -#include -#include -#include - -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_environment.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_types.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" -#include "tensorflow/lite/experimental/litert/core/environment.h" -#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -#if LITERT_HAS_OPENGL_SUPPORT -#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" -#endif // LITERT_HAS_OPENGL_SUPPORT - -namespace { - -constexpr const float kTensorData[] = {10, 20, 30, 40}; - -constexpr const int32_t kTensorDimensions[] = {sizeof(kTensorData) / - sizeof(kTensorData[0])}; - -constexpr const LiteRtRankedTensorType kTensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - ::litert::BuildLayout(kTensorDimensions)}; - -TEST(TensorBufferConversionTest, HostToGl) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(LiteRtEnvironmentT::Ptr litert_env, - LiteRtEnvironmentT::CreateWithOptions({})); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferT::Ptr tensor_buffer_host, - LiteRtTensorBufferT::CreateManaged(kLiteRtTensorBufferTypeHostMemory, - kTensorType, sizeof(kTensorData))); - // Write data to the host memory. - LITERT_ASSERT_OK_AND_ASSIGN(void* host_memory, - tensor_buffer_host->GetHostBuffer()); - std::memcpy(host_memory, kTensorData, sizeof(kTensorData)); - -#if LITERT_HAS_OPENGL_SUPPORT - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferT::Ptr tensor_buffer_gl, - litert::internal::TensorBufferConvertTo(kLiteRtTensorBufferTypeGlBuffer, - *tensor_buffer_host)); - - // Ensure that data was copied correctly from host to GL. - LITERT_ASSERT_OK_AND_ASSIGN(void* host_gl, tensor_buffer_gl->Lock()); - ASSERT_EQ(std::memcmp(host_gl, kTensorData, sizeof(kTensorData)), 0); -#else - // Since GL support is not enabled, the conversion should fail. - EXPECT_FALSE(litert::internal::TensorBufferConvertTo( - kLiteRtTensorBufferTypeGlBuffer, *tensor_buffer_host)); -#endif -} - -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT -TEST(TensorBufferConversionTest, GlToAhwb) { - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferT::Ptr tensor_buffer_gl, - LiteRtTensorBufferT::CreateManaged(kLiteRtTensorBufferTypeGlBuffer, - kTensorType, sizeof(kTensorData))); - // Write data to the GL buffer. - LITERT_ASSERT_OK_AND_ASSIGN(litert::internal::GlBuffer * gl_buffer, - tensor_buffer_gl->GetGlBuffer()); - LITERT_ASSERT_OK_AND_ASSIGN(float* data, gl_buffer->Lock()); - std::memcpy(data, kTensorData, sizeof(kTensorData)); - gl_buffer->Unlock(); - - // Convert. - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferT::Ptr tensor_buffer_ahwb, - litert::internal::TensorBufferConvertTo(kLiteRtTensorBufferTypeAhwb, - *tensor_buffer_gl)); - // Ensure that data was copied correctly from Gl to Ahwb. - LITERT_ASSERT_OK_AND_ASSIGN(void* host_ahwb, tensor_buffer_ahwb->Lock()); - ASSERT_EQ(std::memcmp(host_ahwb, kTensorData, sizeof(kTensorData)), 0); -} -#endif // LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_AHWB_SUPPORT - -#if LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_OPENCL_SUPPORT -TEST(TensorBufferConversionTest, GlToCl) { - // Environment setup. - LITERT_ASSERT_OK_AND_ASSIGN(LiteRtEnvironmentT::Ptr litert_env, - LiteRtEnvironmentT::CreateWithOptions({})); - if (!litert::internal::OpenClBuffer::IsSupported()) { - GTEST_SKIP() << "OpenCL buffers are not supported on this platform; " - "skipping the test"; - } - - std::unique_ptr env; - ASSERT_TRUE(tflite::gpu::gl::EglEnvironment::NewEglEnvironment(&env).ok()); - - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferT::Ptr tensor_buffer_gl, - LiteRtTensorBufferT::CreateManaged(kLiteRtTensorBufferTypeGlBuffer, - kTensorType, sizeof(kTensorData))); - // Write data to the GL buffer. - LITERT_ASSERT_OK_AND_ASSIGN(litert::internal::GlBuffer * gl_buffer, - tensor_buffer_gl->GetGlBuffer()); - LITERT_ASSERT_OK_AND_ASSIGN(float* data, gl_buffer->Lock()); - std::memcpy(data, kTensorData, sizeof(kTensorData)); - gl_buffer->Unlock(); - - // Convert. - LITERT_ASSERT_OK_AND_ASSIGN( - LiteRtTensorBufferT::Ptr tensor_buffer_cl, - litert::internal::TensorBufferConvertTo(kLiteRtTensorBufferTypeOpenCl, - *tensor_buffer_gl)); - - // Ensure that data was copied correctly from Gl to CL. - LITERT_ASSERT_OK_AND_ASSIGN(void* host_cl, tensor_buffer_cl->Lock()); - ASSERT_EQ(std::memcmp(host_cl, kTensorData, sizeof(kTensorData)), 0); -} -#endif // LITERT_HAS_OPENGL_SUPPORT && LITERT_HAS_OPENCL_SUPPORT - -} // namespace diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_requirements.h b/tensorflow/lite/experimental/litert/runtime/tensor_buffer_requirements.h deleted file mode 100644 index 04f461966889b7..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer_requirements.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_REQUIREMENTS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_REQUIREMENTS_H_ - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" - -class LiteRtTensorBufferRequirementsT { - public: - LiteRtTensorBufferRequirementsT( - int num_supported_tensor_buffer_types, - const LiteRtTensorBufferType* supported_tensor_buffer_types, - size_t buffer_size, std::vector&& strides) - : supported_buffer_types_( - supported_tensor_buffer_types, - supported_tensor_buffer_types + num_supported_tensor_buffer_types), - buffer_size_(buffer_size), - strides_(std::move(strides)) {} - const std::vector& SupportedBufferTypes() const { - return supported_buffer_types_; - } - size_t BufferSize() const { return buffer_size_; } - const std::vector& Strides() const { return strides_; } - - private: - std::vector supported_buffer_types_; - size_t buffer_size_; - // Stride per each dimension. - std::vector strides_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_REQUIREMENTS_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/tfl_utils.cc b/tensorflow/lite/experimental/litert/runtime/tfl_utils.cc deleted file mode 100644 index 37e419c68b00ca..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tfl_utils.cc +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/runtime/tfl_utils.h" - -#include -#include - -#include "tensorflow/lite/c/c_api_opaque.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -namespace litert { -namespace internal { - -Expected ConvertElementType(TfLiteType tfl_type) { - switch (tfl_type) { - case kTfLiteNoType: - return ElementType::None; - case kTfLiteBool: - return ElementType::Bool; - case kTfLiteInt4: - return ElementType::Int4; - case kTfLiteInt8: - return ElementType::Int8; - case kTfLiteInt16: - return ElementType::Int16; - case kTfLiteInt32: - return ElementType::Int32; - case kTfLiteInt64: - return ElementType::Int64; - case kTfLiteUInt8: - return ElementType::UInt8; - case kTfLiteUInt16: - return ElementType::UInt16; - case kTfLiteUInt32: - return ElementType::UInt32; - case kTfLiteUInt64: - return ElementType::UInt64; - case kTfLiteFloat16: - return ElementType::Float16; - case kTfLiteBFloat16: - return ElementType::BFloat16; - case kTfLiteFloat32: - return ElementType::Float32; - case kTfLiteFloat64: - return ElementType::Float64; - case kTfLiteComplex64: - return ElementType::Complex64; - case kTfLiteComplex128: - return ElementType::Complex128; - case kTfLiteResource: - return ElementType::TfResource; - case kTfLiteString: - return ElementType::TfString; - case kTfLiteVariant: - return ElementType::TfVariant; - default: - return Unexpected(kLiteRtStatusErrorInvalidArgument, - "Unsupported TfLiteType"); - } -} - -Expected ConvertTensorType( - const TfLiteOpaqueTensor* tfl_opaque_tensor) { - auto tfl_type = TfLiteOpaqueTensorType(tfl_opaque_tensor); - auto element_type = ConvertElementType(tfl_type); - if (!element_type) { - return Unexpected(element_type.Error()); - } - - size_t rank = TfLiteOpaqueTensorNumDims(tfl_opaque_tensor); - Dimensions dimensions(rank); - for (size_t i = 0; i < rank; ++i) { - dimensions[i] = TfLiteOpaqueTensorDim(tfl_opaque_tensor, i); - } - - return RankedTensorType(*element_type, Layout(std::move(dimensions))); -} - -} // namespace internal -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/tfl_utils.h b/tensorflow/lite/experimental/litert/runtime/tfl_utils.h deleted file mode 100644 index 8874c7535f3d6a..00000000000000 --- a/tensorflow/lite/experimental/litert/runtime/tfl_utils.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TFL_UTILS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TFL_UTILS_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -struct TfLiteOpaqueTensor; - -namespace litert::internal { - -Expected ConvertElementType(TfLiteType tfl_type); - -Expected ConvertTensorType( - const TfLiteOpaqueTensor* tfl_opaque_tensor); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TFL_UTILS_H_ diff --git a/tensorflow/lite/experimental/litert/test/BUILD b/tensorflow/lite/experimental/litert/test/BUILD deleted file mode 100644 index 14b60339d2726b..00000000000000 --- a/tensorflow/lite/experimental/litert/test/BUILD +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:tfl_model_gen.bzl", "tfl_model_gen") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - # copybara:uncomment "//third_party/mediapipe/calculators/tensor:__subpackages__", - # copybara:uncomment "//third_party/odml/litert:__subpackages__", - "//tensorflow/lite/experimental/litert:__subpackages__", - ], -) - -tfl_model_gen( - name = "mlir_test_data", - srcs = glob(["testdata/*.mlir"]), -) - -filegroup( - name = "tflite_test_data", - srcs = glob(["testdata/*.tflite"]), -) - -cc_library( - name = "common", - testonly = 1, - srcs = [ - "common.cc", - ], - hdrs = [ - "common.h", - ], - deps = [ - "//tensorflow/lite:framework", - "//tensorflow/lite/c:c_api_opaque", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core:cc_api_stable", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/experimental/litert/core/model:model_buffer", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform", - ], -) - -cc_library( - name = "simple_model", - testonly = 1, - hdrs = [ - "testdata/simple_model_test_vectors.h", - ], - data = [ - "testdata/simple_model.tflite", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - ], -) - -cc_library( - name = "simple_model_npu", - testonly = 1, - srcs = [], - hdrs = [ - "testdata/simple_model_test_vectors.h", - ], - data = [ - "testdata/simple_model_google_tensor.bin", - "testdata/simple_model_mtk.bin", - "testdata/simple_model_npu.tflite", - "testdata/simple_model_qualcomm.bin", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - ], -) - -cc_library( - name = "simple_cascade_model_npu", - testonly = 1, - srcs = [], - hdrs = [ - "testdata/simple_model_test_vectors.h", - ], - data = [ - "testdata/simple_cascade_model_npu.tflite", - "testdata/simple_model_google_tensor.bin", - "testdata/simple_model_mtk.bin", - "testdata/simple_model_qualcomm.bin", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_layout", - ], -) - -cc_library( - name = "test_models", - hdrs = ["test_models.h"], - deps = [ - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "matchers", - testonly = True, - hdrs = ["matchers.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - ], -) - -cc_test( - name = "matchers_test", - srcs = ["matchers_test.cc"], - deps = [ - ":matchers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_googletest//:gtest_main", - ], -) - -# Use this library if you want to enforce an OSS environment for your test. -cc_library( - name = "matchers_oss", - testonly = True, - hdrs = ["matchers.h"], - defines = ["LITERT_DEFINE_GTEST_STATUS_PRINTER"], - tags = ["avoid_dep"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - ], -) - -exports_files(srcs = [ - "testdata/mobilenet_v2_1.0_224.tflite", - "testdata/simple_model_google_tensor.bin", - "testdata/simple_model_mtk.bin", - "testdata/simple_model_qualcomm.bin", -]) diff --git a/tensorflow/lite/experimental/litert/test/common.cc b/tensorflow/lite/experimental/litert/test/common.cc deleted file mode 100644 index 8744bcc14cfbed..00000000000000 --- a/tensorflow/lite/experimental/litert/test/common.cc +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/test/common.h" - -#include -#include -#include // NOLINT -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/const_init.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tsl/platform/platform.h" - -namespace litert::testing { - -Expected UniqueTestDirectory::Create() { - constexpr size_t kMaxTries = 1000; - ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); - - // We don't want multiple threads to create the same directory. - absl::MutexLock l(&mutex); - - auto tmp_dir = std::filesystem::temp_directory_path(); - std::random_device dev; - std::mt19937 prng(dev()); - std::uniform_int_distribution rand(0); - std::stringstream ss; - - for (auto i = 0; i < kMaxTries; ++i) { - ss.clear(); - ss << std::hex << rand(prng); - auto path = tmp_dir / ss.str(); - if (std::filesystem::create_directory(path)) { - LITERT_LOG(LITERT_INFO, "Created unique temporary directory %s", - path.c_str()); - return UniqueTestDirectory(path); - } - } - - return Error(kLiteRtStatusErrorRuntimeFailure, - "Could not create a unique temporary directory"); -} - -UniqueTestDirectory::~UniqueTestDirectory() { - std::filesystem::remove_all(tmpdir_); -} - -std::string GetTestFilePath(absl::string_view filename) { - static constexpr absl::string_view kTestDataDir = - "tensorflow/lite/experimental/litert/" - "test/testdata/"; - - if constexpr (!tsl::kIsOpenSource) { - return internal::Join({"third_party", kTestDataDir, filename}); - } else { - return internal::Join({kTestDataDir, filename}); - } -} - -std::string GetTfliteFilePath(absl::string_view filename) { - static constexpr absl::string_view kTestDataDir = "tensorflow/lite/"; - - if constexpr (!tsl::kIsOpenSource) { - return internal::Join({"third_party", kTestDataDir, filename}); - } else { - return internal::Join({kTestDataDir, filename}); - } -} - -std::string GetLiteRtPath(absl::string_view rel_path) { - static constexpr absl::string_view kLiteRtRoot = - "tensorflow/lite/experimental/litert/"; - - if constexpr (!tsl::kIsOpenSource) { - return internal::Join({"third_party", kLiteRtRoot, rel_path}); - } else { - return internal::Join({kLiteRtRoot, rel_path}); - } -} - -Model LoadTestFileModel(absl::string_view filename) { - return *Model::CreateFromFile(GetTestFilePath(filename)); -} - -Expected TflRuntime::CreateFromFlatBuffer( - internal::FlatbufferWrapper::Ptr flatbuffer) { - ::tflite::Interpreter::Ptr interp; - tflite::ops::builtin::BuiltinOpResolver resolver; - tflite::InterpreterBuilder(flatbuffer->FlatbufferModel(), resolver)(&interp); - if (interp == nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure); - } - return TflRuntime::Ptr( - new TflRuntime(std::move(flatbuffer), std::move(interp))); -} - -} // namespace litert::testing diff --git a/tensorflow/lite/experimental/litert/test/common.h b/tensorflow/lite/experimental/litert/test/common.h deleted file mode 100644 index 6b7b20e989b7a3..00000000000000 --- a/tensorflow/lite/experimental/litert/test/common.h +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_COMMON_H_ - -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/interpreter.h" - -namespace litert::testing { - -// A x-platform compatible replacement for testing::UniqueTestDirectory. -class UniqueTestDirectory { - public: - static Expected Create(); - ~UniqueTestDirectory(); - - UniqueTestDirectory(const UniqueTestDirectory&) = delete; - UniqueTestDirectory(UniqueTestDirectory&&) = default; - UniqueTestDirectory& operator=(const UniqueTestDirectory&) = delete; - UniqueTestDirectory& operator=(UniqueTestDirectory&&) = default; - - absl::string_view Str() const { return tmpdir_; } - - private: - explicit UniqueTestDirectory(std::string&& tmpdir) - : tmpdir_(std::move(tmpdir)) {} - std::string tmpdir_; -}; - -// Gets the path to the given filename in the testdata directory. -std::string GetTestFilePath(absl::string_view filename); - -// Gets a path to the given filename in the tflite directory. -std::string GetTfliteFilePath(absl::string_view filename); - -// Gets a full path given a path relative to the litert directory. -std::string GetLiteRtPath(absl::string_view rel_path); - -Model LoadTestFileModel(absl::string_view filename); - -class TflRuntime { - public: - using Ptr = std::unique_ptr; - - static Expected CreateFromFlatBuffer( - internal::FlatbufferWrapper::Ptr flatbuffer); - - ::tflite::Interpreter& Interpreter() { return *interpreter_; } - - const internal::FlatbufferWrapper& Flatbuffer() const { return *flatbuffer_; } - - private: - TflRuntime(internal::FlatbufferWrapper::Ptr flatbuffer, - ::tflite::Interpreter::Ptr interpreter) - : flatbuffer_(std::move(flatbuffer)), - interpreter_(std::move(interpreter)) {} - - internal::FlatbufferWrapper::Ptr flatbuffer_; - ::tflite::Interpreter::Ptr interpreter_; -}; - -inline Expected MakeRuntimeFromTestFile( - absl::string_view filename) { - auto flatbuffer = - internal::FlatbufferWrapper::CreateFromTflFile(GetTestFilePath(filename)); - if (!flatbuffer) { - return flatbuffer.Error(); - } - return TflRuntime::CreateFromFlatBuffer(std::move(*flatbuffer)); -} - -inline Expected MakeRuntimeFromTestFileWithNpuModel( - absl::string_view filename, absl::string_view npu_filename) { - auto buf = internal::GetModelBufWithByteCode(GetTestFilePath(filename), - GetTestFilePath(npu_filename)); - if (!buf) { - return buf.Error(); - } - auto flatbuffer = - internal::FlatbufferWrapper::CreateFromBuffer(std::move(*buf)); - if (!flatbuffer) { - return flatbuffer.Error(); - } - return TflRuntime::CreateFromFlatBuffer(std::move(*flatbuffer)); -} - -} // namespace litert::testing - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_COMMON_H_ diff --git a/tensorflow/lite/experimental/litert/test/matchers.h b/tensorflow/lite/experimental/litert/test/matchers.h deleted file mode 100644 index 7db2c43435c2ca..00000000000000 --- a/tensorflow/lite/experimental/litert/test/matchers.h +++ /dev/null @@ -1,359 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_MATCHERS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_MATCHERS_H_ - -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -// Is equivalent to `ASSERT_THAT(expr, testing::litert::IsOk())` -#define LITERT_ASSERT_OK(EXPR) ASSERT_THAT((EXPR), ::testing::litert::IsOk()) - -// Is equivalent to `EXPECT_THAT(expr, testing::litert::IsOk())` -#define LITERT_EXPECT_OK(EXPR) EXPECT_THAT((EXPR), ::testing::litert::IsOk()) - -// Checks that the result of `EXPR` (a `litert::Expected` object) is not an -// error and assigns the value it holds to `DECL` as if: -// ``` -// DECL = std::move(EXPR.Value()); -// ``` -// -// ```cpp -// Expected BuildSomething(); -// -// Will fail the test if `BuildSomething()`'s returned value holds an error. -// Otherwise defines and assigns the returned `Something` value to `smth` -// ASSERT_OK_AND_ASSIGN(Something smth, BuildSomething()); -// ``` -#define LITERT_ASSERT_OK_AND_ASSIGN(DECL, EXPR) \ - LITERT_ASSERT_OK_AND_ASSIGN_HELPER2(__LINE__, DECL, EXPR) - -#define LITERT_ASSERT_OK_AND_ASSIGN_HELPER1(LINE, DECL, EXPR) \ - auto&& litert_expected_value_or_error_##LINE = (EXPR); \ - LITERT_ASSERT_OK(litert_expected_value_or_error_##LINE); \ - DECL = std::move(litert_expected_value_or_error_##LINE.Value()); - -#define LITERT_ASSERT_OK_AND_ASSIGN_HELPER2(LINE, DECL, EXPR) \ - LITERT_ASSERT_OK_AND_ASSIGN_HELPER1(LINE, DECL, EXPR) - -namespace testing::litert { - -// Matches `litert::Expected` values that hold a success value and -// `LiteRtStatusOk`. -// -// See `IsOk()` function below for usage examples. -class IsOkMatcher { - public: - // Implicitly builds and wraps the matcher implementation in a GTest - // Matcher object. - template - // NOLINTNEXTLINE(*-explicit-constructor): This needs to be implicit. - operator testing::Matcher() const { - return testing::Matcher(new Impl()); - } - - template - class Impl : public testing::MatcherInterface { - template - bool MatchAndExplainImpl(const ::litert::Expected& value, - testing::MatchResultListener* listener) const { - return value.HasValue(); - } - - bool MatchAndExplainImpl(const ::litert::Unexpected& unexpected, - testing::MatchResultListener* listener) const { - return false; - } - - bool MatchAndExplainImpl(const ::litert::Error& e, - testing::MatchResultListener* listener) const { - return false; - } - - bool MatchAndExplainImpl(const LiteRtStatus& status, - testing::MatchResultListener* listener) const { - if (status != kLiteRtStatusOk) { - *listener << "status is " << LiteRtGetStatusString(status); - return false; - } - return true; - } - - public: - using is_gtest_matcher = void; - - bool MatchAndExplain( - V value, testing::MatchResultListener* listener) const override { - return MatchAndExplainImpl(value, listener); - } - - void DescribeTo(std::ostream* os) const override { - if (os) { - *os << "is ok."; - } - } - - void DescribeNegationTo(std::ostream* os) const override { - if (os) { - *os << "is not ok."; - } - } - }; -}; - -// Matches `litert::Expected` values that hold a success value and -// `LiteRtStatusOk`. -// -// Note: you might want to use the convenience macros: -// - `LITERT_EXPECT_OK(expr)` -// - `LITERT_ASSERT_OK(expr)` -// - `ASSERT_OK_AND_ASSIGN(type var, expr)` -// -// ```cpp -// LiteRtStatus DoSomething(); -// -// // Will fail the test if DoSomething() doesn't return kLiteRtStatusOk. -// EXPECT_THAT(DoSomething(), IsOk()); -// ``` -// -// This also works for `Expected` objects. -// -// Note: You probably want `ASSERT_OK_AND_ASSIGN` when working with `Expected`. -// -// ```cpp -// Expected BuildSomething(); -// -// // Will fail the test if BuildSomething()'s returned value holds an error. -// // Note that the returned value is unused. -// EXPECT_THAT(BuildSomething(), IsOk()); -// ``` -inline IsOkMatcher IsOk() { return IsOkMatcher(); } - -// Matches `litert::Expected` values that hold an error and -// `LiteRtStatusError*` values. -// -// See `IsError(...)` functions below for usage examples. -class IsErrorMatcher { - public: - IsErrorMatcher(std::optional status, - std::optional msg) - : impl_(status, msg) {} - - // Implicitly builds and wraps the matcher implementation in a GTest - // Matcher object. - template - // NOLINTNEXTLINE(*-explicit-constructor): This needs to be implicit. - operator testing::Matcher() const { - return testing::Matcher(new Impl(impl_)); - } - - private: - class ImplBase { - public: - ImplBase() = default; - - explicit ImplBase(std::optional status, - std::optional msg) - : status_(status), msg_(std::move(msg)) {}; - - protected: - bool MatchAndExplainImpl(const LiteRtStatus status, - const absl::string_view msg, - testing::MatchResultListener* listener) const { - if (status == kLiteRtStatusOk || - (status_.has_value() && status != status_.value())) { - if (listener) { - *listener << "status doesn't match"; - } - return false; - } - if (msg_.has_value() && msg != msg_.value()) { - if (listener) { - *listener << "message doesn't match"; - } - return false; - } - return true; - } - - template - bool MatchAndExplainImpl(const ::litert::Expected& value, - testing::MatchResultListener* listener) const { - if (value.HasValue()) { - *listener << "expected holds a value (but should hold an error)"; - return false; - } - return MatchAndExplainImpl(value.Error(), listener); - } - - bool MatchAndExplainImpl(const ::litert::Unexpected& e, - testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(e.Error().Status(), e.Error().Message(), - listener); - } - - bool MatchAndExplainImpl(const ::litert::Error& e, - testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(e.Status(), e.Message(), listener); - } - - bool MatchAndExplainImpl(const LiteRtStatus& status, - testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(status, {}, listener); - } - - void DescribeImpl(std::ostream* os, const bool negation) const { - if (os) { - *os << "is" << (negation ? " not" : "") << " an error"; - const char* sep = " with "; - if (status_.has_value()) { - *os << sep << "status " << LiteRtGetStatusString(status_.value()); - sep = " and "; - } - if (msg_.has_value()) { - *os << sep << "message matching: '" << msg_.value() << "'"; - } - *os << "."; - } - } - - private: - std::optional status_; - std::optional msg_; - }; - - template - class Impl : public testing::MatcherInterface, ImplBase { - public: - using is_gtest_matcher = void; - - Impl() = default; - explicit Impl(const ImplBase& base) : ImplBase(base) {} - - bool MatchAndExplain( - V value, testing::MatchResultListener* listener) const override { - return MatchAndExplainImpl(value, listener); - } - - void DescribeTo(std::ostream* os) const override { - DescribeImpl(os, /*negation=*/false); - } - - void DescribeNegationTo(std::ostream* os) const override { - DescribeImpl(os, /*negation=*/true); - } - }; - - ImplBase impl_; -}; - -// Matches `litert::Expected`, `litert::Unexpected`, `litert::Error` and -// `LiteRtStatus` values that hold an error. -// -// Note: This will always match `true` for `litert::Unexpected` and -// `litert::Error`. This can be useful to test template code that might always -// return an error for certain specialisations. -// -// ```cpp -// LiteRtStatus DoSomething(); -// -// // Will fail the test if `DoSomething()` returns `kLiteRtStatusOk`. -// EXPECT_THAT(DoSomething(), IsError()); -// ``` -// -// This also works for `Expected` objects. -// -// ```cpp -// Expected BuildSomething(); -// -// // Will fail the test if BuildSomething()'s returned object holds a value. -// EXPECT_THAT(BuildSomething(), IsError()); -// ``` -inline IsErrorMatcher IsError() { - return IsErrorMatcher(/*status=*/std::nullopt, /*msg=*/std::nullopt); -} - -// Matches `litert::Expected`, `litert::Unexpected`, `litert::Error` and -// `LiteRtStatus` values that hold a specific error status. -// -// ```cpp -// Expected BuildSomething(); -// -// // Will fail the test if BuildSomething()'s returned object holds a value or -// // if the error status is not `kLiteRtStatusErrorSystemError`. -// EXPECT_THAT(BuildSomething(), IsError(kLiteRtStatusErrorSystemError)); -// ``` -inline IsErrorMatcher IsError(LiteRtStatus status) { - return IsErrorMatcher(status, /*msg=*/std::nullopt); -} - -// Matches `litert::Expected` and `LiteRtStatus` values that have a specific -// error status and error message. -// -// Warning: This will always return `false` for `LiteRtStatus` objects as those -// do not convey a message. -// -// ```cpp -// Expected BuildSomething(); -// -// // Will fail the test if BuildSomething()'s returned object holds a value. -// EXPECT_THAT(BuildSomething(), IsError(kLiteRtStatusErrorSystemError, -// "System is not initialised")); -// ``` -inline IsErrorMatcher IsError(LiteRtStatus status, std::string msg) { - return IsErrorMatcher(status, std::move(msg)); -} - -} // namespace testing::litert - -// GTest doesn't use `AbslStringify` if `GTEST_USE_ABSL` is not defined. This -// provides a fallback implementation. -// -// This is defined here instead of with `litert::Expected` because those -// functions should only be used for testing. -#if defined(LITERT_DEFINE_GTEST_STATUS_PRINTER) && !defined(GTEST_USE_ABSL) -#include "absl/strings/str_format.h" - -// GTest documentation explicitly states that functions the those below must -// live in the same namespace as the classes they are used with so that GTest -// can find them through ADL. -namespace litert { - -inline void PrintTo(const Error& error, std::ostream* os) { - *os << absl::StrFormat("%v", error); -} - -inline void PrintTo(const Unexpected& unexpected, std::ostream* os) { - *os << absl::StrFormat("%v", unexpected); -} - -template -void PrintTo(const Expected& expected, std::ostream* os) { - *os << absl::StrFormat("%v", expected); -} - -} // namespace litert - -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_MATCHERS_H_ diff --git a/tensorflow/lite/experimental/litert/test/matchers_test.cc b/tensorflow/lite/experimental/litert/test/matchers_test.cc deleted file mode 100644 index 1acfdf282a1810..00000000000000 --- a/tensorflow/lite/experimental/litert/test/matchers_test.cc +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -#include - -#include -#include -#include -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -using litert::Error; -using litert::Expected; -using litert::Unexpected; -using testing::Not; -using testing::StrEq; -using testing::litert::IsError; -using testing::litert::IsOk; - -namespace { - -struct CopyOnly { - CopyOnly() = default; - CopyOnly(const CopyOnly&) = default; - CopyOnly& operator=(const CopyOnly&) = default; -}; - -struct MoveOnly { - MoveOnly() = default; - MoveOnly(MoveOnly&&) = default; - MoveOnly& operator=(MoveOnly&&) = default; -}; - -TEST(IsOkMatcherTest, Works) { - const Expected error = Error(kLiteRtStatusErrorNotFound, "not found"); - EXPECT_THAT(kLiteRtStatusOk, IsOk()); - EXPECT_THAT(Expected(3), IsOk()); - - EXPECT_THAT(error, Not(IsOk())); - EXPECT_THAT(Unexpected(kLiteRtStatusErrorFileIO), Not(IsOk())); - EXPECT_THAT(Error(kLiteRtStatusErrorInvalidArgument), Not(IsOk())); - - EXPECT_THAT(kLiteRtStatusErrorUnsupported, Not(IsOk())); - - EXPECT_THAT(testing::DescribeMatcher>(IsOk()), StrEq("is ok.")); - EXPECT_THAT(testing::DescribeMatcher>(Not(IsOk())), - StrEq("is not ok.")); - - testing::StringMatchResultListener listener; - EXPECT_FALSE(testing::ExplainMatchResult( - IsOk(), kLiteRtStatusErrorUnsupported, &listener)); - EXPECT_THAT(listener.str(), StrEq("status is kLiteRtStatusErrorUnsupported")); - - listener.Clear(); - EXPECT_FALSE(testing::ExplainMatchResult(IsOk(), error, &listener)); - EXPECT_THAT(listener.str(), StrEq("")); - - listener.Clear(); - EXPECT_FALSE(testing::ExplainMatchResult(IsOk(), error.Error(), &listener)); - EXPECT_THAT(listener.str(), StrEq("")); -} - -// No, I'm not creating a templated test fixture just for that. This only -// contains non-fatal failures that are propagated to the test. -// -// The type of the error wrapper that fails is the test failure stack trace when -// debug options are specified. -template -void TestErrorWrapper() { - const ErrorWrapper error = Error(kLiteRtStatusErrorNotFound, "not found"); - EXPECT_THAT(error, IsError()); - EXPECT_THAT(error, IsError(kLiteRtStatusErrorNotFound)); - EXPECT_THAT(error, IsError(kLiteRtStatusErrorNotFound, "not found")); - // This checks against the wrong status. - EXPECT_THAT(error, Not(IsError(kLiteRtStatusErrorInvalidArgument))); - // This checks against the wrong message. - EXPECT_THAT(error, Not(IsError(kLiteRtStatusErrorNotFound, "oob"))); - - testing::StringMatchResultListener listener; - EXPECT_FALSE(testing::ExplainMatchResult( - IsError(kLiteRtStatusErrorInvalidArgument), error, &listener)); - EXPECT_THAT(listener.str(), StrEq("status doesn't match")); - - listener.Clear(); - EXPECT_FALSE(testing::ExplainMatchResult( - IsError(kLiteRtStatusErrorNotFound, "oob"), error, &listener)); - EXPECT_THAT(listener.str(), StrEq("message doesn't match")); -} - -TEST(IsErrorMatcherTest, Works) { - TestErrorWrapper>(); - TestErrorWrapper(); - TestErrorWrapper(); - - EXPECT_THAT(kLiteRtStatusErrorUnsupported, IsError()); - EXPECT_THAT(kLiteRtStatusOk, Not(IsError())); - EXPECT_THAT(Expected(3), Not(IsError())); - - EXPECT_THAT(testing::DescribeMatcher>(IsError()), - StrEq("is an error.")); - EXPECT_THAT(testing::DescribeMatcher>(Not(IsError())), - StrEq("is not an error.")); - EXPECT_THAT( - testing::DescribeMatcher>( - IsError(kLiteRtStatusErrorUnsupported)), - testing::StrEq("is an error with status kLiteRtStatusErrorUnsupported.")); - EXPECT_THAT(testing::DescribeMatcher>( - IsError(kLiteRtStatusErrorUnsupported, "unsupported")), - testing::StrEq("is an error with status " - "kLiteRtStatusErrorUnsupported and message " - "matching: 'unsupported'.")); - - testing::StringMatchResultListener listener; - EXPECT_FALSE( - testing::ExplainMatchResult(IsError(), kLiteRtStatusOk, &listener)); - EXPECT_THAT(listener.str(), StrEq("status doesn't match")); - - listener.Clear(); - EXPECT_FALSE( - testing::ExplainMatchResult(IsError(), Expected(3), &listener)); - EXPECT_THAT(listener.str(), - StrEq("expected holds a value (but should hold an error)")); -} - -TEST(LitertAssertOk, Works) { - LITERT_ASSERT_OK(Expected(3)); - LITERT_ASSERT_OK(kLiteRtStatusOk); - EXPECT_FATAL_FAILURE( - LITERT_ASSERT_OK(Error(kLiteRtStatusErrorInvalidArgument)), "is ok"); -} -TEST(LitertExpectOk, Works) { - LITERT_EXPECT_OK(Expected(3)); - LITERT_EXPECT_OK(kLiteRtStatusOk); - EXPECT_NONFATAL_FAILURE( - LITERT_EXPECT_OK(Error(kLiteRtStatusErrorInvalidArgument)), "is ok"); -} - -TEST(AssertOkAndAssign, DefineAVariableWorks) { - LITERT_ASSERT_OK_AND_ASSIGN(auto expected, Expected(3)); - static_assert(std::is_same_v, - "Type should be deduced to int."); - EXPECT_EQ(expected, 3); - - LITERT_ASSERT_OK_AND_ASSIGN([[maybe_unused]] auto copy_only, - Expected(CopyOnly())); - LITERT_ASSERT_OK_AND_ASSIGN([[maybe_unused]] auto move_only, - Expected(MoveOnly())); -} - -TEST(AssertOkAndAssign, AssignAVariableWorks) { - int expected = 0; - LITERT_ASSERT_OK_AND_ASSIGN(expected, Expected(3)); - EXPECT_EQ(expected, 3); - - [[maybe_unused]] CopyOnly copy_only; - [[maybe_unused]] MoveOnly move_only; - LITERT_ASSERT_OK_AND_ASSIGN(copy_only, Expected(CopyOnly())); - LITERT_ASSERT_OK_AND_ASSIGN(move_only, Expected(MoveOnly())); -} - -void TestAssertOkAndAssignFailure() { - LITERT_ASSERT_OK_AND_ASSIGN( - [[maybe_unused]] int expected, - Expected(Unexpected(kLiteRtStatusErrorInvalidArgument))); -} - -TEST(AssertOkAndAssign, FailuresStopsExecution) { - EXPECT_FATAL_FAILURE(TestAssertOkAndAssignFailure(), "is ok"); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/test/test_models.h b/tensorflow/lite/experimental/litert/test/test_models.h deleted file mode 100644 index ddad473d40bb42..00000000000000 --- a/tensorflow/lite/experimental/litert/test/test_models.h +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TEST_MODELS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TEST_MODELS_H_ - -#include "absl/strings/string_view.h" -#include "absl/types/span.h" - -// /////////////////////////////////////////////////////////////////////////// -// FP32 models. -// /////////////////////////////////////////////////////////////////////////// - -// Attention sub-module of a toy model. -static constexpr absl::string_view kAttentionModel = "attention.tflite"; - -// Attention vector einsum sub-module of a toy LLM. -static constexpr absl::string_view kAttnVecEinsumModel = - "attn_vec_einsum.tflite"; - -// Feed forward sub-module of a toy LLM. -static constexpr absl::string_view kFeedForwardModel = "ff.tflite"; - -// Key einsume sub-module of a toy LLM. -static constexpr absl::string_view kKeyEinsumModel = "k_einsum.tflite"; - -// Value einsum sub-module of a toy LLM. -static constexpr absl::string_view kValueEinsumModel = "v_einsum.tflite"; - -// Query einsum sub-module of a toy LLM. -static constexpr absl::string_view kQueryEinsumModel = "q_einsum.tflite"; - -// RMS Normalization sub-module of a toy LLM. -static constexpr absl::string_view kRMSNormModel = "norm.tflite"; - -// ROPE sub-module of a toy LLM. -static constexpr absl::string_view kROPEModel = "rope.tflite"; - -// ROPE sub-module of a toy LLM, uses embedding_lookup op for sin/cos. -static constexpr absl::string_view kLookUpROPEModel = "lookup_rope.tflite"; - -// Scale dot product attentionsub-module of a toy LLM. -static constexpr absl::string_view kSDPAModel = "sdpa.tflite"; - -// Transformer block sub-module of a toy LLM. -static constexpr absl::string_view kTransformerBlockModel = - "transformer.tflite"; - -// /////////////////////////////////////////////////////////////////////////// -// Quantized models. -// /////////////////////////////////////////////////////////////////////////// - -// Quantized model with a single mul op. -// Mul: <8x100x32x4xint16>, <8x100x32x4xint16> -> <8x100x32x4xint16> -static constexpr absl::string_view kQSimpleMul16x16Model = "mul_quant.tflite"; - -// Quantized model with a mul op and a add op. -// Mul: <8x100x32x4xint16>, <8x100x32x4xint16> -> <8x100x32x4xint16> -// Add: <8x100x32x4xint16>, <8x100x32x4xint16> -> <8x100x32x4xint16> -static constexpr absl::string_view kQMulAdd16x16Model = - "simple_quantized_ops.tflite"; - -// Single add op i16 activations and i8 weights and dynamic shape. -// Add: , -> -static constexpr absl::string_view kQSingleDynAdd16x8Model = - "single_add_default_a16w8_recipe_quantized.tflite"; - -// Single add op i8 activations and i8 weights and dynamic shape. -// Add: , -> -static constexpr absl::string_view kQSingleDynAdd8x8Model = - "single_add_default_a8w8_recipe_quantized.tflite"; - -// Single mul op i16 activations and i8 weights and dynamic shape. -// Mul: , -> -static constexpr absl::string_view kQSingleDynMul16x8Model = - "single_mul_default_a16w8_recipe_quantized.tflite"; - -// Single mul op i8 activations and i8 weights and dynamic shape. -// Mul: , -> -static constexpr absl::string_view kQSingleDynMul8x8Model = - "single_mul_default_a8w8_recipe_quantized.tflite"; - -// Single rsqrt op i16 activations and i8 weights and dynamic shape. -// RSQRT: -> -static constexpr absl::string_view kQSingleDynRsqrt16x8Model = - "single_rsqrt_default_a16w8_recipe_quantized.tflite"; - -// Single rsqrt op i8 activations and i8 weights and dynamic shape. -// RSQRT: -> -static constexpr absl::string_view kQSingleDynRsqrt8x8Model = - "single_rsqrt_default_a8w8_recipe_quantized.tflite"; - -// Quantized einsum model with i16 activations and i8 weights. -static constexpr absl::string_view kQQueryEinsum16x8Model = - "static_w8_a16_quantized_q_einsum.tflite"; - -static constexpr absl::string_view kQKeyEinsum16x8Model = - "static_w8_a16_quantized_k_einsum.tflite"; - -static constexpr absl::string_view kQVauleEinsum16x8Model = - "static_w8_a16_quantized_v_einsum.tflite"; - -static constexpr absl::string_view kQAttnVecEinsum16x8Model = - "static_w8_a16_quantized_attn_vec_einsum.tflite"; - -static constexpr absl::string_view kQSDPAModel = - "static_a8w8_quantized_sdpa.tflite"; - -// All the quantized test models. -static constexpr auto kAllQModels = absl::MakeConstSpan((absl::string_view[]){ - kQSimpleMul16x16Model, kQMulAdd16x16Model, kQSingleDynAdd16x8Model, - kQSingleDynAdd8x8Model, kQSingleDynMul16x8Model, kQSingleDynMul8x8Model, - kQSingleDynRsqrt16x8Model, kQSingleDynRsqrt8x8Model}); - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TEST_MODELS_H_ diff --git a/tensorflow/lite/experimental/litert/test/testdata/add_cst.mlir b/tensorflow/lite/experimental/litert/test/testdata/add_cst.mlir deleted file mode 100644 index 502a32a7845190..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/add_cst.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %cst = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> - return %0 : tensor<4xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/add_simple.mlir b/tensorflow/lite/experimental/litert/test/testdata/add_simple.mlir deleted file mode 100644 index 32945b4c8be23c..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/add_simple.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/cos_mul.mlir b/tensorflow/lite/experimental/litert/test/testdata/cos_mul.mlir deleted file mode 100644 index e6f996a706f619..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/cos_mul.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x32x2xf32>, %arg1: tensor<8x100x1x2xf32>) -> tensor<8x100x32x2xf32> { - %0 = "tfl.cos"(%arg1) : (tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> - %1 = tfl.mul(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<8x100x32x2xf32>, tensor<8x100x1x2xf32>) -> tensor<8x100x32x2xf32> - return %1 : tensor<8x100x32x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/cst_multi_subgraph.mlir b/tensorflow/lite/experimental/litert/test/testdata/cst_multi_subgraph.mlir deleted file mode 100644 index 8a11bf4f58ba4f..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/cst_multi_subgraph.mlir +++ /dev/null @@ -1,12 +0,0 @@ -module { - func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %0 = "tfl.pseudo_const"() <{value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32>}> : () -> tensor<4xf32> - %1 = tfl.mul %arg0, %0 {fused_activation_function = "NONE"} : tensor<4xf32> - return %1 : tensor<4xf32> - } - func.func @other(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %0 = "tfl.pseudo_const"() <{value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32>}> : () -> tensor<4xf32> - %1 = tfl.mul %arg0, %0 {fused_activation_function = "NONE"} : tensor<4xf32> - return %1 : tensor<4xf32> - } -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/dynamic_shape_tensor.mlir b/tensorflow/lite/experimental/litert/test/testdata/dynamic_shape_tensor.mlir deleted file mode 100644 index 7024ce189b7745..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/dynamic_shape_tensor.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor) -> tensor { - %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor - return %0 : tensor -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/fully_connected_3d.mlir b/tensorflow/lite/experimental/litert/test/testdata/fully_connected_3d.mlir deleted file mode 100644 index a3db1d9a887a65..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/fully_connected_3d.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x128xf32>, %arg1: tensor<128x128xf32>, %arg2: none) -> tensor<8x100x128xf32> { - %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) <{asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<8x100x128xf32>, tensor<128x128xf32>, none) -> tensor<8x100x128xf32> - return %0 : tensor<8x100x128xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/mul_simple.mlir b/tensorflow/lite/experimental/litert/test/testdata/mul_simple.mlir deleted file mode 100644 index dd02656c2f370f..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/mul_simple.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %1 = tfl.mul %0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %1 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/multi_composite.mlir b/tensorflow/lite/experimental/litert/test/testdata/multi_composite.mlir deleted file mode 100644 index 60a65cdfe4f38c..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/multi_composite.mlir +++ /dev/null @@ -1,21 +0,0 @@ -func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = stablehlo.composite "odml.npu_call" %arg0, %arg1 {decomposition = @decomp1} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = stablehlo.composite "odml.regular_composite" %arg0, %0 {decomposition = @decomp2} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %2 = stablehlo.composite "odml.npu_call" %arg0, %1 {decomposition = @decomp3} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %2 : tensor<2x2xf32> -} - -func.func private @decomp1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -func.func private @decomp2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -func.func private @decomp3(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/multi_op_multi_subgraph.mlir b/tensorflow/lite/experimental/litert/test/testdata/multi_op_multi_subgraph.mlir deleted file mode 100644 index 433d166fe3c1f5..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/multi_op_multi_subgraph.mlir +++ /dev/null @@ -1,9 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %1 = tfl.mul %0, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %2 = tfl.sub %1, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %3 = tfl.sub %2, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %3 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph.mlir b/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph.mlir deleted file mode 100644 index 7c1f0fe4e0f5b0..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph.mlir +++ /dev/null @@ -1,21 +0,0 @@ -module { - -func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %cst = arith.constant dense<[-1.0, -1.0, -1.0, -1.0]> : tensor<4xf32> - %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> - return %0 : tensor<4xf32> -} - -func.func @func1(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %cst = arith.constant dense<[1.0, 1.0, 1.0, 1.0]> : tensor<4xf32> - %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> - return %0 : tensor<4xf32> -} - -func.func @func2(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %cst = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : tensor<4xf32> - %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> - return %0 : tensor<4xf32> -} - -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph_mul.mlir b/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph_mul.mlir deleted file mode 100644 index 607100dbc389b6..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph_mul.mlir +++ /dev/null @@ -1,13 +0,0 @@ -module { - -func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -func.func @func1(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { - %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<4x4xf32> - return %0 : tensor<4x4xf32> -} - -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/multi_use_cst.mlir b/tensorflow/lite/experimental/litert/test/testdata/multi_use_cst.mlir deleted file mode 100644 index 617c27db761e44..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/multi_use_cst.mlir +++ /dev/null @@ -1,9 +0,0 @@ -module { -func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %cst = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> - %1 = tfl.add %0, %0 {fused_activation_function = "NONE"} : tensor<4xf32> - %2 = tfl.add %1, %cst {fused_activation_function = "NONE"} : tensor<4xf32> - return %2 : tensor<4xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/nested_composite.mlir b/tensorflow/lite/experimental/litert/test/testdata/nested_composite.mlir deleted file mode 100644 index 32ca6e26f2bfc9..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/nested_composite.mlir +++ /dev/null @@ -1,14 +0,0 @@ -func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = stablehlo.composite "odml.npu_call" %arg0, %arg1 {decomposition = @decomp1} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -func.func private @decomp1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = stablehlo.composite "odml.regular_composite" %arg0, %arg1 {decomposition = @decomp2} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -func.func private @decomp2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/one_mul.mlir b/tensorflow/lite/experimental/litert/test/testdata/one_mul.mlir deleted file mode 100644 index afabf1903ee846..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/one_mul.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/rms_norm.mlir b/tensorflow/lite/experimental/litert/test/testdata/rms_norm.mlir deleted file mode 100644 index 476c9829a5bd92..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/rms_norm.mlir +++ /dev/null @@ -1,16 +0,0 @@ -module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00zs\F5|\1F\CE)\0D\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.10.0\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} { - func.func @main(%arg0: tensor<8x128x1024xf32> {tf_saved_model.index_path = ["args_0"]}) -> (tensor<8x128x1024xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_args_0:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { - %0 = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<8x128x1024xf32> - %1 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> - %2 = "tfl.sum"(%0, %1) <{keep_dims = false}> : (tensor<8x128x1024xf32>, tensor<1xi32>) -> tensor<8x128xf32> - %3 = "tfl.pseudo_const"() <{value = dense<1.024000e+03> : tensor}> : () -> tensor - %4 = tfl.div(%2, %3) <{fused_activation_function = "NONE"}> : (tensor<8x128xf32>, tensor) -> tensor<8x128xf32> - %5 = "tfl.pseudo_const"() <{value = dense<9.99999997E-7> : tensor}> : () -> tensor - %6 = tfl.add(%4, %5) <{fused_activation_function = "NONE"}> : (tensor<8x128xf32>, tensor) -> tensor<8x128xf32> - %7 = "tfl.pseudo_const"() <{value = dense<[8, 128, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> - %8 = "tfl.reshape"(%6, %7) : (tensor<8x128xf32>, tensor<3xi32>) -> tensor<8x128x1xf32> - %9 = "tfl.rsqrt"(%8) : (tensor<8x128x1xf32>) -> tensor<8x128x1xf32> - %10 = tfl.mul(%arg0, %9) <{fused_activation_function = "NONE"}> : (tensor<8x128x1024xf32>, tensor<8x128x1xf32>) -> tensor<8x128x1024xf32> - return %10 : tensor<8x128x1024xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/rms_norm_composite.mlir b/tensorflow/lite/experimental/litert/test/testdata/rms_norm_composite.mlir deleted file mode 100644 index 6995e4d739ab2a..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/rms_norm_composite.mlir +++ /dev/null @@ -1,23 +0,0 @@ -module { - -func.func @main(%arg0: tensor<1x128x2304xf32>, %arg1: tensor<2304xf32>) -> tensor<1x128x2304xf32> { - %0 = stablehlo.composite "odml.rms_norm" %arg0, %arg1 {composite_attributes = {epsilon = 9.99999997E-7 : f32}, decomposition = @odml.rms_norm.impl} : (tensor<1x128x2304xf32>, tensor<2304xf32>) -> tensor<1x128x2304xf32> - return %0 : tensor<1x128x2304xf32> -} - -func.func @odml.rms_norm.impl(%arg0: tensor<1x128x2304xf32>, %arg1: tensor<2304xf32>) -> tensor<1x128x2304xf32> { - %0 = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<1x128x2304xf32> - %1 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> - %2 = "tfl.sum"(%0, %1) <{keep_dims = false}> : (tensor<1x128x2304xf32>, tensor<1xi32>) -> tensor<1x128xf32> - %3 = "tfl.pseudo_const"() <{value = dense<4.34027781E-4> : tensor}> : () -> tensor - %4 = tfl.mul(%2, %3) <{fused_activation_function = "NONE"}> : (tensor<1x128xf32>, tensor) -> tensor<1x128xf32> - %5 = "tfl.pseudo_const"() <{value = dense<9.99999997E-7> : tensor}> : () -> tensor - %6 = tfl.add(%4, %5) <{fused_activation_function = "NONE"}> : (tensor<1x128xf32>, tensor) -> tensor<1x128xf32> - %7 = "tfl.pseudo_const"() <{value = dense<[1, 128, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> - %8 = "tfl.reshape"(%6, %7) : (tensor<1x128xf32>, tensor<3xi32>) -> tensor<1x128x1xf32> - %9 = "tfl.rsqrt"(%8) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32> - %10 = tfl.mul(%arg0, %9) <{fused_activation_function = "NONE"}> : (tensor<1x128x2304xf32>, tensor<1x128x1xf32>) -> tensor<1x128x2304xf32> - %11 = tfl.mul(%10, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x128x2304xf32>, tensor<2304xf32>) -> tensor<1x128x2304xf32> - return %11 : tensor<1x128x2304xf32> - } -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/scala_reshape.mlir b/tensorflow/lite/experimental/litert/test/testdata/scala_reshape.mlir deleted file mode 100644 index 0b655f704eb5d7..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/scala_reshape.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x1xf32>) -> tensor { - %cst = arith.constant dense<[]> : tensor<0xi32> - %0 = "tfl.reshape"(%arg0, %cst) : (tensor<1x1xf32>, tensor<0xi32>) -> tensor - return %0 : tensor -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/shared_input_cpu_npu.mlir b/tensorflow/lite/experimental/litert/test/testdata/shared_input_cpu_npu.mlir deleted file mode 100644 index 42a5059e8861dd..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/shared_input_cpu_npu.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { - func.func @main(%x: tensor<2xf32>, %y: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %cpu_out = tfl.add %x, %y {fused_activation_function = "NONE"} : tensor<2xf32> - %npu_out = "tfl.custom"(%x, %y) {custom_code = "DISPATCH_OP", custom_option = #tfl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - func.return %cpu_out, %npu_out : tensor<2xf32>, tensor<2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_add_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_add_op.mlir deleted file mode 100644 index 0902f5966f8266..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_add_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x1xf32>, %arg1: tensor<1x128x1xf32>) -> tensor<1x128x1xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x1xf32> - return %0 : tensor<1x128x1xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_average_poll_2d.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_average_poll_2d.mlir deleted file mode 100644 index 979610cdaa0e1e..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_average_poll_2d.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x1728x2304x3xf32>) -> tensor<1x432x576x3xf32> { - %0 = "tfl.average_pool_2d"(%arg0) <{filter_height = 4 : i32, filter_width = 4 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 4 : i32}> : (tensor<1x1728x2304x3xf32>) -> tensor<1x432x576x3xf32> - return %0 : tensor<1x432x576x3xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir deleted file mode 100644 index e756a0dab87cbc..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x4x256x128xf32>, %arg1: tensor<1x4x128x128xf32>) -> tensor<1x4x256x128xf32> { - %0 = "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<1x4x256x128xf32>, tensor<1x4x128x128xf32>) -> tensor<1x4x256x128xf32> - return %0 : tensor<1x4x256x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_cascade_model_npu.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_cascade_model_npu.mlir deleted file mode 100644 index 5e262cb678714c..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_cascade_model_npu.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { - func.func @main(%x1: tensor<2xf32>, %x2: tensor<2xf32>, %x3: tensor<2xf32>) -> tensor<2xf32> { - %t1 = "tfl.custom"(%x1, %x2) {custom_code = "DISPATCH_OP_1", custom_option = #tfl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - %out = "tfl.custom"(%t1, %x3) {custom_code = "DISPATCH_OP_2", custom_option = #tfl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - func.return %out : tensor<2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_cast_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_cast_op.mlir deleted file mode 100644 index 6066c665713bc3..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_cast_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x1xi32>) -> tensor<8x100x1xf32> { - %0 = "tfl.cast"(%arg0) : (tensor<8x100x1xi32>) -> tensor<8x100x1xf32> - return %0 : tensor<8x100x1xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_composite.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_composite.mlir deleted file mode 100644 index 79c64f423039ba..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_composite.mlir +++ /dev/null @@ -1,11 +0,0 @@ -module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32, tf_saved_model.semantics} { - func.func @main(%arg0: tensor<2x2xf32> { tf_saved_model.index_path = ["arg0"] }, %arg1: tensor<2x2xf32> { tf_saved_model.index_path = ["arg1"]}) -> (tensor<2x2xf32> {tf_saved_model.index_path = ["output"] }) attributes {tf.entry_function = {inputs = "arg0,arg1", outputs = "output"}, tf_saved_model.exported_names = ["serving_default"]} { - %0 = stablehlo.composite "odml.npu_call" %arg0, %arg1 {decomposition = @decomp} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %0 : tensor<2x2xf32> - } - func.func private @decomp(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %0 : tensor<2x2xf32> - } -} - diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_concatenation_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_concatenation_op.mlir deleted file mode 100644 index e1e9bd36ae01b0..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_concatenation_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<128x4x1x256xf32>, %arg1: tensor<128x4x1x256xf32>) -> tensor<128x4x2x256xf32> { - %0 = "tfl.concatenation"(%arg0, %arg1) <{axis = 2 : i32, fused_activation_function = "NONE"}> : (tensor<128x4x1x256xf32>, tensor<128x4x1x256xf32>) -> tensor<128x4x2x256xf32> - return %0 : tensor<128x4x2x256xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_conv_2d_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_conv_2d_op.mlir deleted file mode 100644 index 4eb0e0a04d32c2..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_conv_2d_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x216x288x24xf32>, %arg1: tensor<24x3x3x24xf32>, %arg2: tensor<24xf32>) -> tensor<1x216x288x24xf32> { - %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x216x288x24xf32>, tensor<24x3x3x24xf32>, tensor<24xf32>) -> tensor<1x216x288x24xf32> - return %0 : tensor<1x216x288x24xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_cos_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_cos_op.mlir deleted file mode 100644 index 70ea46c1988b16..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_cos_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> { - %0 = "tfl.cos"(%arg0) : (tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> - return %0 : tensor<8x100x1x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_depth_to_space_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_depth_to_space_op.mlir deleted file mode 100644 index 2682b724b88c37..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_depth_to_space_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x216x288x12xf32>) -> tensor<1x432x576x3xf32> { - %0 = "tfl.depth_to_space"(%arg0) <{block_size = 2 : i32}> : (tensor<1x216x288x12xf32>) -> tensor<1x432x576x3xf32> - return %0 : tensor<1x432x576x3xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_depthwise_conv_2d_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_depthwise_conv_2d_op.mlir deleted file mode 100644 index 706295d3e27076..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_depthwise_conv_2d_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x40x40x192xf32>, %arg1: tensor<1x3x3x192xf32>, %arg2: tensor<192xf32>) -> tensor<1x32x32x192xf32> { - %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) <{depth_multiplier = 1 : i32, dilation_h_factor = 4 : i32, dilation_w_factor = 4 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x40x40x192xf32>, tensor<1x3x3x192xf32>, tensor<192xf32>) -> tensor<1x32x32x192xf32> - return %0 : tensor<1x32x32x192xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_div_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_div_op.mlir deleted file mode 100644 index 3748d45bcd5249..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_div_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x8x128xf32>, %arg1: tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> { - %0 = tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x8x128xf32> - return %0 : tensor<1x128x8x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_dynamic_update_slice_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_dynamic_update_slice_op.mlir deleted file mode 100644 index a10606eccd41f9..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_dynamic_update_slice_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x64x4x64xf32>, %arg1: tensor<1x1x4x64xf32>) -> tensor<1x64x4x64xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<[0, 1, 0, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> - %0 = "tfl.dynamic_update_slice"(%arg0, %arg1, %cst) : (tensor<1x64x4x64xf32>, tensor<1x1x4x64xf32>, tensor<4xi32>) -> tensor<1x64x4x64xf32> - return %0 : tensor<1x64x4x64xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_embedding_lookup_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_embedding_lookup_op.mlir deleted file mode 100644 index 75b8000bb97a35..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_embedding_lookup_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<5xi32>) -> tensor<5x1x2xf32> { - %table = "tfl.pseudo_const"() <{value = dense<"0x00010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001000100010001"> : tensor<20x1x2xf32>}> : () -> tensor<20x1x2xf32> - %0 = "tfl.embedding_lookup"(%arg0, %table) : (tensor<5xi32>, tensor<20x1x2xf32>) -> tensor<5x1x2xf32> - return %0 : tensor<5x1x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_floor_mod_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_floor_mod_op.mlir deleted file mode 100644 index 6bd3f1fa79d77c..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_floor_mod_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> { - %0 = "tfl.floor_mod"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> - return %0 : tensor<5xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_fully_connected_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_fully_connected_op.mlir deleted file mode 100644 index 5cad120662635e..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_fully_connected_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<128x2048xf32>, %arg1: tensor<2304x2048xf32>, %arg2: none) -> tensor<128x2304xf32> { - %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) <{asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<128x2048xf32>, tensor<2304x2048xf32>, none) -> tensor<128x2304xf32> - return %0 : tensor<128x2304xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_gather_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_gather_op.mlir deleted file mode 100644 index 6b0375c77c24a9..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_gather_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x3x6xf32>, %arg1: tensor<4x5xi32>) -> tensor<4x5x3x6xf32> { - %0 = "tfl.gather"(%arg0, %arg1) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<2x3x6xf32>, tensor<4x5xi32>) -> tensor<4x5x3x6xf32> - return %0 : tensor<4x5x3x6xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_gelu_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_gelu_op.mlir deleted file mode 100644 index 39ebcf24e972d0..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_gelu_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x1xf32>) -> tensor<8x100x1xf32> { - %0 = "tfl.gelu"(%arg0) : (tensor<8x100x1xf32>) -> tensor<8x100x1xf32> - return %0 : tensor<8x100x1xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_greater_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_greater_op.mlir deleted file mode 100644 index b368def16d6e88..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_greater_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x1x64xi32>, %arg1: tensor<1x1x64xi32>) -> tensor<1x1x64xi1> { - %0 = "tfl.greater"(%arg0, %arg1) : (tensor<1x1x64xi32>, tensor<1x1x64xi32>) -> tensor<1x1x64xi1> - return %0 : tensor<1x1x64xi1> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_hard_swish_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_hard_swish_op.mlir deleted file mode 100644 index 5c95ca2bb4e573..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_hard_swish_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x216x288x48xf32>) -> tensor<1x216x288x48xf32> { - %0 = "tfl.hard_swish"(%arg0) : (tensor<1x216x288x48xf32>) -> tensor<1x216x288x48xf32> - return %0 : tensor<1x216x288x48xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_leaky_relu_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_leaky_relu_op.mlir deleted file mode 100644 index 13dacd3984493a..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_leaky_relu_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x32x32x192xf32>) -> tensor<1x32x32x192xf32> { - %0 = "tfl.leaky_relu"(%arg0) <{alpha = 2.000000e-01 : f32}> : (tensor<1x32x32x192xf32>) -> tensor<1x32x32x192xf32> - return %0 : tensor<1x32x32x192xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_less_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_less_op.mlir deleted file mode 100644 index 06370a186ddc5b..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_less_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x1x64xi32>, %arg1: tensor<1x1x64xi32>) -> tensor<1x1x64xi1> { - %0 = "tfl.less"(%arg0, %arg1) : (tensor<1x1x64xi32>, tensor<1x1x64xi32>) -> tensor<1x1x64xi1> - return %0 : tensor<1x1x64xi1> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_logical_and_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_logical_and_op.mlir deleted file mode 100644 index e58307caceb3be..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_logical_and_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x64x64xi1>, %arg1: tensor<1x64x64xi1>) -> tensor<1x64x64xi1> { - %0 = "tfl.logical_and"(%arg0, %arg1) : (tensor<1x64x64xi1>, tensor<1x64x64xi1>) -> tensor<1x64x64xi1> - return %0 : tensor<1x64x64xi1> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_mean_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_mean_op.mlir deleted file mode 100644 index 56b4fcb8a9f3e6..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_mean_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> - %0 = "tfl.mean"(%arg0, %cst) <{keep_dims = false}> : (tensor<2x2xf32>, tensor<1xi32>) -> tensor<2xf32> - return %0 : tensor<2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_model.mlir deleted file mode 100644 index d88a5d5923c77e..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_model.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { - %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2xf32> - return %0 : tensor<2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model_google_tensor.bin b/tensorflow/lite/experimental/litert/test/testdata/simple_model_google_tensor.bin deleted file mode 100644 index 208cb983671510..00000000000000 Binary files a/tensorflow/lite/experimental/litert/test/testdata/simple_model_google_tensor.bin and /dev/null differ diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model_mtk.bin b/tensorflow/lite/experimental/litert/test/testdata/simple_model_mtk.bin deleted file mode 100644 index b6702cc8b180a2..00000000000000 Binary files a/tensorflow/lite/experimental/litert/test/testdata/simple_model_mtk.bin and /dev/null differ diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model_npu.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_model_npu.mlir deleted file mode 100644 index f4959fb63e6231..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_model_npu.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { - func.func @main(%x: tensor<2xf32>, %y: tensor<2xf32>) -> tensor<2xf32> { - %out = "tfl.custom"(%x, %y) {custom_code = "DISPATCH_OP", custom_option = #tfl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - func.return %out : tensor<2xf32> - } -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model_qualcomm.bin b/tensorflow/lite/experimental/litert/test/testdata/simple_model_qualcomm.bin deleted file mode 100644 index a66f76296d7698..00000000000000 Binary files a/tensorflow/lite/experimental/litert/test/testdata/simple_model_qualcomm.bin and /dev/null differ diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h b/tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h deleted file mode 100644 index 2068cb028b8a6d..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TESTDATA_SIMPLE_MODEL_TEST_VECTORS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TESTDATA_SIMPLE_MODEL_TEST_VECTORS_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" - -constexpr const char* kModelFileName = "simple_model.tflite"; -constexpr const char* kQualcommModelFileName = "simple_model_qualcomm.bin"; -constexpr const char* kGoogleTensorModelFileName = - "simple_model_google_tensor.bin"; -constexpr const char* kMediaTekModelFileName = "simple_model_mtk.bin"; - -constexpr const int32_t kTestInput0Dimensions[] = {2}; -constexpr const int32_t kNumTestInput0Dimensions = - sizeof(kTestInput0Dimensions) / sizeof(kTestInput0Dimensions[0]); -constexpr const int32_t kTestInput1Dimensions[] = {2}; -constexpr const int32_t kNumTestInput1Dimensions = - sizeof(kTestInput1Dimensions) / sizeof(kTestInput1Dimensions[0]); -constexpr const int32_t kTestOutputDimensions[] = {2}; -constexpr const int32_t kNumTestOutputDimensions = - sizeof(kTestOutputDimensions) / sizeof(kTestOutputDimensions[0]); - -constexpr const float kTestInput0Tensor[] = {1, 2}; -constexpr const float kTestInput1Tensor[] = {10, 20}; -constexpr const float kTestOutputTensor[] = {11, 22}; - -constexpr const float kTestInput0Tensor_2[] = {10, 20}; -constexpr const float kTestInput1Tensor_2[] = {100, 200}; -constexpr const float kTestOutputTensor_2[] = {110, 220}; - -constexpr const size_t kTestInput0Size = - sizeof(kTestInput0Tensor) / sizeof(kTestInput0Tensor[0]); -constexpr const size_t kTestInput1Size = - sizeof(kTestInput1Tensor) / sizeof(kTestInput1Tensor[0]); -constexpr const size_t kTestOutputSize = - sizeof(kTestOutputTensor) / sizeof(kTestOutputTensor[0]); - -constexpr const LiteRtRankedTensorType kInput0TensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - ::litert::BuildLayout(kTestInput0Dimensions)}; - -constexpr const LiteRtRankedTensorType kInput1TensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - ::litert::BuildLayout(kTestInput1Dimensions)}; - -constexpr const LiteRtRankedTensorType kOutputTensorType = { - /*.element_type=*/kLiteRtElementTypeFloat32, - ::litert::BuildLayout(kTestOutputDimensions)}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TESTDATA_SIMPLE_MODEL_TEST_VECTORS_H_ diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_mul_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_mul_op.mlir deleted file mode 100644 index 7fb5ac2d2187f0..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_mul_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x2304xf32>, %arg1: tensor<1x128x2304xf32>) -> tensor<1x128x2304xf32> { - %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x2304xf32> - return %0 : tensor<1x128x2304xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_multi_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_multi_op.mlir deleted file mode 100644 index 07757fddec1b90..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_multi_op.mlir +++ /dev/null @@ -1,9 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %1 = tfl.mul %0, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %2 = tfl.mul %1, %1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %3 = tfl.add %2, %2 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %3 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_pack_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_pack_op.mlir deleted file mode 100644 index e94d4815d9545e..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_pack_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<4xi32> { - // %cst = "tfl.pseudo_const"() <{value = dense<0> : tensor}> : () -> tensor - %0 = "tfl.pack"(%arg0, %arg1, %arg2, %arg3) <{axis = 0 : i32, values_count = 4 : i32}> : (tensor, tensor, tensor, tensor) -> tensor<4xi32> - return %0 : tensor<4xi32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_relu6_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_relu6_op.mlir deleted file mode 100644 index 17bbc4ef2fdefe..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_relu6_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x1xf32>) -> tensor<8x100x1xf32> { - %0 = "tfl.relu6"(%arg0) : (tensor<8x100x1xf32>) -> tensor<8x100x1xf32> - return %0 : tensor<8x100x1xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_relu_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_relu_op.mlir deleted file mode 100644 index 72306d2b9e6cd3..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_relu_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x1xf32>) -> tensor<8x100x1xf32> { - %0 = "tfl.relu"(%arg0) : (tensor<8x100x1xf32>) -> tensor<8x100x1xf32> - return %0 : tensor<8x100x1xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_reshape_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_reshape_op.mlir deleted file mode 100644 index 515db6e424e6a7..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_reshape_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x4x256xf32>, %arg1: tensor<4xi32>) -> tensor<128x4x1x256xf32> { - %0 = "tfl.reshape"(%arg0, %arg1) : (tensor<1x128x4x256xf32>, tensor<4xi32>) -> tensor<128x4x1x256xf32> - return %0 : tensor<128x4x1x256xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_resize_bilinear_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_resize_bilinear_op.mlir deleted file mode 100644 index 1cd9be9729f487..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_resize_bilinear_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x54x72x96xf32>) -> tensor<1x108x144x96xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<[108, 144]> : tensor<2xi32>}> : () -> tensor<2xi32> - %0 = "tfl.resize_bilinear"(%arg0, %cst) <{align_corners = false, half_pixel_centers = true}> : (tensor<1x54x72x96xf32>, tensor<2xi32>) -> tensor<1x108x144x96xf32> - return %0 : tensor<1x108x144x96xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_resize_nearest_neighbor_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_resize_nearest_neighbor_op.mlir deleted file mode 100644 index a73eb7e60e0b4c..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_resize_nearest_neighbor_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x54x72x96xf32>) -> tensor<1x108x144x96xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<[108, 144]> : tensor<2xi32>}> : () -> tensor<2xi32> - %0 = "tfl.resize_nearest_neighbor"(%arg0, %cst) <{align_corners = false, half_pixel_centers = true}> : (tensor<1x54x72x96xf32>, tensor<2xi32>) -> tensor<1x108x144x96xf32> - return %0 : tensor<1x108x144x96xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_rsqrt_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_rsqrt_op.mlir deleted file mode 100644 index 5083f3f3a30383..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_rsqrt_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x1xf32>) -> tensor<1x128x1xf32> { - %0 = "tfl.rsqrt"(%arg0) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32> - return %0 : tensor<1x128x1xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_select_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_select_op.mlir deleted file mode 100644 index 2405e5d3626893..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_select_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x8x128xi1>, %arg1: tensor<1x128x8x128xf32>, %arg2: tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> { - %0 = "tfl.select"(%arg0, %arg1, %arg2) : (tensor<1x128x8x128xi1>, tensor<1x128x8x128xf32>, tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> - return %0 : tensor<1x128x8x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_select_v2_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_select_v2_op.mlir deleted file mode 100644 index a8d80ecc80f970..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_select_v2_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x1x1x100xi1>, %arg1: tensor<8x100x32x100xf32>, %arg2: tensor<8x100x32x100xf32>) -> tensor<8x100x32x100xf32> { - %0 = "tfl.select_v2"(%arg0, %arg1, %arg2) : (tensor<8x1x1x100xi1>, tensor<8x100x32x100xf32>, tensor<8x100x32x100xf32>) -> tensor<8x100x32x100xf32> - return %0 : tensor<8x100x32x100xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_sin_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_sin_op.mlir deleted file mode 100644 index 431d3b93065441..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_sin_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> { - %0 = "tfl.sin"(%arg0) : (tensor<8x100x1x2xf32>) -> tensor<8x100x1x2xf32> - return %0 : tensor<8x100x1x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_slice_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_slice_op.mlir deleted file mode 100644 index 4adfa00a204cfc..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_slice_op.mlir +++ /dev/null @@ -1,8 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x8x256xf32>) -> tensor<1x128x4x128xf32> { - %cst_0 = "tfl.pseudo_const"() <{value = dense<0> : tensor<4xi32>}> : () -> tensor<4xi32> - %cst_1 = "tfl.pseudo_const"() <{value = dense<[1, 128, 4, 128]> : tensor<4xi32>}> : () -> tensor<4xi32> - %0 = "tfl.slice"(%arg0, %cst_0, %cst_1) : (tensor<1x128x8x256xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x128x4x128xf32> - return %0 : tensor<1x128x4x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_softmax_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_softmax_op.mlir deleted file mode 100644 index bb3a83a3787f6f..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_softmax_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> { - %0 = "tfl.softmax"(%arg0) <{beta = 1.000000e+00 : f32}> : (tensor<8x128xf32>) -> tensor<8x128xf32> - return %0 : tensor<8x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_space_to_depth_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_space_to_depth_op.mlir deleted file mode 100644 index 3e336816486285..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_space_to_depth_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x432x576x6xf32>) -> tensor<1x216x288x24xf32> { - %0 = "tfl.space_to_depth"(%arg0) <{block_size = 2 : i32}> : (tensor<1x432x576x6xf32>) -> tensor<1x216x288x24xf32> - return %0 : tensor<1x216x288x24xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_split_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_split_op.mlir deleted file mode 100644 index 38c99095a01319..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_split_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x4x3x3xf32>) -> tensor<1x4x3x1xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<3> : tensor}> : () -> tensor - %0:3 = "tfl.split"(%cst, %arg0) <{num_splits = 3 : i32}> : (tensor, tensor<1x4x3x3xf32>) -> (tensor<1x4x3x1xf32>, tensor<1x4x3x1xf32>, tensor<1x4x3x1xf32>) - return %0#0 : tensor<1x4x3x1xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir deleted file mode 100644 index 9d098eb0b9f61d..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir +++ /dev/null @@ -1,9 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x4x256xf32>, %arg1: tensor<131072x4xi32>, %arg2: tensor<131072xf32>) -> tensor<1x128x4x256xf32> { - %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{scatter_dimension_numbers = #stablehlo.scatter}> ({ - ^bb0(%arg3: tensor, %arg4: tensor): - stablehlo.return %arg4 : tensor - }) : (tensor<1x128x4x256xf32>, tensor<131072x4xi32>, tensor<131072xf32>) -> tensor<1x128x4x256xf32> - return %0 : tensor<1x128x4x256xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_strided_slice_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_strided_slice_op.mlir deleted file mode 100644 index 373eff80ff3cd8..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_strided_slice_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x4x256xf32>, %arg1: tensor<4xi32>, %arg2: tensor<4xi32>, %arg3: tensor<4xi32>) -> tensor<1x128x4x128xf32> { - %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<1x128x4x256xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x128x4x128xf32> - return %0 : tensor<1x128x4x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_sub_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_sub_op.mlir deleted file mode 100644 index e1483fed87d802..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_sub_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x4x128xf32>, %arg1: tensor<1x128x4x128xf32>) -> tensor<1x128x4x128xf32> { - %0 = tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x4x128xf32> - return %0 : tensor<1x128x4x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_sum_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_sum_op.mlir deleted file mode 100644 index bb4613d5b4b6c5..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_sum_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x2304xf32>) -> tensor<1x128x1xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> - %0 = "tfl.sum"(%arg0, %cst) <{keep_dims = true}> : (tensor<1x128x2304xf32>, tensor<1xi32>) -> tensor<1x128x1xf32> - return %0 : tensor<1x128x1xf32> -} -} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_tanh_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_tanh_op.mlir deleted file mode 100644 index ce1d0302c8a838..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_tanh_op.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> { - %0 = "tfl.tanh"(%arg0) : (tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> - return %0 : tensor<1x128x8x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_transpose_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_transpose_op.mlir deleted file mode 100644 index f24d72216897fd..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/simple_transpose_op.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<128x4x2x128xf32>) -> tensor<128x2x4x128xf32> { - %cst = "tfl.pseudo_const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> - %0 = "tfl.transpose"(%arg0, %cst) : (tensor<128x4x2x128xf32>, tensor<4xi32>) -> tensor<128x2x4x128xf32> - return %0 : tensor<128x2x4x128xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/two_adds.mlir b/tensorflow/lite/experimental/litert/test/testdata/two_adds.mlir deleted file mode 100644 index 463dd456dc5c5d..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/two_adds.mlir +++ /dev/null @@ -1,7 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %1 = tfl.add %0, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %1 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/two_partition.mlir b/tensorflow/lite/experimental/litert/test/testdata/two_partition.mlir deleted file mode 100644 index 738c8309110318..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/two_partition.mlir +++ /dev/null @@ -1,9 +0,0 @@ -module { -func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %1 = tfl.mul %0, %0 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %2 = tfl.add %1, %1 {fused_activation_function = "NONE"} : tensor<2x2xf32> - %3 = tfl.mul %2, %2 {fused_activation_function = "NONE"} : tensor<2x2xf32> - return %3 : tensor<2x2xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/unranked_tensor.mlir b/tensorflow/lite/experimental/litert/test/testdata/unranked_tensor.mlir deleted file mode 100644 index 4e2403a7fadbd8..00000000000000 --- a/tensorflow/lite/experimental/litert/test/testdata/unranked_tensor.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { -func.func @main(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<*xf32> - return %0 : tensor<*xf32> -} -} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/tools/BUILD b/tensorflow/lite/experimental/litert/tools/BUILD deleted file mode 100644 index 88e02b5246d4d0..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/BUILD +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/vendors/qualcomm:qualcomm_build_defs.bzl", "litert_cc_bin_with_qnn") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "apply_plugin", - srcs = ["apply_plugin.cc"], - hdrs = ["apply_plugin.h"], - deps = [ - ":dump", - ":outstream", - ":tool_display", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_flags", - "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_plugin", - "//tensorflow/lite/experimental/litert/core/model:model_serialize", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "apply_plugin_test", - srcs = ["apply_plugin_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/vendors/examples:example_plugin_so", - ], - tags = [ - "noasan", - "nomsan", - "nosan", - "notsan", - ], - deps = [ - ":apply_plugin", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "//tensorflow/lite/experimental/litert/core:dispatch_op_schema", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -# TODO - @lukeboyer: Figure out some selective inclusiion of the data deps, some are very large. -litert_cc_bin_with_qnn( - name = "apply_plugin_main", - srcs = ["apply_plugin_main.cc"], - data = [ - # copybara:uncomment_begin(google-only) - # "//platforms/darwinn/compiler:compiler_api_wrapper", - # copybara:uncomment_end - "//tensorflow/lite/experimental/litert/vendors/examples:example_plugin_so", - "//tensorflow/lite/experimental/litert/vendors/google_tensor/compiler:google_tensor_compiler_plugin_so", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:qnn_compiler_plugin_so", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler:compiler_plugin_so", - ], - export_litert_only = 1, - include_system = 1, - linkstatic = 1, - # copybara:uncomment malloc = "//base:system_malloc", - tags = [ - "noasan", - "nobuilder", - "nomsan", - "nosan", - ], - ungrte = True, - deps = [ - ":apply_plugin", - ":outstream", - "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_flags", - "//tensorflow/lite/experimental/litert/core:build_stamp", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@llvm-project//llvm:Support", - ], -) - -# Fork of "apply_plugin_main" without the "ungrte" so this tool can be used as part of larger -# integration test pipelines with example_plugin. -cc_binary( - name = "apply_plugin_main_for_test", - testonly = 1, - srcs = ["apply_plugin_main.cc"], - data = [ - "//tensorflow/lite/experimental/litert/vendors/examples:example_plugin_so", - ], - linkstatic = 1, - tags = [ - "noasan", - "nomsan", - "nosan", - ], - deps = [ - ":apply_plugin", - ":outstream", - "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_flags", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@llvm-project//llvm:Support", - ], -) - -cc_library( - name = "tool_display", - srcs = ["tool_display.cc"], - hdrs = ["tool_display.h"], - deps = [ - ":outstream", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "tool_display_test", - srcs = ["tool_display_test.cc"], - data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], - deps = [ - ":tool_display", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "dump", - srcs = ["dump.cc"], - hdrs = ["dump.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_plugin", - "//tensorflow/lite/experimental/litert/core/model", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "dump_test", - srcs = ["dump_test.cc"], - data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], - deps = [ - ":dump", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "outstream", - hdrs = ["outstream.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_logging", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "benchmark_litert_model", - srcs = ["benchmark_litert_model.cc"], - hdrs = ["benchmark_litert_model.h"], - deps = [ - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/c:common", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_compilation_options", - "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", - "//tensorflow/lite/experimental/litert/cc:litert_environment", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/tools:utils", - "//tensorflow/lite/tools/benchmark:benchmark_model_lib", - "//tensorflow/lite/tools/benchmark:benchmark_params", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "benchmark_litert_model_test", - srcs = ["benchmark_litert_model_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:testdata/mobilenet_v2_1.0_224.tflite", - ], - env = { - "ASAN_OPTIONS": "detect_odr_violation=0", - }, - tags = [ - "manual", - "notap", - "requires-gpu-nvidia", - ], - deps = - [ - ":benchmark_litert_model", - "@com_google_googletest//:gtest_main", - # copybara:uncomment_begin(google-only) - # "//third_party/odml/infra/ml_drift_delegate/litert:ml_drift_cl_accelerator", # buildcleaner: keep - # copybara:uncomment_end - "//tensorflow/lite/core/c:c_api_types", - "//tensorflow/lite/tools/benchmark:benchmark_model_lib", - "//tensorflow/lite/tools/benchmark:benchmark_params", - ], -) - -# We create a library for benchmark_main.cc to faciliate the creation of a -# customized benchmark model binary that only needs linking with extra -# dependency, e.g., enabling creating of benchmark binaries with a custom -# delegate provider. -cc_library( - name = "benchmark_model_main", - srcs = [ - "benchmark_litert_model_main.cc", - ], - deps = [ - ":benchmark_litert_model", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/tools:logging", - ], -) diff --git a/tensorflow/lite/experimental/litert/tools/README.md b/tensorflow/lite/experimental/litert/tools/README.md deleted file mode 100644 index 400f9c1f9a5b19..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/README.md +++ /dev/null @@ -1,24 +0,0 @@ -## run_model - -This is a simple tool to run a model with the CompiledModel API. - -``` -run_model --graph= -``` - -### Use NPU via Dispatch API - -If you're using the Dispatch API, you need to pass the Dispatch library -(libLiteRtDispatch_xxx.so) location via `--dispatch_library_dir` - -``` -run_model --graph= --dispatch_library_dir= -``` - -### Use GPU - -If you run a model with GPU accelerator, use `--use_gpu` flag. - -``` -run_model --graph= --use_gpu -``` diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin.cc b/tensorflow/lite/experimental/litert/tools/apply_plugin.cc deleted file mode 100644 index 36db8e81000fe4..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/apply_plugin.cc +++ /dev/null @@ -1,515 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/tools/apply_plugin.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/tools/dump.h" -#include "tensorflow/lite/experimental/litert/tools/tool_display.h" - -namespace litert::tools { - -using ::litert::BufferRef; -using ::litert::internal::CompilerFlags; -using ::litert::internal::CompilerPlugin; -using ::litert::internal::Dump; -using ::litert::internal::PartitionResult; -using ::litert::internal::SerializeModel; -using ::litert::internal::VerifyFlatbuffer; -using ::litert::tools::ApplyPluginRun; - -#define LITERT_ENSURE_CONFIG(expr) \ - if (!(expr)) { \ - return kLiteRtStatusErrorInvalidToolConfig; \ - } - -namespace { - -class Context { - public: - using Ptr = std::unique_ptr; - - explicit Context(ApplyPluginRun::Ptr run) - : run_(std::move(run)), - display_(ToolDisplay(std::move(run_->dump_out), - Context::CmdStr(run_->cmd))) {} - - ApplyPluginRun::Cmd Cmd() const { return run_->cmd; } - - absl::Span LibSearchPaths() const { - return absl::MakeConstSpan(run_->lib_search_paths.data(), - run_->lib_search_paths.size()); - } - - absl::string_view SocModelTarget() const { - ABSL_CHECK_EQ(run_->soc_models.size(), 1); - return run_->soc_models.front(); - } - - absl::string_view SocManufacturer() const { - return run_->soc_manufacturer.value(); - } - - std::ostream& Out(size_t out_ind = 0) { - ABSL_CHECK_GE(run_->outs.size(), 1); - return run_->outs.at(out_ind); - } - - const CompilerFlags& Flags() const { return run_->compiler_flags; } - - OutStream SwapOut(OutStream out) { - ABSL_CHECK_EQ(run_->outs.size(), 1); - auto res = run_->outs.front(); - run_->outs.at(0) = out; - return res; - } - - uint32_t NumOuts() const { return run_->outs.size(); } - - const ApplyPluginRun& Run() const { return *run_; } - ApplyPluginRun& Run() { return *run_; } - - ToolDisplay& Dump() { return display_; } - - static absl::string_view CmdStr(ApplyPluginRun::Cmd cmd); - - private: - ApplyPluginRun::Ptr run_; - ToolDisplay display_; -}; - -void DumpSubgraphs(ToolDisplay& display, absl::string_view label, - absl::Span subgraphs) { - for (auto* subgraph : subgraphs) { - display.Labeled(); - display.Indented() << absl::StreamFormat("(%s graph)", label); - Dump(*subgraph, display.Display()); - } -} - -void DumpCompilationRequest(ToolDisplay& display, absl::string_view soc_model, - size_t num_subgraphs, const CompilerFlags& flags) { - display.Labeled() << absl::StreamFormat( - "Requesting compilation for target `%s` on %lu " - "partitions with flags: ", - soc_model, num_subgraphs) - << flags << "\n"; -} - -void DumpCompilationResult(ToolDisplay& display, size_t byte_code_size, - size_t num_entry_points) { - display.Labeled() << absl::StreamFormat( - "Compiled %lu partitions into %lu bytes\n", num_entry_points, - byte_code_size); -} - -void DumpModelStats(ToolDisplay& display, BufferRef buf) { - display.Labeled() << absl::StreamFormat( - "Serialized a model of size %lu bytes\n", buf.Size()); -} - -void DumpPartitionResult(ToolDisplay& display, const PartitionResult& result) { - display.Labeled() << absl::StreamFormat( - "Partitioning yielded %lu new subgraphs\n", result.second.Size()); - - DumpSubgraphs(display, "new subgraphs", result.second.Elements()); -} - -absl::string_view Context::CmdStr(ApplyPluginRun::Cmd cmd) { - switch (cmd) { - case ApplyPluginRun::Cmd::INFO: - return "INFO"; - case ApplyPluginRun::Cmd::NOOP: - return "NOOP"; - case ApplyPluginRun::Cmd::PARTITION: - return "PARTITION"; - case ApplyPluginRun::Cmd::COMPILE: - return "COMPILE"; - case ApplyPluginRun::Cmd::APPLY: - return "APPLY"; - } -} - -Expected> LoadAllPlugins(Context& ctx) { - ctx.Dump().Start("Load Plugins"); - ctx.Dump().Labeled() << "Loading plugins from: "; - const auto paths = ctx.LibSearchPaths(); - for (auto it = paths.begin(); it < paths.end(); ++it) { - ctx.Dump().Display() << *it; - if (it < paths.end() - 1) { - ctx.Dump().Display() << ", "; - } - } - ctx.Dump().Display() << "\n"; - - auto plugins = CompilerPlugin::LoadPlugins(ctx.LibSearchPaths()); - if (!plugins.HasValue()) { - ctx.Dump().Fail(); - return plugins; - } - ctx.Dump().Labeled() << "Found plugins\n"; - ctx.Dump().Labeled() << absl::StreamFormat("Loaded %lu plugins\n", - plugins.Value().size()); - - ctx.Dump().Done(); - return plugins; -} - -Expected LoadPlugin(Context& ctx) { - auto plugins = LoadAllPlugins(ctx); - if (!plugins) { - return plugins.Error(); - } - - ctx.Dump().Start("Select Plugin"); - - for (auto& plugin : *plugins) { - if (plugin.SocManufacturer() == ctx.Run().soc_manufacturer) { - ctx.Dump().Labeled() << absl::StreamFormat("Selected plugin for: %s\n", - plugin.SocManufacturer()); - ctx.Dump().Done(); - return std::move(plugin); - } - } - - ctx.Dump().Fail(); - return Unexpected(kLiteRtStatusErrorNotFound); -} - -Expected LoadModel(Context& ctx) { - ctx.Dump().Start("Load Model"); - ctx.Dump().Labeled() << absl::StreamFormat("Loading model from: %s\n", - ctx.Run().model.value()); - auto model_result = Model::CreateFromFile(ctx.Run().model->data()); - if (!model_result.HasValue()) { - ctx.Dump().Labeled() << "Failed to load model from file."; - ctx.Dump().Fail(); - return model_result; - } - - ctx.Dump().Labeled(); - Dump(*model_result.Value().Get(), ctx.Dump().Display()); - ctx.Dump().Done(); - - return model_result; -} - -// -// INFO Command -// - -LiteRtStatus ValidateInfoRun(const ApplyPluginRun& run) { - LITERT_ENSURE_CONFIG(!run.lib_search_paths.empty()); - LITERT_ENSURE_CONFIG(run.outs.size() == 1); - return kLiteRtStatusOk; -} - -LiteRtStatus Info(Context& ctx) { - auto plugins = LoadAllPlugins(ctx); - if (!plugins) { - return plugins.Error().Status(); - } - - for (auto& plugin : *plugins) { - ctx.Out() << absl::StreamFormat("< LiteRtCompilerPlugin > \"%s\" | ", - plugin.SocManufacturer()); - const auto& models = plugin.SocModels(); - for (auto it = models.begin(); it < models.end(); ++it) { - ctx.Out() << absl::StreamFormat("\"%s\"", *it); - if (it < models.end() - 1) { - ctx.Out() << ", "; - } - } - ctx.Out() << "\n"; - } - return kLiteRtStatusOk; -} - -// -// NOOP Command -// - -LiteRtStatus ValidateNoopRun(const ApplyPluginRun& run) { - LITERT_ENSURE_CONFIG(run.model.has_value()); - LITERT_ENSURE_CONFIG(run.outs.size() == 1); - return kLiteRtStatusOk; -} - -LiteRtStatus Noop(Context& ctx) { - auto model = LoadModel(ctx); - if (!model) { - return model.Error().Status(); - } - - auto serialized = SerializeModel(std::move(*model->Get())); - if (!serialized) { - return serialized.Error().Status(); - } - LITERT_ENSURE(VerifyFlatbuffer(serialized->Span()), - kLiteRtStatusErrorInvalidFlatbuffer, - "Failed to invalidate flatbuffer"); - serialized->WriteStr(ctx.Out()); - return kLiteRtStatusOk; -} - -// -// PARTITION Command -// - -LiteRtStatus ValidatePartitionRun(const ApplyPluginRun& run) { - LITERT_ENSURE_CONFIG(!run.lib_search_paths.empty()); - LITERT_ENSURE_CONFIG(run.model.has_value() && !run.model.value().empty()); - LITERT_ENSURE_CONFIG(run.soc_manufacturer.has_value()); - LITERT_ENSURE_CONFIG(!run.outs.empty()); - return kLiteRtStatusOk; -} - -LiteRtStatus Partition(Context& ctx) { - auto plugin = LoadPlugin(ctx); - if (!plugin) { - return plugin.Error().Status(); - } - - auto model_wrap = LoadModel(ctx); - if (!model_wrap) { - return model_wrap.Error().Status(); - } - auto& model = *model_wrap->Get(); - - ctx.Dump().Start("Partitioning model"); - auto partition_result = PartitionModel(*plugin, model, ctx.Run().subgraphs); - if (!partition_result) { - return partition_result.Error().Status(); - } - ctx.Dump().Done(); - DumpPartitionResult(ctx.Dump(), *partition_result); - - auto& new_subgraphs = partition_result->second; - model.TransferSubgraphsFrom(std::move(new_subgraphs)); - - ctx.Dump().Start("Serializing model"); - auto serialized = SerializeModel(std::move(model)); - DumpModelStats(ctx.Dump(), *serialized); - ctx.Dump().Done(); - - ctx.Dump().Start("Verifying flatbuffer"); - LITERT_ENSURE(VerifyFlatbuffer(serialized->Span()), - kLiteRtStatusErrorInvalidFlatbuffer, - "Failed to invalidate flatbuffer"); - ctx.Dump().Done(); - - ctx.Dump().Start("Writing to out"); - serialized->WriteStr(ctx.Out()); - ctx.Dump().Done(); - - return kLiteRtStatusOk; -} - -// -// COMPILE Command -// - -LiteRtStatus ValidateCompileRun(const ApplyPluginRun& run) { - LITERT_ENSURE_CONFIG(!run.lib_search_paths.empty()); - LITERT_ENSURE_CONFIG(run.model.has_value()); - LITERT_ENSURE_CONFIG(run.soc_manufacturer.has_value()); - // TODO: implement multi target compilation. - LITERT_ENSURE_SUPPORTED(run.soc_models.size() == 1, - "Multi target compilation not implemented."); - return kLiteRtStatusOk; -} - -LiteRtStatus Compile(Context& ctx) { - auto model_wrap = LoadModel(ctx); - if (!model_wrap) { - return model_wrap.Error().Status(); - } - auto& model = *model_wrap->Get(); - - auto plugin = LoadPlugin(ctx); - if (!plugin) { - return plugin.Error().Status(); - } - - ctx.Dump().Start("Compiling"); - DumpCompilationRequest(ctx.Dump(), ctx.SocModelTarget(), model.NumSubgraphs(), - ctx.Flags()); - plugin->SetFlags(ctx.Flags()); - auto compilation_result = plugin->Compile(&model, ctx.SocModelTarget()); - if (!compilation_result) { - ctx.Dump().Fail(); - return compilation_result.Error().Status(); - } - - auto num_byte_code = compilation_result->NumByteCodeModules(); - if (*num_byte_code < 1) { - ctx.Dump().Fail(); - return compilation_result.Error().Status(); - } - if (!num_byte_code) { - ctx.Dump().Fail(); - return compilation_result.Error().Status(); - } - for (int i = 0; i < ctx.NumOuts(); ++i) { - auto byte_code = compilation_result->ByteCode(i); - if (!byte_code) { - ctx.Dump().Fail(); - return compilation_result.Error().Status(); - } - auto num_calls = compilation_result->NumCalls(); - if (!num_calls) { - ctx.Dump().Fail(); - return compilation_result.Error().Status(); - } - - DumpCompilationResult(ctx.Dump(), byte_code->Size(), *num_calls); - byte_code->WriteStr(ctx.Out(i)); - } - ctx.Dump().Done(); - - return kLiteRtStatusOk; -} - -// -// APPLY Command -// - -LiteRtStatus ValidateApplyRun(const ApplyPluginRun& run) { - LITERT_ENSURE_CONFIG(!run.lib_search_paths.empty()); - LITERT_ENSURE_CONFIG(run.model.has_value()); - LITERT_ENSURE_CONFIG(run.soc_manufacturer.has_value()); - LITERT_ENSURE_CONFIG(run.outs.size() == run.soc_models.size()); - // TODO: implement multi target compilation. - LITERT_ENSURE_SUPPORTED(run.soc_models.size() == 1, - "Multi target compilation not implemented."); - return kLiteRtStatusOk; -} - -LiteRtStatus Apply(Context& ctx) { - auto model_wrap = LoadModel(ctx); - if (!model_wrap) { - return model_wrap.Error().Status(); - } - auto& model = *model_wrap->Get(); - - auto plugin = LoadPlugin(ctx); - if (!plugin) { - return plugin.Error().Status(); - } - - ctx.Dump().Start("Applying plugin"); - plugin->SetFlags(ctx.Flags()); - if (auto status = litert::internal::ApplyPlugin( - *plugin, model, ctx.SocModelTarget(), ctx.Run().subgraphs); - !status) { - LITERT_LOG(LITERT_ERROR, "%s", status.Error().Message().c_str()); - return status.Error().Status(); - } - ctx.Dump().Done(); - - ctx.Dump().Start("Serializing model"); - auto serialized = SerializeModel(std::move(model)); - DumpModelStats(ctx.Dump(), *serialized); - ctx.Dump().Done(); - - ctx.Dump().Start("Verifying flatbuffer"); - LITERT_ENSURE(VerifyFlatbuffer(serialized->Span()), - kLiteRtStatusErrorInvalidFlatbuffer, - "Failed to invalidate flatbuffer"); - ctx.Dump().Done(); - - ctx.Dump().Start("Writing to out"); - serialized->WriteStr(ctx.Out()); - ctx.Dump().Done(); - - return kLiteRtStatusOk; -} - -} // namespace - -LiteRtStatus ApplyPlugin(ApplyPluginRun::Ptr run) { - Context context(std::move(run)); - DumpPreamble(context.Dump()); - - switch (context.Cmd()) { - case ApplyPluginRun::Cmd::INFO: - if (auto stat = ValidateInfoRun(context.Run()); stat != kLiteRtStatusOk) { - context.Dump().Labeled() << "Invalid arguments for INFO command\n"; - return stat; - } - return Info(context); - - case ApplyPluginRun::Cmd::PARTITION: - if (auto stat = ValidatePartitionRun(context.Run()); - stat != kLiteRtStatusOk) { - context.Dump().Labeled() << "Invalid arguments for PARTITION command\n"; - return stat; - } - return Partition(context); - - case ApplyPluginRun::Cmd::COMPILE: - if (auto stat = ValidateCompileRun(context.Run()); - stat != kLiteRtStatusOk) { - context.Dump().Labeled() << "Invalid arguments for COMPILE command\n"; - return stat; - } - return Compile(context); - - case ApplyPluginRun::Cmd::APPLY: - if (auto stat = ValidateApplyRun(context.Run()); - stat != kLiteRtStatusOk) { - context.Dump().Labeled() << "Invalid arguments for APPLY command\n"; - return stat; - } - return Apply(context); - - case ApplyPluginRun::Cmd::NOOP: - - if (auto stat = ValidateNoopRun(context.Run()); stat != kLiteRtStatusOk) { - context.Dump().Labeled() << "Invalid arguments for NOP command\n"; - return stat; - } - return Noop(context); - - default: - return kLiteRtStatusErrorInvalidArgument; - } - - return kLiteRtStatusOk; -} - -} // namespace litert::tools diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin.h b/tensorflow/lite/experimental/litert/tools/apply_plugin.h deleted file mode 100644 index 8d105836eb8422..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/apply_plugin.h +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_APPLY_PLUGIN_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_APPLY_PLUGIN_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h" -#include "tensorflow/lite/experimental/litert/tools/outstream.h" - -namespace litert::tools { - -using ::litert::internal::CompilerFlags; - -struct ApplyPluginRun { - // NOTE: All StrFlagT are expected to have static storage duration. - using Ptr = std::unique_ptr; - - // A specific command implemented by the tool to run. - enum class Cmd { - // Displays info about all plugins found in given search paths. - // - // FLAG SEMANTICS: - // "lib_search_paths": Required, at least one. - // "model": Ignored. - // "soc_manufacturer": Optional, filters plugins to display. - // "soc_models": Ignored. - // "outs": Required, must be size one. - // "dump_out": Optional. - INFO, - - // Does nothing and simply de-serializes and re-serializes the given model. - // This is intended for testing and internal debugging only. - // - // FLAG SEMANTICS: - // "lib_search_paths": Ignored. - // "model": Required. - // "soc_manufacturer": Ignored. - // "soc_models": Ignored. - // "outs": Required, must be size one. - // "dump_out": Optional. - NOOP, - - // Runs the entire end to end flow. This is the standard compiler plugin - // usage. A seperate compilation step will occur for each sco_model tag that - // is supported by the loaded plugin, and a new output model will be - // generated for each. Partitioning is invariant accross different soc_model - // targets from the same manufacturer, so only one compilation step will - // occur even if multiple targest are requested. - // - // FLAG SEMANTICS: - // "lib_search_paths": Required, at least one. - // "model": Required. - // "soc_manufacturer": Required. - // "soc_models": Required, at least one. - // "outs": Required, must be size equal to "soc_models". - // "dump_out": Optional. - // - // TODO: Support multi target compilation. - APPLY, - - // Only run the partiion step and skip compilation. Writes a ".tflite" model - // to "out" where selected partitions are manifested as new standard - // flatbuffer subgraphs added to the input model. - // The partitions original locations are replaced with a single custom op - // the contains an identifier to the corresponding partition (new subgraph). - // This is intended for testing and development. - // - // FLAG SEMANTICS: - // "lib_search_paths": Required, at least one. - // "model": Required. - // "soc_manufacturer": Required. - // "soc_models": Ignored. - // "outs": Required, must be size one. - // "dump_out": Optional. - PARTITION, - - // Skip partitioning and run the entire input model through compilation - // directly. Fails if any ops in the input model are unsupported by the - // plugin. Writes the raw compiled result to the "out" stream without any - // wrapping flatbuffer. Runs multi-target compilation as in "APPLY", - // Intended for testing and development. - // - // FLAG SEMANTICS: - // "lib_search_paths": Required, at least one. - // "model": Required. - // "soc_manufacturer": Required. - // "soc_models": Required, at least one. - // "out": Required, must be size equal to "soc_models". - // "dump_out": Optional. - // - // TODO: Support multi target compilation. - COMPILE, - }; - - // A command to run, see above. - Cmd cmd; - - // Collection of paths on local files system dictating where the tool should - // look for suitable LiteRtCompilerPlugin shared libraries. The tool will - // select the first ".so" file found with prefix "libLiteRtPlugin" that has - // the "soc_manufacturer" tag passed. Providing more than one plugin shared - // library for the same manufacturer results in an error. - std::vector lib_search_paths = {}; - - // Path to ".tflite" model the tool should operated on. - std::optional model = {}; - - // A tag representing a manufacturer the tool should target for compilation. - // This is used to select the appropriate plugin if multiple plugins are found - // in "lib_search_paths". - std::optional soc_manufacturer = {}; - - // Collection of soc models tags the tool should target for compilation. - std::vector soc_models = {}; - - // Where the tool should write its result file(s) to. If the command runs - // compilation, an "out" stream should be passed for each "soc_model" target - // requested for compilation. Output for the "ith" target will be written to - // the "ith" outs stream. - std::vector outs = {std::cout}; - - // Where to direct logging for this run. Passing nullopt here indicates - // "silent" behavior and should only be used when this tool is part of a - // larger pipeline like an end2end test. - UserStream dump_out; - - // Compiler flags to pass to the plugin. Only relevant for "APPLY" and - // "COMPILE" commands. - CompilerFlags compiler_flags; - - // If provided, only the subgraphs with the given indices are applied with the - // plugin. - absl::flat_hash_set subgraphs = {}; -}; - -LiteRtStatus ApplyPlugin(ApplyPluginRun::Ptr run); - -} // namespace litert::tools - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_APPLY_PLUGIN_H_ diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin_main.cc b/tensorflow/lite/experimental/litert/tools/apply_plugin_main.cc deleted file mode 100644 index 261ecd494e855f..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/apply_plugin_main.cc +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expruns or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/Support/CommandLine.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_flags.h" -#include "tensorflow/lite/experimental/litert/tools/apply_plugin.h" -#include "tensorflow/lite/experimental/litert/tools/outstream.h" - -using ::litert::tools::ApplyPlugin; -using ::litert::tools::ApplyPluginRun; -using ::litert::tools::UserStream; - -// NOLINTNEXTLINE -static llvm::cl::opt cmd( - llvm::cl::Positional, - llvm::cl::desc("Routine to run (apply, partition, compile, info, noop)."), - llvm::cl::init("partition")); - -// NOLINTNEXTLINE -static llvm::cl::opt model( - "model", llvm::cl::desc("Path to flatbuffer file."), llvm::cl::init("")); - -// TODO: b/366821557 - Support path to pre-compiled plugin in flags. -// NOLINTNEXTLINE -static llvm::cl::opt soc_manufacturer( - "soc_man", - llvm::cl::desc("String identifier of SoC manufacturer (e.g., GoogleTensor, " - "Qualcomm)."), - llvm::cl::init("ExampleSocManufacturer")); - -// TODO: Support multi target compilation. -// NOLINTNEXTLINE -static llvm::cl::opt soc_model("soc_model", - llvm::cl::desc("Target SoC model."), - llvm::cl::init("ExampleSocModel")); - -// NOLINTNEXTLINE -static llvm::cl::list libs( - "libs", - llvm::cl::desc("List of directories in which to search for suitable " - "compiler plugin shared libraries."), - llvm::cl::list_init(llvm::ArrayRef{ - "third_party/tensorflow/lite/experimental/litert/vendors/examples", - "third_party/tensorflow/lite/experimental/litert/vendors/qualcomm/" - "compiler", - "third_party/tensorflow/lite/experimental/litert/vendors/mediatek/" - "compiler", - "third_party/tensorflow/lite/experimental/litert/vendors/" - "google_tensor/compiler"})); - -// NOLINTNEXTLINE -static llvm::cl::list outs( - "o", - llvm::cl::desc("Path to files for output, \"-\" indicates standard out, " - "\"--\" for standard err, \"none\" for null stream."), - llvm::cl::list_init(llvm::ArrayRef{"-"})); - -// NOLINTNEXTLINE -static llvm::cl::opt err( - "err", - llvm::cl::desc("Path to file for err output, \"-\" indicates standard out, " - "\"--\" for standard err, \"none\" for null stream."), - llvm::cl::init("--")); - -// NOLINTNEXTLINE -static llvm::cl::opt compiler_flags( - "compiler-flags", - llvm::cl::desc("List of comma separated (no space) compiler flags. Flags " - "may be key-value pairs " - "in the format of \"key=value\", or just \"key\". E.g. " - "\"--compiler-flags=key1=value1,key2\"")); - -// NOLINTNEXTLINE -static llvm::cl::list subgraphs( - "subgraphs", - llvm::cl::desc("If provides, only the subgraphs with the given indices " - "are applied with the plugin."), - llvm::cl::list_init(llvm::ArrayRef{})); - -ApplyPluginRun::Ptr ParseFlags() { - auto res = std::make_unique(); - - if (!model.empty()) { - res->model = model; - } - - res->compiler_flags = *litert::internal::ParseCompilerFlags(compiler_flags); - - res->soc_manufacturer = soc_manufacturer; - res->soc_models.push_back(soc_model); - - res->lib_search_paths.assign(libs.begin(), libs.end()); - - if (cmd == "apply") { - res->cmd = ApplyPluginRun::Cmd::APPLY; - } else if (cmd == "partition") { - res->cmd = ApplyPluginRun::Cmd::PARTITION; - } else if (cmd == "compile") { - res->cmd = ApplyPluginRun::Cmd::COMPILE; - } else if (cmd == "info") { - res->cmd = ApplyPluginRun::Cmd::INFO; - } else if (cmd == "noop") { - res->cmd = ApplyPluginRun::Cmd::NOOP; - } else { - return nullptr; - } - - for (auto subgraph_idx : subgraphs) { - res->subgraphs.insert(subgraph_idx); - } - - return res; -} - -int main(int argc, char* argv[]) { - llvm::cl::ParseCommandLineOptions(argc, argv); - - auto run = ParseFlags(); - if (run == nullptr) { - return 1; - } - - run->outs.clear(); - std::vector> oss; - for (const auto& out : outs) { - oss.push_back(std::make_unique( - UserStream::MakeFromFlag(out))); - run->outs.push_back(oss.back()->Get()); - } - - run->dump_out = UserStream::MakeFromFlag(err); - - run->dump_out.Get() << absl::StreamFormat( - "CMD: %s\nMODEL: %s\nSOC_MANUFACTURER: %s\nSOC_MODEL: %s\n", cmd, model, - soc_manufacturer, soc_model); - - return ApplyPlugin(std::move(run)); -} diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc b/tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc deleted file mode 100644 index b86bc5ec19f874..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/tools/apply_plugin.h" - -#include -#include -#include -#include - -#include -#include -#include "absl/log/absl_check.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/build_stamp.h" -#include "tensorflow/lite/experimental/litert/core/dispatch_op_schema.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace litert::tools { -namespace { - -using ::litert::internal::kLiteRtBuildStampKey; -using ::litert::internal::ParseBuildStamp; -using ::testing::HasSubstr; -using ::testing::litert::IsError; - -static constexpr absl::string_view kPluginSearchPath = - "third_party/tensorflow/lite/experimental/litert/vendors/examples"; - -static constexpr absl::string_view kSocManufacturer = "ExampleSocManufacturer"; - -static constexpr absl::string_view kSocModel = "ExampleSocModel"; - -absl::string_view TestModelPath(absl::string_view filename) { - static char kModelPath[512] = {}; - const auto model_path = ::litert::testing::GetTestFilePath(filename); - ABSL_CHECK(model_path.size() < 512); - model_path.copy(kModelPath, model_path.size(), 0); - return kModelPath; -} - -ApplyPluginRun::Ptr MakeBaseRun( - ApplyPluginRun::Cmd cmd, absl::string_view model_path = "one_mul.tflite") { - auto run = std::make_unique(); - run->cmd = cmd; - run->lib_search_paths.push_back(kPluginSearchPath); - run->model.emplace(TestModelPath(model_path)); - run->soc_manufacturer.emplace(kSocManufacturer); - run->soc_models.push_back(kSocModel); - run->outs.clear(); - return run; -} - -TEST(TestApplyPluginTool, TestInfoBadConfig) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::INFO); - run->lib_search_paths.clear(); - EXPECT_THAT(ApplyPlugin(std::move(run)), - IsError(kLiteRtStatusErrorInvalidToolConfig)); -} - -TEST(TestApplyPluginTool, TestInfo) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::INFO); - std::stringstream out; - run->outs.push_back(out); - LITERT_ASSERT_OK(ApplyPlugin(std::move(run))); - EXPECT_THAT(out.str(), - ::testing::HasSubstr( - "< LiteRtCompilerPlugin > \"ExampleSocManufacturer\" | " - "\"ExampleSocModel\"")); -} - -TEST(TestApplyPluginTool, TestNoopBadConfig) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::NOOP); - run->model.reset(); - EXPECT_THAT(ApplyPlugin(std::move(run)), - IsError(kLiteRtStatusErrorInvalidToolConfig)); -} - -TEST(TestApplyPluginTool, TestNoop) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::NOOP); - std::stringstream out; - run->outs.push_back(out); - LITERT_ASSERT_OK(ApplyPlugin(std::move(run))); - - auto model = Model::CreateFromBuffer( - BufferRef(out.view().data(), out.view().size())); - EXPECT_EQ(model->Get()->NumSubgraphs(), 1); -} - -TEST(TestApplyPluginTool, TestPartitionBadConfig) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::PARTITION); - run->model.reset(); - EXPECT_THAT(ApplyPlugin(std::move(run)), - IsError(kLiteRtStatusErrorInvalidToolConfig)); -} - -TEST(TestApplyPluginTool, TestPartition) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::PARTITION); - std::stringstream out; - run->outs.push_back(out); - LITERT_ASSERT_OK(ApplyPlugin(std::move(run))); - EXPECT_FALSE(out.str().empty()); -} - -TEST(TestApplyPluginTool, TestCompileBadConfig) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::COMPILE); - run->model.reset(); - EXPECT_THAT(ApplyPlugin(std::move(run)), - IsError(kLiteRtStatusErrorInvalidToolConfig)); -} - -TEST(TestApplyPluginTool, TestCompile) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::COMPILE); - std::stringstream out; - run->outs.push_back(out); - LITERT_ASSERT_OK(ApplyPlugin(std::move(run))); - EXPECT_FALSE(out.str().empty()); - EXPECT_THAT(out.str(), HasSubstr("Partition_0_with_1_muls")); -} - -TEST(TestApplyPluginTool, TestApplyBadConfig) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::APPLY); - run->model.reset(); - EXPECT_THAT(ApplyPlugin(std::move(run)), - IsError(kLiteRtStatusErrorInvalidToolConfig)); -} - -TEST(TestApplyPluginTool, TestApply) { - auto run = MakeBaseRun(ApplyPluginRun::Cmd::APPLY); - std::stringstream out; - run->outs.push_back(out); - LITERT_ASSERT_OK(ApplyPlugin(std::move(run))); - - const auto out_str = out.str(); - BufferRef serialized(out_str.data(), out_str.size()); - - auto model = Model::CreateFromBuffer(serialized); - EXPECT_EQ(model->Get()->NumSubgraphs(), 1); - - { - auto stamp_buffer = model->Get()->FindMetadata(kLiteRtBuildStampKey); - auto stamp = ParseBuildStamp(*stamp_buffer); - auto [man, soc_model] = *stamp; - EXPECT_EQ(man, kSocManufacturer); - EXPECT_EQ(soc_model, kSocModel); - } - - auto* op = model->Get()->MainSubgraph()->Ops().front(); - ASSERT_EQ(op->OpCode(), kLiteRtOpCodeTflCustom); - - const auto options = internal::GetDispatchOpOptions(op->CustomOptions()); - const auto& [size, offset, name] = options; - EXPECT_EQ(name, "Partition_0"); - ASSERT_LE(offset + size, serialized.Size()); - - EXPECT_THAT(serialized.StrView().substr(offset, size), - HasSubstr("Partition_0_with_1_muls")); -} - -TEST(TestApplyPluginTool, TestCompileToMultiByteCode) { - auto run = - MakeBaseRun(ApplyPluginRun::Cmd::COMPILE, "multi_subgraph_mul.tflite"); - std::stringstream out_0; - std::stringstream out_1; - run->outs.push_back(out_0); - run->outs.push_back(out_1); - - LITERT_ASSERT_OK(ApplyPlugin(std::move(run))); - EXPECT_FALSE(out_0.str().empty()); - EXPECT_FALSE(out_1.str().empty()); - EXPECT_THAT(out_0.str(), HasSubstr("Partition_0_with_1_muls")); - EXPECT_THAT(out_1.str(), HasSubstr("Partition_1_with_1_muls")); -} - -} // namespace -} // namespace litert::tools diff --git a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model.cc b/tensorflow/lite/experimental/litert/tools/benchmark_litert_model.cc deleted file mode 100644 index b82fdc18fc0228..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model.cc +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/lite/experimental/litert/tools/benchmark_litert_model.h" - -#include -#include -#include -#include - -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compilation_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" - -namespace litert::benchmark { -namespace { -using ::litert::CompilationOptions; -using ::litert::CompiledModel; -using ::litert::TensorBuffer; - -CompilationOptions CreateCompiledModelOptions(const BenchmarkParams& params) { - auto use_gpu = params.Get("use_gpu"); - CompilationOptions compilation_options = - *litert::CompilationOptions::Create(); - if (use_gpu) { - compilation_options.SetHardwareAccelerators( - LiteRtHwAccelerators::kLiteRtHwAcceleratorGpu); - } - return compilation_options; -} -} // namespace - -TfLiteStatus BenchmarkLiteRtModel::Init() { - std::string fd_or_graph_path = params_.Get("graph"); - LITERT_LOG(LITERT_INFO, "Loading model from: %s", fd_or_graph_path.c_str()); - model_ = *litert::Model::CreateFromFile(fd_or_graph_path); - if (!model_) { - LITERT_LOG(LITERT_ERROR, "Failed to load model: %s", - fd_or_graph_path.c_str()); - return kTfLiteError; - } - - auto env = Environment::Create({}); - if (!env) { - LITERT_LOG(LITERT_ERROR, "Failed to create litert environment."); - return kTfLiteError; - } - - auto compilation_options = CreateCompiledModelOptions(params_); - auto compiled_model_result = - litert::CompiledModel::Create(*env, model_, compilation_options); - if (!compiled_model_result) { - LITERT_LOG(LITERT_ERROR, "Failed to create compiled model."); - return kTfLiteError; - } - - compiled_model_ = std::make_unique( - std::move(*compiled_model_result)); - auto signature = params_.Get("signature_to_run_for"); - auto input_buffers_result = compiled_model_->CreateInputBuffers(signature); - if (!input_buffers_result) { - LITERT_LOG(LITERT_ERROR, "Failed to create input buffers."); - return kTfLiteError; - } - input_buffers_ = std::make_unique>( - std::move(*input_buffers_result)); - - auto output_buffers_result = compiled_model_->CreateOutputBuffers(signature); - if (!output_buffers_result) { - LITERT_LOG(LITERT_ERROR, "Failed to create output buffers."); - return kTfLiteError; - } - output_buffers_ = std::make_unique>( - std::move(*output_buffers_result)); - - return kTfLiteOk; -} -} // namespace litert::benchmark diff --git a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model.h b/tensorflow/lite/experimental/litert/tools/benchmark_litert_model.h deleted file mode 100644 index 8534efddee78ab..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model.h +++ /dev/null @@ -1,151 +0,0 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_BENCHMARK_LITERT_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_BENCHMARK_LITERT_MODEL_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/numbers.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" -#include "tensorflow/lite/tools/benchmark/benchmark_model.h" -#include "tensorflow/lite/tools/benchmark/benchmark_params.h" -#include "tensorflow/lite/tools/utils.h" - -namespace litert { -namespace benchmark { - -using ::litert::CompiledModel; -using ::litert::Environment; -using ::litert::Model; -using ::litert::TensorBuffer; -using ::tflite::benchmark::BenchmarkModel; -using ::tflite::benchmark::BenchmarkParam; -using ::tflite::benchmark::BenchmarkParams; -using ::tflite::utils::InputTensorData; - -class BenchmarkLiteRtModel : public BenchmarkModel { - public: - BenchmarkLiteRtModel() = default; - explicit BenchmarkLiteRtModel(BenchmarkParams params) - : BenchmarkModel(std::move(params)) {} - ~BenchmarkLiteRtModel() override = default; - static BenchmarkParams DefaultParams() { - BenchmarkParams default_params = BenchmarkModel::DefaultParams(); - default_params.AddParam("graph", BenchmarkParam::Create("")); - default_params.AddParam("signature_to_run_for", - BenchmarkParam::Create("")); - default_params.AddParam("use_xnnpack", BenchmarkParam::Create(true)); - default_params.AddParam("use_gpu", BenchmarkParam::Create(false)); - - return default_params; - } - - TfLiteStatus Init() override; - - int64_t MayGetModelFileSize() override { - std::string fd_or_graph_path = params_.Get("graph"); - // Path can be one of the following: - // 1) File descriptor path: path must be in the format of - // "fd:%model_fd%:%model_offset%:%model_size%". - // 2) File path: path to the model file. - // Please see tensorflow/lite/tools/model_loader.h for more information. - std::vector parts = - absl::StrSplit(fd_or_graph_path, ':'); - if (!parts.empty() && parts[0] == "fd") { - int64_t model_size = -1; - if (parts.size() != 4 || !absl::SimpleAtoi(parts[3], &model_size)) { - LITERT_LOG(LITERT_ERROR, "Failed to parse model file size: %s", - fd_or_graph_path.c_str()); - } - return model_size; - } - std::ifstream in_file(fd_or_graph_path, std::ios::binary | std::ios::ate); - return in_file.tellg(); - } - - TfLiteStatus RunImpl() override { - if (!compiled_model_) { - LITERT_LOG(LITERT_ERROR, "Compiled model not initialized"); - return kTfLiteError; - } - auto signature = params_.Get("signature_to_run_for"); - if (compiled_model_->Run(signature, *input_buffers_, *output_buffers_)) { - return kTfLiteOk; - } else { - LITERT_LOG(LITERT_ERROR, "Run failed"); - return kTfLiteError; - } - } - - uint64_t ComputeInputBytes() override { - uint64_t total_bytes = 0; - for (const auto& buffer : *input_buffers_) { - total_bytes += *buffer.Size(); - } - return total_bytes; - } - - InputTensorData CreateRandomTensorData(const litert::TensorBuffer& t, - std::string name) { - float low_range = 0; - float high_range = 0; - tflite::utils::GetDataRangesForType( - static_cast(t.TensorType()->ElementType()), &low_range, - &high_range); - return tflite::utils::CreateRandomTensorData( - name, static_cast(t.TensorType()->ElementType()), *t.Size(), - low_range, high_range); - } - - TfLiteStatus PrepareInputData() override { - int index = 0; - for (auto& buffer : *input_buffers_) { - auto t_data = - CreateRandomTensorData(buffer, "input_" + std::to_string(index)); - buffer.Write(absl::MakeSpan( - reinterpret_cast(t_data.data.get()), t_data.bytes)); - ++index; - } - return kTfLiteOk; - } - - TfLiteStatus ResetInputsAndOutputs() override { return kTfLiteOk; } - - private: - Model model_; - std::unique_ptr compiled_model_; - std::unique_ptr> input_buffers_; - std::unique_ptr> output_buffers_; -}; - -} // namespace benchmark -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_BENCHMARK_LITERT_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model_main.cc b/tensorflow/lite/experimental/litert/tools/benchmark_litert_model_main.cc deleted file mode 100644 index 8cf1891085f5b9..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model_main.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/experimental/litert/tools/benchmark_litert_model.h" -#include "tensorflow/lite/tools/logging.h" - -namespace litert::benchmark { - -int Main(int argc, char** argv) { - TFLITE_LOG(INFO) << "STARTING!"; - BenchmarkLiteRtModel benchmark; - if (benchmark.Run(argc, argv) != kTfLiteOk) { - TFLITE_LOG(ERROR) << "Benchmarking failed."; - return EXIT_FAILURE; - } - return EXIT_SUCCESS; -} -} // namespace litert::benchmark - -int main(int argc, char** argv) { return litert::benchmark::Main(argc, argv); } diff --git a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model_test.cc b/tensorflow/lite/experimental/litert/tools/benchmark_litert_model_test.cc deleted file mode 100644 index 08634b7aff4d41..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/benchmark_litert_model_test.cc +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/experimental/litert/tools/benchmark_litert_model.h" - -#include -#include - -#include -#include - -#include -#include "tensorflow/lite/core/c/c_api_types.h" -#include "tensorflow/lite/tools/benchmark/benchmark_model.h" -#include "tensorflow/lite/tools/benchmark/benchmark_params.h" - -namespace litert { -namespace benchmark { -namespace { -using ::litert::benchmark::BenchmarkLiteRtModel; -using ::tflite::benchmark::BenchmarkListener; -using ::tflite::benchmark::BenchmarkParams; -using ::tflite::benchmark::BenchmarkResults; - -static constexpr char kModelPath[] = - "third_party/tensorflow/lite/experimental/litert/test/testdata/" - "mobilenet_v2_1.0_224.tflite"; -static constexpr char kSignatureToRunFor[] = ""; - -class TestBenchmarkListener : public BenchmarkListener { - public: - void OnBenchmarkEnd(const BenchmarkResults& results) override { - results_ = results; - } - - BenchmarkResults results_; -}; - -TEST(BenchmarkLiteRtModelTest, GetModelSizeFromPathSucceeded) { - BenchmarkParams params = BenchmarkLiteRtModel::DefaultParams(); - params.Set("graph", kModelPath); - params.Set("signature_to_run_for", kSignatureToRunFor); - params.Set("num_runs", 1); - params.Set("warmup_runs", 0); - params.Set("use_xnnpack", true); - params.Set("use_gpu", false); - BenchmarkLiteRtModel benchmark = BenchmarkLiteRtModel(std::move(params)); - TestBenchmarkListener listener; - benchmark.AddListener(&listener); - - benchmark.Run(); - - EXPECT_GE(listener.results_.model_size_mb(), 0); -} - -TEST(BenchmarkLiteRtModelTest, GPUAcceleration) { - // MSAN does not support GPU tests. -#if defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) - GTEST_SKIP() << "GPU tests are not supported In msan"; -#endif - BenchmarkParams params = BenchmarkLiteRtModel::DefaultParams(); - params.Set("graph", kModelPath); - params.Set("signature_to_run_for", kSignatureToRunFor); - params.Set("use_xnnpack", false); - params.Set("use_gpu", true); - - BenchmarkLiteRtModel benchmark = BenchmarkLiteRtModel(std::move(params)); - - EXPECT_EQ(benchmark.Run(), kTfLiteOk); -} - -} // namespace -} // namespace benchmark -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/tools/dump.cc b/tensorflow/lite/experimental/litert/tools/dump.cc deleted file mode 100644 index e6c84d631773d5..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/dump.cc +++ /dev/null @@ -1,442 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/tools/dump.h" - -#include - -#ifndef __ANDROID__ -#if __has_include() -#include -#endif -#endif - -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -namespace { - -static constexpr int kMaxDisplayCount = 16; - -void DumpNode(const LiteRtTensorT& tensor, std::ostream& out) { - switch (tensor.Type().first) { - case kLiteRtRankedTensorType: - Dump(tensor.Type().second.ranked_tensor_type, out); - break; - case kLiteRtUnrankedTensorType: - Dump(tensor.Type().second.unranked_tensor_type.element_type, out); - break; - default: - out << "UKNOWN_TENSOR_TYPE" << tensor.Type().first; - } - Dump(tensor.Qparams(), out); -} - -void DumpNode(const LiteRtOpT& op, std::ostream& out) { - Dump(op.OpCode(), out); -} - -void DumpSignature(const std::vector& ins, - const std::vector& outs, std::ostream& out) { - out << "("; - for (auto it = ins.begin(); it < ins.end(); ++it) { - DumpNode(**it, out); - if (it != ins.end() - 1) { - out << ", "; - } - } - out << ")"; - - out << " -> "; - const bool paren_outs = outs.size() != 1; - if (paren_outs) { - out << "("; - } - for (auto it = outs.begin(); it < outs.end(); ++it) { - DumpNode(**it, out); - if (it != outs.end() - 1) { - out << ", "; - } - } - if (paren_outs) { - out << ")"; - } -} - -} // namespace - -void Dump(LiteRtOpCode code, std::ostream& out) { - switch (code) { - case kLiteRtOpCodeTflAdd: - out << "TFL_ADD"; - break; - case kLiteRtOpCodeTflMul: - out << "TFL_MUL"; - break; - case kLiteRtOpCodeTflCustom: - out << "TFL_CUSTOM_OP"; - break; - case kLiteRtOpCodeTflSlice: - out << "TFL_SLICE"; - break; - case kLiteRtOpCodeTflDiv: - out << "TFL_DIV"; - break; - case kLiteRtOpCodeTflRsqrt: - out << "TFL_RSQRT"; - break; - case kLiteRtOpCodeTflTanh: - out << "TFL_TANH"; - break; - case kLiteRtOpCodeTflSub: - out << "TFL_SUB"; - break; - case kLiteRtOpCodeTflReshape: - out << "TFL_RESHAPE"; - break; - case kLiteRtOpCodeTflBatchMatmul: - out << "TFL_BATCH_MATMUL"; - break; - case kLiteRtOpCodeTflSum: - out << "TFL_SUM"; - break; - case kLiteRtOpCodeTflConcatenation: - out << "TFL_CONCATENATION"; - break; - case kLiteRtOpCodeTflSoftmax: - out << "TFL_SOFTMAX"; - break; - case kLiteRtOpCodeTflCast: - out << "TFL_CAST"; - break; - case kLiteRtOpCodeTflTranspose: - out << "TFL_TRANSPOSE"; - break; - case kLiteRtOpCodeTflSin: - out << "TFL_SIN"; - break; - case kLiteRtOpCodeTflCos: - out << "TFL_COS"; - break; - case kLiteRtOpCodeTflSelect: - out << "TFL_SELECT"; - break; - case kLiteRtOpCodeTflSelectV2: - out << "TFL_SELECT_V2"; - break; - case kLiteRtOpCodeTflFullyConnected: - out << "TFL_FULLY_CONNECTED"; - break; - case kLiteRtOpCodeTflEmbeddingLookup: - out << "TFL_EMBEDDING_LOOKUP"; - break; - case kLiteRtOpCodeTflLogicalAnd: - out << "TFL_LOGICAL_AND"; - break; - case kLiteRtOpCodeTflLess: - out << "TFL_LESS"; - break; - case kLiteRtOpCodeTflGreater: - out << "TFL_GREATER"; - break; - case kLiteRtOpCodeTflGelu: - out << "TFL_GELU"; - break; - case kLiteRtOpCodeTflDynamicUpdateSlice: - out << "TFL_DYNAMIC_UPDATE_SLICE"; - break; - case kLiteRtOpCodeTflPack: - out << "TFL_PACK"; - break; - case kLiteRtOpCodeTflQuantize: - out << "TFL_QUANTIZE"; - break; - case kLiteRtOpCodeTflLeakyRelu: - out << "TFL_LEAKY_RELU"; - break; - case kLiteRtOpCodeTflHardSwish: - out << "TFL_HARD_SWISH"; - break; - case kLiteRtOpCodeTflAveragePool2d: - out << "AVERAGE_POOL_2D"; - break; - case kLiteRtOpCodeTflDepthwiseConv2d: - out << "DEPTHWISE_CONV_2D"; - break; - case kLiteRtOpCodeTflSpaceToDepth: - out << "SPACE_TO_DEPTH"; - break; - case kLiteRtOpCodeTflDepthToSpace: - out << "DEPTH_TO_SPACE"; - break; - case kLiteRtOpCodeTflConv2d: - out << "CONV_2D"; - break; - case kLiteRtOpCodeTflResizeBilinear: - out << "RESIZE_BILINEAR"; - break; - case kLiteRtOpCodeTflMinimum: - out << "MINIMUM"; - break; - case kLiteRtOpCodeTflMaximum: - out << "MAXIMUM"; - break; - case kLiteRtOpCodeTflResizeNearestNeighbor: - out << "RESIZE_NEAREST_NEIGHBOR"; - break; - case kLiteRtOpCodeTflRelu: - out << "TFL_RELU"; - break; - case kLiteRtOpCodeTflRelu6: - out << "TFL_RELU6"; - break; - default: - out << "UKNOWN_OP_CODE: " << code; - break; - } -}; - -// Dump details about the given LiteRtElementType to the given stream. -void Dump(LiteRtElementType type, std::ostream& out) { - switch (type) { - case kLiteRtElementTypeFloat32: - out << "f32"; - break; - case kLiteRtElementTypeInt32: - out << "i32"; - break; - case kLiteRtElementTypeFloat64: - out << "f64"; - break; - case kLiteRtElementTypeInt64: - out << "i64"; - break; - case kLiteRtElementTypeFloat16: - out << "f16"; - break; - case kLiteRtElementTypeInt16: - out << "i16"; - break; - case kLiteRtElementTypeInt8: - out << "i8"; - break; - case kLiteRtElementTypeUInt8: - out << "ui8"; - break; - case kLiteRtElementTypeBool: - out << "i1"; - break; - default: - out << "UKNNOWN_ELEMENT_TYPE: " << type; - } -} - -void Dump(const LiteRtRankedTensorType& type, std::ostream& out) { - out << "<"; - for (int i = 0; i < type.layout.rank; ++i) { - out << type.layout.dimensions[i] << "x"; - } - Dump(type.element_type, out); - out << ">"; -} - -void Dump(const LiteRtTensorT& tensor, std::ostream& out) { - out << "LiteRtTensor : "; - DumpNode(tensor, out); - out << " [ "; - if (tensor.DefiningOp() == nullptr) { - out << "*"; - } else { - DumpNode(*tensor.DefiningOp(), out); - } - out << " ] "; - - out << "("; - for (auto it = tensor.Users().begin(); it < tensor.Users().end(); ++it) { - DumpNode(**it, out); - if (it != tensor.Users().end() - 1) { - out << ", "; - } - } - out << ")"; - out << "\n"; -} - -void Dump(const LiteRtOpT& op, std::ostream& out) { - out << "LiteRtOp : [ "; - DumpNode(op, out); - out << " ] "; - DumpSignature(op.Inputs(), op.Outputs(), out); - out << "\n"; -} - -void Dump(const LiteRtSubgraphT& subgraph, std::ostream& out) { - constexpr absl::string_view kSubgraphTpl = - "LiteRtSubgraph : [ #ops=%d #tensors=%d ] "; - out << absl::StreamFormat(kSubgraphTpl, subgraph.Ops().size(), - subgraph.Tensors().size()); - DumpSignature(subgraph.Inputs(), subgraph.Outputs(), out); - out << "\n"; -} - -void Dump(const CompilerPlugin& plugin, std::ostream& out) { - constexpr absl::string_view kPluginDumpTpl = - "SocManufacturer: %s\nSocModels: { "; - out << absl::StreamFormat(kPluginDumpTpl, plugin.SocManufacturer()); - - for (auto it = plugin.SocModels().begin(); it < plugin.SocModels().end(); - ++it) { - out << *it; - if (it != plugin.SocModels().end() - 1) { - out << ","; - } - out << " "; - } - - out << "}\n"; -} - -void Dump(const LiteRtModelT& model, std::ostream& out) { - out << absl::StreamFormat("LiteRtModel : [ #subgraphs=%d ]\n", - model.Subgraphs().size()); -} - -void DumpOptions(const LiteRtOpT& op, std::ostream& out) { - auto& opts = litert::internal::GetTflOptions(op); - if (opts.value == nullptr) { - out << "null options\n"; - return; - } - switch (op.OpCode()) { - case kLiteRtOpCodeTflAdd: - out << "fused_activation_function: " - << opts.AsAddOptions()->fused_activation_function << "\n"; - break; - case kLiteRtOpCodeTflMul: - out << "fused_activation_function: " - << opts.AsMulOptions()->fused_activation_function << "\n"; - break; - case kLiteRtOpCodeTflBatchMatmul: - out << "adj_x: " << opts.AsBatchMatMulOptions()->adj_x << "\n"; - out << "adj_y: " << opts.AsBatchMatMulOptions()->adj_y << "\n"; - out << "asymmetric_quantize_input: " - << opts.AsBatchMatMulOptions()->asymmetric_quantize_inputs << "\n"; - break; - case kLiteRtOpCodeTflConcatenation: - out << "axis: " << opts.AsConcatenationOptions()->axis << "\n"; - out << "fused_activation_function: " - << opts.AsConcatenationOptions()->fused_activation_function << "\n"; - break; - case kLiteRtOpCodeTflDiv: - out << "fused_activation_function: " - << opts.AsDivOptions()->fused_activation_function << "\n"; - break; - case kLiteRtOpCodeTflFullyConnected: - out << "weights_format: " - << opts.AsFullyConnectedOptions()->weights_format << "\n"; - out << "keep_num_dims: " << opts.AsFullyConnectedOptions()->keep_num_dims - << "\n"; - out << "quantized_bias_type: " - << opts.AsFullyConnectedOptions()->quantized_bias_type << "\n"; - out << "asymmetric_quantize_input: " - << opts.AsFullyConnectedOptions()->asymmetric_quantize_inputs << "\n"; - out << "fused_activation_function: " - << opts.AsFullyConnectedOptions()->fused_activation_function << "\n"; - break; - case kLiteRtOpCodeTflSoftmax: - out << "beta: " << opts.AsSoftmaxOptions()->beta << "\n"; - break; - case kLiteRtOpCodeTflStridedSlice: - out << "begin_mask: " << opts.AsStridedSliceOptions()->begin_mask << "\n"; - out << "end_mask: " << opts.AsStridedSliceOptions()->end_mask << "\n"; - out << "ellipsis_mask: " << opts.AsStridedSliceOptions()->ellipsis_mask - << "\n"; - out << "new_axis_mask: " << opts.AsStridedSliceOptions()->new_axis_mask - << "\n"; - out << "shrink_axis_mask: " - << opts.AsStridedSliceOptions()->shrink_axis_mask << "\n"; - out << "offset: " << opts.AsStridedSliceOptions()->offset << "\n"; - break; - case kLiteRtOpCodeTflSub: - out << "fused_activation_function: " - << opts.AsSubOptions()->fused_activation_function << "\n"; - break; - case kLiteRtOpCodeTflReshape: - out << "new_shape: "; - if (opts.AsReshapeOptions() != nullptr) { - const int32_t* new_shape = opts.AsReshapeOptions()->new_shape.data(); - int32_t new_shape_size = opts.AsReshapeOptions()->new_shape.size(); - for (int i = 0; i < new_shape_size; ++i) { - out << new_shape[i] << " "; - } - } - break; - case kLiteRtOpCodeTflSum: - out << "keepdims: " << opts.AsReducerOptions()->keep_dims << "\n"; - break; - case kLiteRtOpCodeTflPack: - out << "axis: " << opts.AsPackOptions()->axis << "\n"; - break; - default: - out << "No options for op code: " << op.OpCode(); - break; - } -} - -void Dump(Quantization quantization, std::ostream& out) { - int max_display_count; - switch (quantization.first) { - case kLiteRtQuantizationNone: - return; - case kLiteRtQuantizationPerTensor: - out << absl::StreamFormat(" ", - quantization.second.per_tensor.zero_point, - quantization.second.per_tensor.scale); - return; - case kLiteRtQuantizationPerChannel: - max_display_count = - kMaxDisplayCount < quantization.second.per_channel.num_channels - ? kMaxDisplayCount - : quantization.second.per_channel.num_channels; - out << absl::StreamFormat(" ", quantization.second.per_channel.quantized_dimension); - return; - default: - out << " "; - return; - } -} - -} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/tools/dump.h b/tensorflow/lite/experimental/litert/tools/dump.h deleted file mode 100644 index 89254ae48e29a6..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/dump.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_DUMP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_DUMP_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" - -namespace litert::internal { - -// -// LiteRt IR -// - -// Dump details about the given LiteRtOpT to the given stream. -void Dump(const LiteRtOpT& op, std::ostream& out = std::cerr); - -// Dump details about the given LiteRtSubgraphT to the given stream. -void Dump(const LiteRtSubgraphT& subgraph, std::ostream& out = std::cerr); - -// Dump details about the given LiteRtTensorT to the given stream. -void Dump(const LiteRtTensorT& tensor, std::ostream& out = std::cerr); - -// Dump details about the given LiteRtOpCode to the given stream. -void Dump(LiteRtOpCode code, std::ostream& out = std::cerr); - -// Dump details about the given LiteRtElementType to the given stream. -void Dump(LiteRtElementType type, std::ostream& out = std::cerr); - -// Dump details about the given LiteRtRankedTensorType to the given stream. -void Dump(const LiteRtRankedTensorType& type, std::ostream& out = std::cerr); - -// Dump details about the given LiteRtModel to the given stream. -void Dump(const LiteRtModelT& model, std::ostream& out = std::cerr); - -// Dump details about the given quantization params. -void Dump(Quantization quantization, std::ostream& out = std::cerr); - -// Dump details about options -void DumpOptions(const LiteRtOpT& op, std::ostream& out = std::cerr); - -// -// Library Utilities -// - -// Dumps details about the loaded LiteRtCompilerPlugin library. -void Dump(const CompilerPlugin& plugin, std::ostream& out = std::cerr); - -} // namespace litert::internal - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_DUMP_H_ diff --git a/tensorflow/lite/experimental/litert/tools/dump_test.cc b/tensorflow/lite/experimental/litert/tools/dump_test.cc deleted file mode 100644 index ff89547c2350aa..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/dump_test.cc +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/tools/dump.h" - -#include -#include -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" - -namespace { - -using ::litert::internal::Dump; -using ::litert::internal::DumpOptions; -using ::litert::testing::LoadTestFileModel; - -TEST(DumpTest, TestDump) { - auto model = LoadTestFileModel("one_mul.tflite"); - - { - std::ostringstream model_dump; - Dump(*model.Get(), model_dump); - EXPECT_EQ(model_dump.view(), "LiteRtModel : [ #subgraphs=1 ]\n"); - } - - { - const LiteRtTensorT& in_tensor = model.Get()->Subgraph(0).Input(0); - std::ostringstream in_tensor_dump; - Dump(in_tensor, in_tensor_dump); - EXPECT_EQ(in_tensor_dump.view(), - "LiteRtTensor : <2x2xf32> [ * ] (TFL_MUL)\n"); - } - - { - const LiteRtTensorT& out_tensor = model.Get()->Subgraph(0).Output(0); - std::ostringstream out_tensor_dump; - Dump(out_tensor, out_tensor_dump); - EXPECT_EQ(out_tensor_dump.view(), - "LiteRtTensor : <2x2xf32> [ TFL_MUL ] ()\n"); - } - - { - const LiteRtOpT& op = model.Get()->Subgraph(0).Op(0); - std::ostringstream op_dump; - Dump(op, op_dump); - EXPECT_EQ(op_dump.view(), - "LiteRtOp : [ TFL_MUL ] (<2x2xf32>, <2x2xf32>) -> <2x2xf32>\n"); - } - - { - const LiteRtSubgraphT& subgraph = model.Get()->Subgraph(0); - std::ostringstream subgraph_dump; - Dump(subgraph, subgraph_dump); - EXPECT_EQ( - subgraph_dump.view(), - "LiteRtSubgraph : [ #ops=1 #tensors=3 ] (<2x2xf32>, <2x2xf32>) -> " - "<2x2xf32>\n"); - } -} - -TEST(DumpTest, TestDumpOptions) { - auto model = LoadTestFileModel("simple_strided_slice_op.tflite"); - const LiteRtOpT& op = model.Get()->Subgraph(0).Op(0); - std::ostringstream op_dump; - DumpOptions(op, op_dump); - EXPECT_EQ(op_dump.view(), - "begin_mask: 0\n" - "end_mask: 0\n" - "ellipsis_mask: 0\n" - "new_axis_mask: 0\n" - "shrink_axis_mask: 0\n" - "offset: 0\n"); -} - -TEST(DumpTest, TestDumpPerTensorQuantization) { - QuantizationDetail per_tensor_detail; - per_tensor_detail.per_tensor.scale = 1.0; - per_tensor_detail.per_tensor.zero_point = 2; - std::ostringstream q_dump; - Dump(std::make_pair(kLiteRtQuantizationPerTensor, per_tensor_detail), q_dump); - EXPECT_EQ(q_dump.view(), " "); -} - -TEST(DumpTest, TestDumpPerChannelQuantization) { - static constexpr size_t kRank = 2; - static constexpr size_t kQuantizedDimension = 1; - static constexpr float kScales[kRank] = {1.0, 2.0}; - static constexpr int64_t kZps[kRank] = {2, 3}; - QuantizationDetail per_channel_detail; - per_channel_detail.per_channel.scales = const_cast(kScales); - per_channel_detail.per_channel.zero_points = const_cast(kZps); - per_channel_detail.per_channel.quantized_dimension = kQuantizedDimension; - per_channel_detail.per_channel.num_channels = kRank; - std::ostringstream q_dump; - Dump(std::make_pair(kLiteRtQuantizationPerChannel, per_channel_detail), - q_dump); - EXPECT_FALSE(q_dump.view().empty()); -} - -TEST(DumpTest, TestDumpNoQuantization) { - QuantizationDetail none_detail; - std::ostringstream q_dump; - Dump(std::make_pair(kLiteRtQuantizationNone, none_detail), q_dump); - EXPECT_TRUE(q_dump.view().empty()); -} - -TEST(DumpTest, TestDumpUnknownQuantization) { - QuantizationDetail detail; - std::ostringstream q_dump; - Dump(std::make_pair(kLiteRtQuantizationBlockWise, detail), q_dump); - EXPECT_EQ(q_dump.view(), " "); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/tools/outstream.h b/tensorflow/lite/experimental/litert/tools/outstream.h deleted file mode 100644 index a920f21839592b..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/outstream.h +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_OUTSTREAM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_OUTSTREAM_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" - -namespace litert::tools { - -using OutStream = std::reference_wrapper; -using OutStreamPtr = std::unique_ptr; - -// Out stream configured by a user by flag. -class UserStream { - public: - // Parse the flag and get a configured stream. - static UserStream MakeFromFlag(absl::string_view flag) { - if (flag == kCerr) { - LITERT_LOG(LITERT_INFO, "Setup cerr stream\n", ""); - return UserStream(std::cerr); - } else if (flag == kCout) { - LITERT_LOG(LITERT_INFO, "Setup cout stream\n", ""); - return UserStream(std::cout); - } else if (flag == kNone) { - LITERT_LOG(LITERT_INFO, "Setup null stream\n", ""); - return UserStream(); - } else { - // File stream. - LITERT_LOG(LITERT_INFO, "Setup file stream\n", ""); - auto ofstream = std::make_unique(); - ofstream->open(flag.data()); - return UserStream(std::move(ofstream)); - } - } - - // Get the actual stream to write to. - OutStream Get() { return used_; } - - // Silent stream. - UserStream() - : stored_(std::make_unique(nullptr)), used_(*stored_) {} - // From reference to external stream (cerr, cout) - explicit UserStream(OutStream ostream) : stored_(nullptr), used_(ostream) {} - // From stream to internalize. - explicit UserStream(OutStreamPtr ostream) - : stored_(std::move(ostream)), used_(*stored_) {} - - UserStream(UserStream&&) = default; - UserStream& operator=(UserStream&&) = default; - - private: - // These are used in the various CLI's flags that configure output streams. - static constexpr absl::string_view kCerr = "--"; - static constexpr absl::string_view kCout = "-"; - static constexpr absl::string_view kNone = "none"; - - OutStreamPtr stored_; - OutStream used_; -}; - -} // namespace litert::tools - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_OUTSTREAM_H_ diff --git a/tensorflow/lite/experimental/litert/tools/run_model.cc b/tensorflow/lite/experimental/litert/tools/run_model.cc deleted file mode 100644 index c360beae0f2a4b..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/run_model.cc +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include "absl/flags/flag.h" -#include "absl/flags/parse.h" -#include "absl/log/absl_log.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/profiling/time.h" - -ABSL_FLAG(std::string, graph, "", "Model filename to use for testing."); -ABSL_FLAG(std::string, dispatch_library_dir, "", - "Path to the dispatch library."); -ABSL_FLAG(bool, use_gpu, false, "Use GPU Accelerator."); - -namespace litert { -namespace { - -Expected RunModel() { - if (absl::GetFlag(FLAGS_graph).empty()) { - return Error(kLiteRtStatusErrorInvalidArgument, - "Model filename is empty. Use --graph to provide it."); - } - - ABSL_LOG(INFO) << "Model: " << absl::GetFlag(FLAGS_graph); - LITERT_ASSIGN_OR_RETURN(auto model, - Model::CreateFromFile(absl::GetFlag(FLAGS_graph))); - - const std::string dispatch_library_dir = - absl::GetFlag(FLAGS_dispatch_library_dir); - - std::vector environment_options = {}; - if (!dispatch_library_dir.empty()) { - environment_options.push_back(litert::Environment::Option{ - litert::Environment::OptionTag::DispatchLibraryDir, - absl::string_view(dispatch_library_dir)}); - }; - - LITERT_ASSIGN_OR_RETURN( - auto env, - litert::Environment::Create(absl::MakeConstSpan(environment_options))); - - ABSL_LOG(INFO) << "Create CompiledModel"; - auto accelerator = absl::GetFlag(FLAGS_use_gpu) ? kLiteRtHwAcceleratorGpu - : kLiteRtHwAcceleratorNone; - if (accelerator == kLiteRtHwAcceleratorGpu) { - ABSL_LOG(INFO) << "Using GPU Accelerator"; - } - LITERT_ASSIGN_OR_RETURN(auto compiled_model, - CompiledModel::Create(env, model, accelerator)); - - LITERT_ASSIGN_OR_RETURN(auto signatures, model.GetSignatures()); - size_t signature_index = 0; - - ABSL_LOG(INFO) << "Prepare input buffers"; - - LITERT_ASSIGN_OR_RETURN(auto input_buffers, - compiled_model.CreateInputBuffers(signature_index)); - - ABSL_LOG(INFO) << "Prepare output buffers"; - - LITERT_ASSIGN_OR_RETURN(auto output_buffers, - compiled_model.CreateOutputBuffers(signature_index)); - - ABSL_LOG(INFO) << "Run model"; - uint64_t start = tflite::profiling::time::NowMicros(); - auto status = - compiled_model.Run(signature_index, input_buffers, output_buffers); - uint64_t end = tflite::profiling::time::NowMicros(); - LITERT_LOG(LITERT_INFO, "Run took %lu microseconds", end - start); - - ABSL_LOG(INFO) << "Model run completed"; - - return status; -} - -} // namespace -} // namespace litert - -int main(int argc, char** argv) { - absl::ParseCommandLine(argc, argv); - - auto res = litert::RunModel(); - if (!res) { - LITERT_LOG(LITERT_ERROR, "%s", res.Error().Message().c_str()); - return EXIT_FAILURE; - } - return EXIT_SUCCESS; -} diff --git a/tensorflow/lite/experimental/litert/tools/tool_display.cc b/tensorflow/lite/experimental/litert/tools/tool_display.cc deleted file mode 100644 index 2067d7826adb66..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/tool_display.cc +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/tools/tool_display.h" - -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/tools/outstream.h" - -namespace litert::tools { - -std::string ToolDisplay::MakeLabel(absl::string_view tool_label) { - return absl::StrFormat( - "[LITERT_TOOLS%s] ", - tool_label.empty() ? tool_label : absl::StrFormat(":%s", tool_label)); -} - -std::ostream& ToolDisplay::Display() { return ostream_.Get(); } - -std::ostream& ToolDisplay::Labeled() { - Display() << label_; - return Display(); -} - -std::ostream& ToolDisplay::Indented() { - Display() << "\t"; - return Display(); -} - -void ToolDisplay::Start(const absl::string_view scope_name) { - static constexpr absl::string_view kStartFmt = "Starting %s...\n"; - Labeled() << absl::StreamFormat(kStartFmt, scope_name); -} - -void ToolDisplay::Done(const absl::string_view scope_name) { - static constexpr absl::string_view kDoneFmt = "%s Done!\n"; - Labeled() << ""; - Indented() << absl::StreamFormat(kDoneFmt, scope_name); -} - -void ToolDisplay::Fail() { - Labeled() << ""; - Indented() << "Failed\n"; -} - -ToolDisplay::LoggedScope ToolDisplay::StartS(absl::string_view scope_name) { - return LoggedScope(*this, scope_name); -} - -void ToolDisplay::LoggedScope::Start() { parent_.Start(scope_name_); } - -void ToolDisplay::LoggedScope::Done() { parent_.Done(scope_name_); } - -ToolDisplay::LoggedScope::~LoggedScope() { Done(); } - -ToolDisplay::LoggedScope::LoggedScope(ToolDisplay& parent, - absl::string_view scope_name) - : parent_(parent), scope_name_(scope_name) { - Start(); -} - -static constexpr absl::string_view kArt = R"( - __ _ __ ____ __ - / / (_/ /____ / __ \/ /_ - / / / / __/ _ \/ /_/ / __/ - / /___/ / /_/ __/ _, _/ /_ -/_____/_/\__/\___/_/ |_|\__/ -)"; - -void DumpPreamble(ToolDisplay& display) { display.Display() << kArt << "\n"; } - -} // namespace litert::tools diff --git a/tensorflow/lite/experimental/litert/tools/tool_display.h b/tensorflow/lite/experimental/litert/tools/tool_display.h deleted file mode 100644 index 583d07ee3480f6..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/tool_display.h +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_TOOL_DISPLAY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_TOOL_DISPLAY_H_ - -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/tools/outstream.h" - -namespace litert::tools { - -// Utility class for interactive logging for usage in command line tools only. -// Allows user to explicitly set target stream. -class ToolDisplay { - public: - using Ptr = std::unique_ptr; - // Construct configured ToolDisplay. Label is used for prefixing dumps - // in "LabeledStream". - explicit ToolDisplay(UserStream&& ostream, absl::string_view tool_label = "") - : label_(MakeLabel(tool_label)), - ostream_(std::forward(ostream)) {} - explicit ToolDisplay(OutStream ostream, absl::string_view tool_label = "") - : label_(MakeLabel(tool_label)), ostream_(UserStream(ostream)) {} - - ToolDisplay(const ToolDisplay&) = delete; - ToolDisplay& operator=(const ToolDisplay&) = delete; - ToolDisplay(ToolDisplay&&) = delete; - ToolDisplay& operator=(ToolDisplay&&) = delete; - - // Get out stream. - std::ostream& Display(); - - // Get Display with label prefix. - std::ostream& Labeled(); - - // Get Display with indent. - std::ostream& Indented(); - - // Log string indicating a sub rountine is beginning. - void Start(absl::string_view scope_name); - - // Log string indicating a sub rountine is done and succeeded. - void Done(absl::string_view scope_name = ""); - - // Log string indicating a sub rountine is done and failed. - void Fail(); - - // Logs "start/finish" messages automatically. - class LoggedScope { - friend class ToolDisplay; - - public: - LoggedScope(const LoggedScope&) = delete; - LoggedScope& operator=(const LoggedScope&) = delete; - LoggedScope(LoggedScope&&) = delete; - LoggedScope& operator=(LoggedScope&&) = delete; - - ~LoggedScope(); - - private: - explicit LoggedScope(ToolDisplay& parent, absl::string_view scope_name); - - void Start(); - void Done(); - - ToolDisplay& parent_; - // These should all be from literals. - absl::string_view scope_name_; - }; - - // Get object that prints a start message and an exit message - // automatically when it goes out of scope. - [[maybe_unused]] LoggedScope StartS(absl::string_view scope_name); - - private: - static std::string MakeLabel(absl::string_view tool_label); - std::string label_; - UserStream ostream_; -}; - -// Print art and info at cli startup. -void DumpPreamble(ToolDisplay& display); - -} // namespace litert::tools - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_TOOL_DISPLAY_H_ diff --git a/tensorflow/lite/experimental/litert/tools/tool_display_test.cc b/tensorflow/lite/experimental/litert/tools/tool_display_test.cc deleted file mode 100644 index 94027f663c301c..00000000000000 --- a/tensorflow/lite/experimental/litert/tools/tool_display_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/tools/tool_display.h" - -#include - -#include -#include -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" - -namespace { - -using ::litert::tools::ToolDisplay; -using ::testing::EndsWith; -using ::testing::StartsWith; - -static constexpr absl::string_view kToolName = "test-tool"; -static constexpr absl::string_view kLabel = "[LITERT_TOOLS:test-tool]"; -static constexpr absl::string_view kStartLabel = "Test Routine"; -static constexpr absl::string_view kDisplayInfo = "info"; - -TEST(TestToolDisplay, Display) { - std::stringstream out; - ToolDisplay display(out, kToolName); - display.Display() << kDisplayInfo; - EXPECT_EQ(out.view(), kDisplayInfo); -} - -TEST(TestToolDisplay, Indented) { - std::stringstream out; - ToolDisplay display(out, kToolName); - display.Indented() << kDisplayInfo; - EXPECT_EQ(out.view(), absl::StrFormat("\t%s", kDisplayInfo)); -} - -TEST(TestToolDisplay, Labeled) { - std::stringstream out; - ToolDisplay display(out, kToolName); - display.Labeled() << kDisplayInfo; - EXPECT_EQ(out.view(), absl::StrFormat("%s %s", kLabel, kDisplayInfo)); -} - -TEST(TestToolDisplay, LabeledNoToolName) { - std::stringstream out; - ToolDisplay display(out); - display.Labeled() << kDisplayInfo; - EXPECT_EQ(out.view(), - absl::StrFormat("%s %s", "[LITERT_TOOLS]", kDisplayInfo)); -} - -TEST(TestToolDisplay, Start) { - std::stringstream out; - ToolDisplay display(out, kToolName); - display.Start(kStartLabel); - EXPECT_EQ(out.view(), - absl::StrFormat("%s Starting %s...\n", kLabel, kStartLabel)); -} - -TEST(TestToolDisplay, Done) { - std::stringstream out; - ToolDisplay display(out, kToolName); - display.Done(kStartLabel); - EXPECT_EQ(out.view(), - absl::StrFormat("%s \t%s Done!\n", kLabel, kStartLabel)); -} - -TEST(TestToolDisplay, Fail) { - std::stringstream out; - ToolDisplay display(out, kToolName); - display.Fail(); - EXPECT_EQ(out.view(), absl::StrFormat("%s \tFailed\n", kLabel)); -} - -TEST(TestLoggedScope, EnterExit) { - std::stringstream out; - ToolDisplay display(out, kToolName); - { - auto s = display.StartS(kStartLabel); - } - EXPECT_THAT(out.view(), StartsWith(absl::StrFormat("%s Starting %s...\n", - kLabel, kStartLabel))); - EXPECT_THAT(out.view(), EndsWith(absl::StrFormat("%s \t%s Done!\n", kLabel, - kStartLabel))); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/c/BUILD b/tensorflow/lite/experimental/litert/vendors/c/BUILD deleted file mode 100644 index 0692c1f0cd4a11..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/c/BUILD +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "litert_compiler_plugin", - hdrs = ["litert_compiler_plugin.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - ], -) - -cc_library( - name = "litert_compiler_plugin_api", - hdrs = ["litert_compiler_plugin_api.h"], - deps = [ - ":litert_compiler_plugin", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "@com_google_absl//absl/strings:string_view", - ], -) - -# This library is used to build the C API header files for the vendor dispatch API. -# All the vendor dispatch .so tragets should depend on this library. -cc_library( - name = "litert_dispatch_c_api", - hdrs = [ - "litert_dispatch.h", - "litert_dispatch_api.h", - ], - deps = [ - # only depend on the headers, not the implementation. - "//tensorflow/lite/experimental/litert/c:litert_dispatch_headers", - ], -) - -# This test verifies that the C API header files can build via C compiler. -cc_test( - name = "litert_vendor_c_api_common_test", - srcs = ["litert_vendor_c_api_common_test.c"], - copts = ["--std=c11"], - linkopts = ["-ldl"], - deps = [ - ":litert_compiler_plugin", - ":litert_compiler_plugin_api", - ":litert_dispatch_c_api", - ], -) - -exports_files(srcs = glob(["litert_*.h"])) diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h b/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h deleted file mode 100644 index 926c4f98d469c0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITERT_DEFINE_HANDLE(LiteRtCompilerPlugin); - -// Artifact produced from compiling a selected partition of ops. -LITERT_DEFINE_HANDLE(LiteRtCompiledResult); - -// -// Plugin -// - -LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version); - -// Name associated with the manufacturer this plugin relates to (e.g, -// GoogleTensor, Qualcomm). -const char* LiteRtGetCompilerPluginSocManufacturer(); - -LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin); - -void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin); - -// Return the HW supported by this plugin (e.g., GPU, NPU) -LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( - LiteRtCompilerPlugin compiler_plugin, - LiteRtHwAccelerators* supported_hardware); - -// Number of SoC models supported by this plugin. -LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( - LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex* num_supported_soc_models); - -// Gets the name of the SoC model at the given index. The memory -// associated with the returned name is owned by the plugin. -LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, - const char** soc_model_name); - -// Select desired ops for compilation. This will only be called once -// per subgraph, plugins should select all supportable ops. -LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtSubgraph subgraph, - LiteRtOpList selected_ops); - -// Prepare result to pass to the runtime for given model containing partitioned -// subgraphs. Optionally, handles a SoC model (parameter `soc_model` can be NULL -// to specify a default SoC model). -LiteRtStatus LiteRtCompilerPluginCompile(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtModel partitions, - LiteRtCompiledResult* compiled_result); - -// Set any flags for the compiler do use during compilation. Flag data may be -// released or reused after this function returns. Flags are string key -> -// optional string value pairs. A non-existent value is represented by an empty -// string. Calling this function will unset any previously set flags. -LiteRtStatus LiteRtCompilerPluginSetFlags(LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex num_flags, - const char** keys, - const char** values); - -// -// Compiled Partition -// - -void LiteRtDestroyCompiledResult(LiteRtCompiledResult result); - -// Get the buffer for the compiled byte code for the given index. -LiteRtStatus LiteRtGetCompiledResultByteCode( - LiteRtCompiledResult compiled_result, LiteRtParamIndex byte_code_idx, - const void** byte_code, size_t* byte_code_size); - -// The number of individual byte code modules. -LiteRtStatus LiteRtCompiledResultNumByteCodeModules( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_byte_code); - -// Get per-op info related to a particular compiled partition as well as the -// index of the respective byte code buffer. -LiteRtStatus LiteRtGetCompiledResultCallInfo( - LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, - const void** call_info, size_t* call_info_size, - LiteRtParamIndex* byte_code_idx); - -// Get the number of calls that will be made to the HAL for this graph. -// This should equal the number of partitions given for compilation which -// is equal to the number of custom ops in the final model. -LiteRtStatus LiteRtGetNumCompiledResultCalls( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h b/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h deleted file mode 100644 index 8555933e6a9890..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_API_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_API_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" - -// Wrapper for dynamically loaded LiteRtCompilerPlugin library. See -// "litert_compiler_plugin.h". - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// -// Api Interface -// - -typedef LiteRtStatus (*LiteRtGetCompilerPluginVersionT)(LiteRtApiVersion*); - -typedef const char* (*LiteRtGetCompilerPluginSocManufacturerT)(); - -typedef LiteRtStatus (*LiteRtCreateCompilerPluginT)(LiteRtCompilerPlugin*); - -typedef void (*LiteRtDestroyCompilerPluginT)(LiteRtCompilerPlugin); - -typedef LiteRtStatus (*LiteRtGetCompilerPluginSupportedHardwareT)( - LiteRtCompilerPlugin, LiteRtHwAccelerators*); - -typedef LiteRtStatus (*LiteRtGetNumCompilerPluginSupportedSocModelsT)( - LiteRtCompilerPlugin, LiteRtParamIndex*); - -typedef LiteRtStatus (*LiteRtGetCompilerPluginSupportedSocModelT)( - LiteRtCompilerPlugin, LiteRtParamIndex soc_model_idx, - const char** soc_moel_idx); - -typedef LiteRtStatus (*LiteRtCompilerPluginPartitionT)( - LiteRtCompilerPlugin, const char* soc_model, LiteRtSubgraph subgraph, - LiteRtOpList selected_ops); - -typedef LiteRtStatus (*LiteRtCompilerPluginCompileT)( - LiteRtCompilerPlugin, const char* soc_model, LiteRtModel partitions, - LiteRtCompiledResult* compiled_result); - -typedef void (*LiteRtDestroyCompiledResultT)(LiteRtCompiledResult); - -typedef LiteRtStatus (*LiteRtGetCompiledResultByteCodeT)( - LiteRtCompiledResult, LiteRtParamIndex byte_code_idx, - const void** byte_code, size_t* byte_code_size); - -typedef LiteRtStatus (*LiteRtCompiledResultNumByteCodeModulesT)( - LiteRtCompiledResult, LiteRtParamIndex* num_byte_code); - -typedef LiteRtStatus (*LiteRtGetCompiledResultCallInfoT)( - LiteRtCompiledResult, LiteRtParamIndex call_idx, const void** call_info, - size_t* call_info_size, LiteRtParamIndex* byte_code_idx); - -typedef LiteRtStatus (*LiteRtGetNumCompiledResultCallsT)( - LiteRtCompiledResult, LiteRtParamIndex* num_calls); - -typedef LiteRtStatus (*LiteRtCompilerPluginSetFlagsT)( - LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex num_flags, - const char** keys, const char** values); - -// -// Function Pointer Container -// - -// Wraps all resolved functions from api interface. -struct LiteRtCompilerPluginApi { - LiteRtGetCompilerPluginVersionT get_compiler_plugin_version; - LiteRtGetCompilerPluginSocManufacturerT get_compiler_plugin_soc_manufacturer; - LiteRtCreateCompilerPluginT create_compiler_plugin; - LiteRtDestroyCompilerPluginT destroy_compiler_plugin; - - LiteRtGetCompilerPluginSupportedHardwareT - get_compiler_plugin_supported_hardware; - LiteRtGetNumCompilerPluginSupportedSocModelsT - get_num_compiler_plugin_supported_models; - LiteRtGetCompilerPluginSupportedSocModelT - get_compiler_plugin_supported_soc_model; - - LiteRtCompilerPluginPartitionT compiler_plugin_partition; - LiteRtCompilerPluginCompileT compiler_plugin_compile; - - LiteRtDestroyCompiledResultT destroy_compiled_result; - LiteRtGetCompiledResultByteCodeT get_compiled_result_byte_code; - LiteRtCompiledResultNumByteCodeModulesT get_compiled_result_num_byte_code; - LiteRtGetCompiledResultCallInfoT get_compiled_result_call_info; - LiteRtGetNumCompiledResultCallsT get_compiled_result_num_calls; - - LiteRtCompilerPluginSetFlagsT set_flags; -}; - -#ifdef __cplusplus -} - -#include "absl/strings/string_view.h" - -static constexpr absl::string_view kLiteRtGetCompilerPluginVersion = - "LiteRtGetCompilerPluginVersion"; - -static constexpr absl::string_view kLiteRtGetCompilerPluginSupportedHardware = - "LiteRtGetCompilerPluginSupportedHardware"; - -static constexpr absl::string_view kLiteRtGetCompilerPluginSocManufacturer = - "LiteRtGetCompilerPluginSocManufacturer"; -static constexpr absl::string_view - kLiteRtGetNumCompilerPluginSupportedSocModels = - "LiteRtGetNumCompilerPluginSupportedSocModels"; -static constexpr absl::string_view kLiteRtGetCompilerPluginSupportedSocModel = - "LiteRtGetCompilerPluginSupportedSocModel"; - -static constexpr absl::string_view kLiteRtCreateCompilerPlugin = - "LiteRtCreateCompilerPlugin"; -static constexpr absl::string_view kLiteRtDestroyCompilerPlugin = - "LiteRtDestroyCompilerPlugin"; - -static constexpr absl::string_view kLiteRtCompilerPluginPartition = - "LiteRtCompilerPluginPartition"; -static constexpr absl::string_view kLiteRtCompilerPluginCompile = - "LiteRtCompilerPluginCompile"; - -static constexpr absl::string_view kLiteRtDestroyCompiledResult = - "LiteRtDestroyCompiledResult"; -static constexpr absl::string_view kLiteRtGetCompiledResultByteCode = - "LiteRtGetCompiledResultByteCode"; -static constexpr absl::string_view kLiteRtCompiledResultNumByteCodeModules = - "LiteRtCompiledResultNumByteCodeModules"; -static constexpr absl::string_view kLiteRtGetCompiledResultCallInfo = - "LiteRtGetCompiledResultCallInfo"; -static constexpr absl::string_view kLiteRtGetNumCompiledResultCalls = - "LiteRtGetNumCompiledResultCalls"; - -static constexpr absl::string_view kLiteRtCompilerPluginSetFlags = - "LiteRtCompilerPluginSetFlags"; - -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_API_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h b/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h deleted file mode 100644 index 7487daf9c9ae22..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h +++ /dev/null @@ -1,309 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// ///////////////////////////////////////////////////////////////////////////// -// Basic Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LITERT_DEFINE_HANDLE(LiteRtDispatchDeviceContext); -LITERT_DEFINE_HANDLE(LiteRtDispatchInvocationContext); -LITERT_DEFINE_HANDLE(LiteRtDispatchMetrics); - -typedef uint64_t LiteRtTensorBufferHandle; - -typedef enum LiteRtDispatchCapabilities { - kLiteRtDispatchCapabilitiesNone = 0, - kLiteRtDispatchCapabilitiesBasic = 1, // The vendor supports the Basic API - kLiteRtDispatchCapabilitiesAsync = 2, // The vendor supports the Async API - kLiteRtDispatchCapabilitiesGraph = 4, // The vendor supports the Graph API -} LiteRtDispatchCapabilities; - -// Types of executable that can run on the HW accelerators. -typedef enum LiteRtDispatchExecutableType { - kLiteRtDispatchExecutableTypeUnknown = 0, - kLiteRtDispatchExecutableTypeDspLibrary = 1, // DSP library - kLiteRtDispatchExecutableTypeMlModel = 2, // Vendor-specific ML model -} LiteRtDispatchExecutableType; - -typedef struct LiteRtDispatchOption { - const char* name; - LiteRtAny value; -} LiteRtDispatchOption; - -typedef struct LiteRtMetric { - const char* name; - LiteRtAny value; -} LiteRtMetric; - -typedef struct LiteRtMemBuffer { - int fd; // File descriptor for an mmapped buffer, -1 if unused. - const void* base_addr; // Base address of the buffer. - size_t offset; // Offset of the buffer from the base address. - size_t size; // Buffer size. -} LiteRtMemBuffer; - -// This option can be used to specify a directory from where to load shared -// libraries. -static const char* kDispatchOptionSharedLibraryDir = "shared_library_dir"; - -// Initialize the Dispatch API runtime. -// -// This function should be called before calling any other Dispatch API -// functions. -LiteRtStatus LiteRtDispatchInitialize(const LiteRtDispatchOption* options, - int num_options); - -// Return the version of the Dispatch API runtime. -LiteRtStatus LiteRtDispatchGetApiVersion(LiteRtApiVersion* api_version); - -// Return the vendor id of the Dispatch API runtime. -// -// This function returns a pointer to a statically allocated string that is the -// ID of vendor providing the Dispatch API runtime. -LiteRtStatus LiteRtDispatchGetVendorId(const char** vendor_id); - -// Return the build ID of the Dispatch API runtime. -// -// This function returns a pointer to a statically allocated string that is the -// ID of the Dispatch API runtime build. -LiteRtStatus LiteRtDispatchGetBuildId(const char** build_id); - -// Return the capabilities supported by the Dispatch API runtime as a set of the -// values specified in LiteRtDispatchCapabilities. -LiteRtStatus LiteRtDispatchGetCapabilities(int* capabilities); - -// Create a `LiteRtDispatchDeviceContext` object. -// -// The returned object is used to talk with the underlying HW. The caller owns -// the memory associated with the context and should call -// LiteRtDispatchDeviceContextDestroy() to release it. Return NULL in case of -// error. -LiteRtStatus LiteRtDispatchDeviceContextCreate( - LiteRtDispatchDeviceContext* device_context); - -// Release a `LiteRtDispatchDeviceContext` object. -// -// The given context should be release only after releasing all associated -// objects. -LiteRtStatus LiteRtDispatchDeviceContextDestroy( - LiteRtDispatchDeviceContext device_context); - -// Given a tensor type for an invocation context input, obtain the attributes -// the HW requires for the associated tensor buffer. The returned -// `tensor_buffer_requirements` object is owned by the caller. -LiteRtStatus LiteRtDispatchGetInputRequirements( - LiteRtDispatchInvocationContext invocation_context, int input_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements); - -// Given a tensor type for an invocation context output, obtain the attributes -// the HW requires for the associated tensor buffer. The returned -// `tensor_buffer_requirements` object is owned by the caller. -LiteRtStatus LiteRtDispatchGetOutputRequirements( - LiteRtDispatchInvocationContext invocation_context, int output_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements); - -// Registers a buffer with the given device context. -// Note: The memory backing the buffer should be valid until -// `LiteRtDispatchUnregisterTensorBuffer` is called. -LiteRtStatus LiteRtDispatchRegisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBuffer tensor_buffer, - LiteRtTensorBufferHandle* tensor_buffer_handle); - -// Unregisters the registered buffer associated with the given -// `LiteRtTensorBufferHandle`. -// Note: The registered `LiteRtTensorBufferHandle` is supposed to be -// unregistered with this function before the associated `ThrContext` is deleted -// by calling `LiteRtDispatchDeviceContextDestroy`. -LiteRtStatus LiteRtDispatchUnregisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBufferHandle tensor_buffer_handle); - -// Create an invocation context to run a given function from a given -// executable. Parameter `function_name` is required if the provided executable -// includes multiple functions. -LiteRtStatus LiteRtDispatchInvocationContextCreate( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs, - LiteRtDispatchInvocationContext* invocation_context); - -LiteRtStatus LiteRtDispatchInvocationContextDestroy( - LiteRtDispatchInvocationContext invocation_context); - -LiteRtStatus LiteRtDispatchAttachInput( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -LiteRtStatus LiteRtDispatchAttachOutput( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -LiteRtStatus LiteRtDispatchDetachInput( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -LiteRtStatus LiteRtDispatchDetachOutput( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -LiteRtStatus LiteRtDispatchInvoke( - LiteRtDispatchInvocationContext invocation_context); - -// Start collection of HW-specific metrics at a specific level of detail (>= 0). -LiteRtStatus LiteRtDispatchStartMetricsCollection( - LiteRtDispatchInvocationContext invocation_context, int detail_level); - -// Stop collection of HW-specific metrics and report the collected -// metrics. Note: The caller is responsible for deallocating the returned -// metrics by calling `LiteRtDispatchDestroyMetrics`. -LiteRtStatus LiteRtDispatchStopMetricsCollection( - LiteRtDispatchInvocationContext invocation_context, - LiteRtDispatchMetrics* metrics); - -LiteRtStatus LiteRtDispatchGetNumMetrics(LiteRtDispatchMetrics metrics, - int* num_metrics); - -// Fetch a specific metric. The runtime owns the returned object. -LiteRtStatus LiteRtDispatchGetMetric(LiteRtDispatchMetrics metrics, - int metric_index, LiteRtMetric* metric); - -LiteRtStatus LiteRtDispatchDestroyMetrics(LiteRtDispatchMetrics metrics); - -// ///////////////////////////////////////////////////////////////////////////// -// Async Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus LiteRtDispatchAttachInputEvent( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtEvent input_event); - -LiteRtStatus LiteRtDispatchInvokeAsync( - LiteRtDispatchInvocationContext invocation_context, int num_output_events, - LiteRtEvent* output_events); - -// ///////////////////////////////////////////////////////////////////////////// -// Graph Execution API -// ///////////////////////////////////////////////////////////////////////////// - -typedef uint64_t LiteRtDispatchNodeId; -typedef uint64_t LiteRtDispatchEdgeId; -typedef uint64_t LiteRtDispatchExecutableHandle; - -LITERT_DEFINE_HANDLE(LiteRtDispatchGraph); - -// Types of graph nodes. -typedef enum LiteRtDispatchNodeType { - kLiteRtDispatchNodeTypeUnknown = 0, - kLiteRtDispatchNodeTypeDsp = - 1, // Can execute both ML models and Dsp libraries - kLiteRtDispatchNodeTypeNpu = 2, // Can execute only ML models -} LiteRtDispatchNodeType; - -LiteRtStatus LiteRtDispatchGraphCreate( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph** graph); - -LiteRtStatus LiteRtDispatchGraphDestroy(LiteRtDispatchGraph* graph); - -// Add a compute node to a given graph. Parameter node_id should be unique to -// the graph. -LiteRtStatus LiteRtDispatchAddNode(LiteRtDispatchGraph* graph, - LiteRtDispatchNodeId node_id, - LiteRtDispatchNodeType node_type); - -// Add an edge a given graph. Parameter edge_id should be unique to the graph. -LiteRtStatus LiteRtDispatchAddEdge(LiteRtDispatchGraph* graph, - LiteRtDispatchEdgeId edge_id); - -// Connect a given node's input. -LiteRtStatus LiteRtDispatchConnectNodeInput(LiteRtDispatchGraph* graph, - LiteRtDispatchNodeId node_id, - int input_index, - LiteRtDispatchEdgeId edge_id); - -// Connect a given node's output. -LiteRtStatus LiteRtDispatchConnectNodeOutput(LiteRtDispatchGraph* graph, - LiteRtDispatchNodeId node_id, - int output_index, - LiteRtDispatchEdgeId edge_id); - -// Connect a given graph's input. -LiteRtStatus LiteRtDispatchConnectGraphInput(LiteRtDispatchGraph* graph, - int input_index, - LiteRtDispatchEdgeId edge_id); - -// Connect a given graph's output. -LiteRtStatus LiteRtDispatchConnectGraphOutput(LiteRtDispatchGraph* graph, - int output_index, - LiteRtDispatchEdgeId edge_id); - -LiteRtStatus LiteRtDispatchLoadExecutable( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType type, const LiteRtMemBuffer* bytecode_buffer, - LiteRtDispatchExecutableHandle* exec_handle); - -LiteRtStatus LiteRtDispatchUnloadExecutable( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableHandle exec_handle); - -// Assign an executable function to a graph node. Parameter `function_name` is -// mandatory if the given executable includes multiple functions. -LiteRtStatus LiteRtDispatchAssignNodeFunction( - LiteRtDispatchGraph* graph, LiteRtDispatchNodeId node_id, - LiteRtDispatchExecutableHandle exec_handle, const char* function_name); - -// Add an annotation to an entire graph. -LiteRtStatus LiteRtDispatchAnnotateGraph(LiteRtDispatchGraph* graph, - const char* key, const char* value); - -// Add an annotation to a specified node. -LiteRtStatus LiteRtDispatchAnnotateNode(LiteRtDispatchGraph* graph, - LiteRtDispatchNodeId node_id, - const char* key, const char* value); - -// Add an annotation to a specified edge. -LiteRtStatus LiteRtDispatchAnnotateEdge(LiteRtDispatchGraph* graph, - LiteRtDispatchEdgeId edge_id, - const char* key, const char* value); - -LiteRtStatus LiteRtDispatchInvocationContextCreateFromGraph( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph* graph, - LiteRtDispatchInvocationContext* invocation_context); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h b/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h deleted file mode 100644 index 527a19c2630e09..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h +++ /dev/null @@ -1,245 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_API_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_API_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// ///////////////////////////////////////////////////////////////////////////// - -typedef LiteRtStatus (*LiteRtDispatchInitializeT)( - const LiteRtDispatchOption* options, int num_options); - -typedef LiteRtStatus (*LiteRtDispatchGetVendorIdT)(const char** vendor_id); - -typedef LiteRtStatus (*LiteRtDispatchGetBuildIdT)(const char** build_id); - -typedef LiteRtStatus (*LiteRtDispatchGetCapabilitiesT)(int* capabilities); - -typedef LiteRtStatus (*LiteRtDispatchDeviceContextCreateT)( - LiteRtDispatchDeviceContext* device_context); - -typedef LiteRtStatus (*LiteRtDispatchDeviceContextDestroyT)( - LiteRtDispatchDeviceContext device_context); - -typedef LiteRtStatus (*LiteRtDispatchGetInputRequirementsT)( - LiteRtDispatchInvocationContext invocation_context, int input_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements); - -typedef LiteRtStatus (*LiteRtDispatchGetOutputRequirementsT)( - LiteRtDispatchInvocationContext invocation_context, int output_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements); - -typedef LiteRtStatus (*LiteRtDispatchRegisterTensorBufferT)( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBuffer tensor_buffer, - LiteRtTensorBufferHandle* tensor_buffer_handle); - -typedef LiteRtStatus (*LiteRtDispatchUnregisterTensorBufferT)( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBufferHandle handle); - -typedef LiteRtStatus (*LiteRtDispatchInvocationContextCreateT)( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs, - LiteRtDispatchInvocationContext* invocation_context); - -typedef LiteRtStatus (*LiteRtDispatchInvocationContextDestroyT)( - LiteRtDispatchInvocationContext invocation_context); - -typedef LiteRtStatus (*LiteRtDispatchAttachInputT)( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -typedef LiteRtStatus (*LiteRtDispatchAttachOutputT)( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -typedef LiteRtStatus (*LiteRtDispatchDetachInputT)( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -typedef LiteRtStatus (*LiteRtDispatchDetachOutputT)( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle); - -typedef LiteRtStatus (*LiteRtDispatchInvokeT)( - LiteRtDispatchInvocationContext invocation_context); - -typedef LiteRtStatus (*LiteRtDispatchStartMetricsCollectionT)( - LiteRtDispatchInvocationContext invocation_context, int detail_level); - -typedef LiteRtStatus (*LiteRtDispatchStopMetricsCollectionT)( - LiteRtDispatchInvocationContext invocation_context, - LiteRtDispatchMetrics* metrics); - -typedef LiteRtStatus (*LiteRtDispatchGetNumMetricsT)( - LiteRtDispatchMetrics metrics, int* num_metrics); - -typedef LiteRtStatus (*LiteRtDispatchGetMetricT)(LiteRtDispatchMetrics metrics, - int metric_index, - LiteRtMetric* metric); - -typedef LiteRtStatus (*LiteRtDispatchDestroyMetricsT)( - LiteRtDispatchMetrics metrics); - -typedef struct LiteRtDispatchInterface { - LiteRtDispatchInitializeT initialize; - LiteRtDispatchGetVendorIdT get_vendor_id; - LiteRtDispatchGetBuildIdT get_build_id; - LiteRtDispatchGetCapabilitiesT get_capabilities; - LiteRtDispatchDeviceContextCreateT device_context_create; - LiteRtDispatchDeviceContextDestroyT device_context_destroy; - LiteRtDispatchGetInputRequirementsT get_input_requirements; - LiteRtDispatchGetOutputRequirementsT get_output_requirements; - LiteRtDispatchRegisterTensorBufferT register_tensor_buffer; - LiteRtDispatchUnregisterTensorBufferT unregister_tensor_buffer; - LiteRtDispatchInvocationContextCreateT invocation_context_create; - LiteRtDispatchInvocationContextDestroyT invocation_context_destroy; - LiteRtDispatchAttachInputT attach_input; - LiteRtDispatchAttachOutputT attach_output; - LiteRtDispatchDetachInputT detach_input; - LiteRtDispatchDetachOutputT detach_output; - LiteRtDispatchInvokeT invoke; - LiteRtDispatchStartMetricsCollectionT start_metrics_collection; - LiteRtDispatchStopMetricsCollectionT stop_metrics_collection; - LiteRtDispatchGetNumMetricsT get_num_metrics; - LiteRtDispatchGetMetricT get_metric; - LiteRtDispatchDestroyMetricsT destroy_metrics; -} LiteRtDispatchInterface; - -// ///////////////////////////////////////////////////////////////////////////// - -typedef LiteRtStatus (*LiteRtDispatchAttachInputEventT)( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtEvent input_event); - -typedef LiteRtStatus (*LiteRtDispatchInvokeAsyncT)( - LiteRtDispatchInvocationContext invocation_context, int num_output_events, - LiteRtEvent* output_events); - -typedef struct LiteRtDispatchAsyncInterface { - LiteRtDispatchAttachInputEventT attach_input_event; - LiteRtDispatchInvokeAsyncT invoke_async; -} LiteRtDispatchAsyncInterface; - -// ///////////////////////////////////////////////////////////////////////////// - -typedef LiteRtStatus (*LiteRtDispatchGraphCreateT)( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph* graph); - -typedef LiteRtStatus (*LiteRtDispatchGraphDestroyT)(LiteRtDispatchGraph graph); - -typedef LiteRtStatus (*LiteRtDispatchAddNodeT)( - LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, - LiteRtDispatchNodeType node_type); - -typedef LiteRtStatus (*LiteRtDispatchAddEdgeT)(LiteRtDispatchGraph graph, - LiteRtDispatchEdgeId edge_id); - -typedef LiteRtStatus (*LiteRtDispatchConnectNodeInputT)( - LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, int input_index, - LiteRtDispatchEdgeId edge_id); - -typedef LiteRtStatus (*LiteRtDispatchConnectNodeOutputT)( - LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, int output_index, - LiteRtDispatchEdgeId edge_id); - -typedef LiteRtStatus (*LiteRtDispatchConnectGraphInputT)( - LiteRtDispatchGraph graph, int input_index, LiteRtDispatchEdgeId edge_id); - -typedef LiteRtStatus (*LiteRtDispatchConnectGraphOutputT)( - LiteRtDispatchGraph graph, int output_index, LiteRtDispatchEdgeId edge_id); - -typedef LiteRtStatus (*LiteRtDispatchLoadExecutableT)( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType type, const LiteRtMemBuffer* bytecode_buffer, - LiteRtDispatchExecutableHandle* exec_handle); - -typedef LiteRtStatus (*LiteRtDispatchUnloadExecutableT)( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableHandle exec_handle); - -typedef LiteRtStatus (*LiteRtDispatchAssignNodeFunctionT)( - LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, - LiteRtDispatchExecutableHandle exec_handle, const char* function_name); - -typedef LiteRtStatus (*LiteRtDispatchInvocationContextCreateFromGraphT)( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, - LiteRtDispatchInvocationContext* invocation_context); - -typedef LiteRtStatus (*LiteRtDispatchAnnotateGraphT)(LiteRtDispatchGraph graph, - const char* key, - const char* value); - -typedef LiteRtStatus (*LiteRtDispatchAnnotateNodeT)( - LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, const char* key, - const char* value); - -typedef LiteRtStatus (*LiteRtDispatchAnnotateEdgeT)( - LiteRtDispatchGraph graph, LiteRtDispatchEdgeId edge_id, const char* key, - const char* value); - -typedef struct LiteRtDispatchGraphInterface { - LiteRtDispatchGraphCreateT graph_create; - LiteRtDispatchGraphDestroyT graph_destroy; - LiteRtDispatchAddNodeT add_node; - LiteRtDispatchAddEdgeT add_edge; - LiteRtDispatchConnectNodeInputT connect_node_input; - LiteRtDispatchConnectNodeOutputT connect_node_output; - LiteRtDispatchConnectGraphInputT connect_graph_input; - LiteRtDispatchConnectGraphOutputT connect_graph_output; - LiteRtDispatchLoadExecutableT load_executable; - LiteRtDispatchUnloadExecutableT unload_executable; - LiteRtDispatchAssignNodeFunctionT assign_node_function; - LiteRtDispatchAnnotateGraphT annotate_graph; - LiteRtDispatchAnnotateNodeT annotate_node; - LiteRtDispatchAnnotateEdgeT annotate_edge; - LiteRtDispatchInvocationContextCreateFromGraphT - invocation_context_create_from_graph; -} LiteRtDispatchGraphInterface; - -// ///////////////////////////////////////////////////////////////////////////// - -// FIXME See Vulkan and OpenCL extensions. -typedef struct LiteRtDispatchApi { - LiteRtApiVersion version; - LiteRtDispatchInterface* interface; - LiteRtDispatchAsyncInterface* async_interface; - LiteRtDispatchGraphInterface* graph_interface; -} LiteRtDispatchApi; - -LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_API_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_vendor_c_api_common_test.c b/tensorflow/lite/experimental/litert/vendors/c/litert_vendor_c_api_common_test.c deleted file mode 100644 index 60cedbb927035a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/c/litert_vendor_c_api_common_test.c +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// This file exists to verify that the below header files can build, link, -// and run as C code. -#ifdef __cplusplus -#error "This file should be compiled as C code, not as C++." -#endif - -// Include all the header files in the litert/c directory. -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" // NOLINT -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h" // NOLINT -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" // NOLINT -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h" // NOLINT - -int main(void) { - return 0; -} diff --git a/tensorflow/lite/experimental/litert/vendors/cc/BUILD b/tensorflow/lite/experimental/litert/vendors/cc/BUILD deleted file mode 100644 index 25e6c26462cfab..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/BUILD +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "litert_compiler_plugin", - hdrs = ["litert_compiler_plugin.h"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "conversion", - hdrs = ["conversion.h"], - deps = [ - ":backend_ir", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "@com_google_absl//absl/container:flat_hash_map", - ], -) - -cc_library( - name = "backend_ir", - hdrs = ["backend_ir.h"], - deps = ["//tensorflow/lite/experimental/litert/c:litert_common"], -) - -cc_library( - name = "partition_with_capabilities", - hdrs = ["partition_with_capabilities.h"], - deps = [ - ":conversion", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - ], -) - -cc_library( - name = "convert_graph", - hdrs = ["convert_graph.h"], - deps = [ - ":conversion", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - ], -) - -cc_library( - name = "ir_types", - hdrs = ["ir_types.h"], - deps = [ - ":backend_ir", - ":conversion", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - ], -) - -cc_test( - name = "partition_with_capabilities_test", - srcs = ["partition_with_capabilities_test.cc"], - deps = [ - ":partition_with_capabilities", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:model_graph", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/vendors/examples:example_conversion_impl", - "//tensorflow/lite/experimental/litert/vendors/examples:example_ir", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "convert_graph_test", - srcs = ["convert_graph_test.cc"], - deps = [ - ":backend_ir", - ":convert_graph", - "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:model_graph", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/vendors/examples:example_conversion_impl", - "//tensorflow/lite/experimental/litert/vendors/examples:example_ir", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h b/tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h deleted file mode 100644 index 34cf95bd3643e6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_BACKEND_IR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_BACKEND_IR_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -namespace litert { - -// Interfaces and types for managing backend IR to be targeted by LiteRt for -// compilation. - -// Memory Management -//===--------------------------------------------------------------------------- - -// Callable for allocating a new instance of a backend IR type. This facilitates -// external memory management for the backend IR implementented by the backend. -// It is encouraged for implementations provide pointer stability (consider -// std::list for storage). -template -using BackendIrAllocator = std::function; - -// Allocator for backend tensors. -template -using TensorAllocator = BackendIrAllocator; - -// Allocator for backend ops. -template -using OpAllocator = BackendIrAllocator; - -// Graph Construction -//===--------------------------------------------------------------------------- - -// Wrapper for an in memory graph for a particular backend. Implementations -// should contain an instance of a backend graph that can be iteratively -// constructed via calls to this interface. -template -class BackendGraphBuilder { - public: - // Hook called to initialize state for a new backend graph with a name. This - // will be called once per-instance before any other method. - virtual void InitGraph(std::string graph_name) = 0; - - // Hook called to register a backend tensor once it - // has been converted. This will be called once per tensor. - virtual LiteRtStatus RegisterTensor(BackendTensor& tensor) = 0; - - // Hook called to register a backend op once it has been converted. This will - // be called once per op (in a toplogogical order). All input/output tensors - // will have been registered before called. - virtual LiteRtStatus RegisterOp(BackendOp& op) = 0; - - // Hook called to register a graph when graph - // conversion is completed. Backend graph context should be stored as internal - // state. This will be called once per instance after all ops/tensors have - // been finalized. - virtual LiteRtStatus FinalizeGraph() = 0; - - virtual ~BackendGraphBuilder() = default; -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_BACKEND_IR_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/conversion.h b/tensorflow/lite/experimental/litert/vendors/cc/conversion.h deleted file mode 100644 index 139ba594bb1e8a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/conversion.h +++ /dev/null @@ -1,262 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utility types for mapping LiteRt IR to arbitrary backend specific -// types. Implementations of these types define mapping for ops and tensors -// that may be used in a stndalone fashion. They also may be composed -// to create lowerings of entire graphs with topology. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERSION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERSION_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" - -namespace litert { - -// Interfaces and types for implementing "conversions" that map LiteRt IR to -// backend IR. -// NOTE: Conversions depend on external memory management for the backend IR -// types. User defined conversions are usually expected to leverage callbacks -// to allocate backend IR types rather than constructing them directly. - -// Conversion Result Type -//===--------------------------------------------------------------------------- - -// Result of a one->many general mapping from LiteRt op to any number of -// backend specific ops. Does not own the memory of the backend ops or tensors. -template -struct GeneralConversionResult { - // Ops emitted from translation pattern. - std::vector ops; - - // Any backend tensors used within the results ops. Not relevant when - // size of backend ops == 1. This does not include input/output tensors of the - // op being converted. - std::vector intermediate_tensors; -}; - -// The result of a one->one specialized mapping from LiteRt op to backend op. -template -using SimpleConversionResult = BackendOp*; - -// A tag-type for a conversion result that is a non-error non-match. -struct NoMatch {}; - -// Type union for conversion results. -// TODO(lukeboyer): Update conversion result types to handle the case where -// backend ops add extra inputs. -template -using ConversionResult = - std::variant, - GeneralConversionResult, NoMatch>; - -// Short hand for holds_alternative. -template -bool ConversionIsA(const ConversionResult& result) { - return std::holds_alternative(result); -} - -// Short hand for holds_alternative. -template -bool ConversionMatched( - const ConversionResult& result) { - return !std::holds_alternative(result); -} - -// Short hand for holds_alternative. -template -bool IsSimpleResult(const ConversionResult& result) { - return ConversionIsA>(result); -} - -// Short hand for holds_alternative. -template -bool IsGeneralResult(const ConversionResult& result) { - return ConversionIsA>( - result); -} - -// Short hand for std::get. Also checks if match and wraps in expected. -template -Expected GetConversionResult( - const ConversionResult& result) { - if (ConversionMatched(result)) { - return Expected(std::get(result)); - } - return Error(kLiteRtStatusLegalizeNoMatch); -} - -// Get simple result if there was a match. -template -Expected> GetSimpleConversionResult( - const ConversionResult& result) { - if (!IsSimpleResult(result)) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - return GetConversionResult>(result); -} - -// Get general result if there was a match. -template -Expected> -GetGeneralConversionResult( - const ConversionResult& result) { - if (!IsGeneralResult(result)) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - return GetConversionResult>( - result); -} - -// Common IR Conversion -//===--------------------------------------------------------------------------- - -// User defined callback for converting a LiteRt tensor to a backend tensor. -// These are leveraged in various higher-level conversion routines. -// TensorConverters should not stack allocate memory for the backend tensor. In -// most situations, these will be bound to an external allocator. -template -using TensorConverter = - std::function(const Tensor& litert_tensor)>; - -// User defined callback for creating a TensorConverter. This facilitates -// TensoConverters that are bound to an external allocator. -template -using TensorConverterFactory = std::function( - TensorAllocator alloc)>; - -// Mapping from LiteRt tensor to backend tensor, used during iterative graph -// conversions to store current scope. -template -using TensorMap = absl::flat_hash_map; - -// User-defined hook that calls backend to determine if an op is supported. -template -using Capability = std::function; - -// Legalization -//===--------------------------------------------------------------------------- - -// A legalization is a particlar type of user-defined conversion that is -// scheduled for execution on a particular type of LiteRtOp. They may be -// one-to-one or one-to-many conversions. -template -class Legalization { - private: - using Self = Legalization; - - public: - using Result = ConversionResult; - using TensorConverter = TensorConverter; - using TensorConverterFactory = TensorConverterFactory; - using Ptr = std::unique_ptr; - using TensorAllocator = TensorAllocator; - using OpAllocator = OpAllocator; - using Tensors = std::vector; - - // The type of op to schedule on. - virtual LiteRtOpCode OpToMatch() const = 0; - - // Invoke this legalization on the given LiteRt op. All new backend IR will be - // allocated via given allocators. NOTE: In most cases, input and output - // converters will be the same. They are separated here for compatibility with - // graph-level conversions routines. - Expected Legalize(const Op& litert_op, - TensorConverterFactory input_converter, - TensorConverterFactory output_converter, - TensorAllocator tensor_allocator, - OpAllocator op_allocator) const { - const auto litert_inputs = litert_op.Inputs(); - Tensors inputs(litert_inputs.size()); - auto convert_input = input_converter(tensor_allocator); - - for (size_t i = 0; i < litert_inputs.size(); ++i) { - const auto& litert_input = litert_inputs[i]; - auto result = convert_input(litert_input); - if (!result) { - return result.Error(); - } - inputs[i] = *result; - } - - const auto litert_outputs = litert_op.Outputs(); - Tensors outputs(litert_outputs.size()); - auto convert_output = output_converter(tensor_allocator); - - for (size_t i = 0; i < litert_outputs.size(); ++i) { - const auto& litert_output = litert_outputs[i]; - auto result = convert_output(litert_output); - if (!result) { - return result.Error(); - } - outputs[i] = *result; - } - - return LegalizeImpl(litert_op, inputs, outputs, tensor_allocator, - op_allocator); - } - - virtual ~Legalization() = default; - - private: - // The user defined implementation of a legalization. Users must use the - // given allocators to allocate any new backend IR types (e.g. intermediate - // ops/tensors in the case of a one-to-many legalization). BackendTensors - // corresponding to LiteRt inputs and outputs have been pre-converted. - virtual Expected LegalizeImpl(const Op& litert_op, - const Tensors& inputs, - const Tensors& outputs, - TensorAllocator tensor_allocator, - OpAllocator op_allocator) const = 0; -}; - -// Collection of legalizations for a specific backend. -template -using Legalizations = - std::vector::Ptr>; - -// Map for instance lookup by op code. -template -using LegalizationMap = - absl::flat_hash_map*>; - -// Construct a LegalizationMap from a collection of legalizations. -// TODO: Consider wrapping the legalization map in a class to avoid -// re-constructing it & better syntax. -template -LegalizationMap MakeLegalizationMap( - const Legalizations& legalizations) { - LegalizationMap map; - for (const auto& l : legalizations) { - map.insert({l->OpToMatch(), l.get()}); - } - return map; -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERSION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h b/tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h deleted file mode 100644 index cd7221c7bba028..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERT_GRAPH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERT_GRAPH_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" - -namespace litert { - -// Performs iterative graph conversion with user provided hooks. This function -// traverses the IR in toplogical order, converting ops and tensors with given -// tensor converter and legalizations. Registers converted ops and tensors with -// the backend graph builder after they have been converted. The following are -// true: -// * Each tensor and op will be converted & registered at most once. -// * An ops input and output tensors will be registered before the op is -// converted (and before its registered). -// * The graph builder will be initialized before any registration. -// * The graph builder will be finalized after all registration. -template -LiteRtStatus ConvertGraph( - const Subgraph& subgraph, std::string graph_name, - typename Ir::TensorConverterFactory tensor_converter_factory, - typename Ir::TensorAllocator tensor_alloc, - typename Ir::OpAllocator op_alloc, - const typename Ir::Legalizations& legalizations, - typename Ir::GraphBuilder& builder) { - // Store mapping between evaluated litert tensors and corresponding backend - // tensors. - typename Ir::TensorMap tensor_map; - - // Initialize backend graph builder. - builder.InitGraph(std::move(graph_name)); - - // Convert tensor, add to scope and register in backend graph builder. - auto handle_tensor = [&tensor_map, &builder]( - const auto& litert_tensor, - auto tensor_converter) -> Ir::TensorResult { - auto converted = tensor_converter(litert_tensor); - if (!converted) { - LITERT_LOG(LITERT_ERROR, "Failed to convert tensor %lu", - litert_tensor.Get()); - return converted.Error(); - } - - if (auto status = builder.RegisterTensor(**converted); - status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register tensor %lu, with status %d", - litert_tensor.Get(), status); - return Error(status); - } - - tensor_map.insert({litert_tensor.Get(), *converted}); - return *converted; - }; - - // Wrap provided tensor conversion logic for converting subgraph or op input - // tensors. We want functionality that provides user-defined conversions with - // tensors to be aware of the tensor map and graph builder registration. - auto input_tensor_convert_factory = [tensor_converter_factory, &tensor_map, - handle_tensor](auto tensor_alloc) { - return [tensor_alloc, tensor_converter_factory, &tensor_map, - handle_tensor](const Tensor& litert_tensor) -> Ir::TensorResult { - auto tensor_converter = tensor_converter_factory(tensor_alloc); - - // Check if tensor has been converted already. - auto it = tensor_map.find(litert_tensor.Get()); - const auto in_scope = it != tensor_map.end(); - if (in_scope) { - LITERT_LOG(LITERT_VERBOSE, "Tensor %lu is in scope", - litert_tensor.Get()); - return it->second; - } - - // If its a subgraph input or constant, we can convert it and add to - // scope. - const auto is_cst = litert_tensor.IsConstant(); - const auto is_sg_input = litert_tensor.IsSubgraphInput(); - if (is_sg_input || is_cst) { - return handle_tensor(litert_tensor, tensor_converter); - } - - // Tensor must be added to scope before conversion, or not have a parent - // (e.g. subgraph input or constant) so error at this point. - LITERT_LOG(LITERT_ERROR, "Tensor %lu not handled", litert_tensor.Get()); - return Error(kLiteRtStatusErrorInvalidArgument); - }; - }; - - // Wrap provided tensor conversion logic for op output tensors. Adds to map - // and backend graph after conversion. - auto output_tensor_convert_factory = [tensor_converter_factory, - handle_tensor](auto tensor_alloc) { - return [tensor_alloc, tensor_converter_factory, - handle_tensor](const Tensor& litert_tensor) { - auto tensor_converter = tensor_converter_factory(tensor_alloc); - return handle_tensor(litert_tensor, tensor_converter); - }; - }; - - // Convert all ops in subgraph in toplogical order. - auto legalization_map = Ir::MakeLegalizationMap(legalizations); - for (const auto& op : subgraph.Ops()) { - auto it = legalization_map.find(op.Code()); - if (it == legalization_map.end()) { - LITERT_LOG(LITERT_ERROR, "No legalization found for op %d", op.Code()); - return kLiteRtStatusErrorUnsupported; - } - - auto result = it->second->Legalize(op, input_tensor_convert_factory, - output_tensor_convert_factory, - tensor_alloc, op_alloc); - if (!result) { - LITERT_LOG(LITERT_ERROR, "Failed to legalize op %d, with status %d", - op.Code(), result.Error().Status()); - return result.Error().Status(); - } - - auto simple_result = GetSimpleConversionResult(*result); - if (simple_result) { - if (auto stat = builder.RegisterOp(**simple_result); - stat != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register op %d, with status %d", - op.Code(), stat); - return stat; - } - } - - auto general_result = GetGeneralConversionResult(*result); - if (general_result) { - for (auto* tensor : general_result->intermediate_tensors) { - if (auto stat = builder.RegisterTensor(*tensor); - stat != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, - "Failed to register tensor %d, with status %d", tensor->id, - stat); - return stat; - } - } - - for (auto* op : general_result->ops) { - if (auto stat = builder.RegisterOp(*op); stat != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to register op %d, with status %d", - op->op_code, stat); - return stat; - } - } - } - } - - builder.FinalizeGraph(); - - return kLiteRtStatusOk; -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERT_GRAPH_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/convert_graph_test.cc b/tensorflow/lite/experimental/litert/vendors/cc/convert_graph_test.cc deleted file mode 100644 index 9ad0e0e644e66f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/convert_graph_test.cc +++ /dev/null @@ -1,390 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h" - -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" - -namespace litert { -namespace { - -using ::litert::example::ExampleOpAllocator; -using ::litert::example::ExampleOpType; -using ::litert::example::ExampleTensorAllocator; -using ::litert::example::ExampleTypes; -using ::litert::example::MakeAllLegalizations; -using ::litert::example::MakeTensorConverter; -using ::testing::AllOf; -using ::testing::ElementsAreArray; -using ::testing::Expectation; -using ::testing::ExpectationSet; -using ::testing::Field; -using ::testing::Return; - -static constexpr std::array kDims = {2, 2}; -static constexpr auto kElementType = kLiteRtElementTypeFloat32; -static constexpr absl::string_view kGraphName = "graph_name"; - -TensorType GetTestTensorType() { - return MakeRankedTensorType(kElementType, absl::MakeConstSpan(kDims)); -} - -class MockGraphBuilder - : public BackendGraphBuilder { - public: - MOCK_METHOD(void, InitGraph, (std::string name), (override)); - MOCK_METHOD(LiteRtStatus, RegisterTensor, (ExampleTypes::Tensor & tensor), - (override)); - MOCK_METHOD(LiteRtStatus, RegisterOp, (ExampleTypes::Op & op), (override)); - MOCK_METHOD(LiteRtStatus, FinalizeGraph, (), (override)); -}; - -TEST(ConvertGraphTest, ConvertSingleSimpleConversion) { - LiteRtSubgraphT subgraph; - - auto& op = subgraph.EmplaceOp(); - op.SetOpCode(kLiteRtOpCodeTflMul); - - auto& input1 = subgraph.EmplaceTensor(); - input1.SetType(GetTestTensorType()); - input1.SetName("input1"); - - auto& input2 = subgraph.EmplaceTensor(); - input2.SetType(GetTestTensorType()); - input2.SetName("input2"); - - auto& output = subgraph.EmplaceTensor(); - output.SetType(GetTestTensorType()); - output.SetName("output"); - - internal::AttachInput(&input1, op); - internal::AttachInput(&input2, op); - internal::AttachOutput(&output, op); - - subgraph.Inputs().push_back(&input1); - subgraph.Inputs().push_back(&input2); - subgraph.Outputs().push_back(&output); - - Subgraph litert_subgraph(&subgraph); - - ExampleOpAllocator op_alloc; - ExampleTensorAllocator tensor_alloc; - - MockGraphBuilder builder; - - Expectation init_graph = - EXPECT_CALL(builder, InitGraph(std::string(kGraphName))).Times(1); - - ExpectationSet reg_inputs; - reg_inputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - input1.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - reg_inputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - input2.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - ExpectationSet reg_outputs; - reg_outputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - output.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - auto match_reg_op_args = - AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::MUL), - Field(&ExampleTypes::Op::input_names, - ElementsAreArray({input1.Name(), input2.Name()})), - Field(&ExampleTypes::Op::output_names, - ElementsAreArray({output.Name()}))); - - Expectation reg_op = EXPECT_CALL(builder, RegisterOp(match_reg_op_args)) - .Times(1) - .After(reg_inputs, reg_outputs) - .WillOnce(Return(kLiteRtStatusOk)); - - Expectation finalize_graph = EXPECT_CALL(builder, FinalizeGraph()) - .Times(1) - .After(reg_op) - .WillOnce(Return(kLiteRtStatusOk)); - - auto stat = ConvertGraph( - litert_subgraph, std::string(kGraphName), MakeTensorConverter, - tensor_alloc, op_alloc, MakeAllLegalizations(), builder); - - LITERT_ASSERT_OK(stat); -} - -TEST(ConvertGraphTest, ConvertSingleGeneralConversion) { - LiteRtSubgraphT subgraph; - - auto& op = subgraph.EmplaceOp(); - op.SetOpCode(kLiteRtOpCodeTflAdd); - - tflite::AddOptionsT add_opts; - add_opts.fused_activation_function = tflite::ActivationFunctionType_RELU; - internal::TflOptions tfl_opts; - tfl_opts.Set(std::move(add_opts)); - litert::internal::SetTflOptions(op, std::move(tfl_opts)); - - auto& input1 = subgraph.EmplaceTensor(); - input1.SetType(GetTestTensorType()); - input1.SetName("input1"); - - auto& input2 = subgraph.EmplaceTensor(); - input2.SetType(GetTestTensorType()); - input2.SetName("input2"); - - auto& output = subgraph.EmplaceTensor(); - output.SetType(GetTestTensorType()); - output.SetName("output"); - - internal::AttachInput(&input1, op); - internal::AttachInput(&input2, op); - internal::AttachOutput(&output, op); - - subgraph.Inputs().push_back(&input1); - subgraph.Inputs().push_back(&input2); - subgraph.Outputs().push_back(&output); - - Subgraph litert_subgraph(&subgraph); - - ExampleOpAllocator op_alloc; - ExampleTensorAllocator tensor_alloc; - - MockGraphBuilder builder; - - Expectation init_graph = - EXPECT_CALL(builder, InitGraph(std::string(kGraphName))).Times(1); - - ExpectationSet reg_inputs; - reg_inputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - input1.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - reg_inputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - input2.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - ExpectationSet reg_intermediates; - reg_intermediates += - EXPECT_CALL(builder, - RegisterTensor(Field(&ExampleTypes::Tensor::name, - example::kIntermediateTensorName))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - ExpectationSet reg_outputs; - reg_outputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - output.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - auto match_reg_add_args = - AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::ADD), - Field(&ExampleTypes::Op::input_names, - ElementsAreArray({input1.Name(), input2.Name()})), - Field(&ExampleTypes::Op::output_names, - ElementsAreArray({example::kIntermediateTensorName}))); - - Expectation reg_add = EXPECT_CALL(builder, RegisterOp(match_reg_add_args)) - .Times(1) - .After(reg_inputs, reg_intermediates) - .WillOnce(Return(kLiteRtStatusOk)); - - auto match_reg_relu_args = - AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::RELU), - Field(&ExampleTypes::Op::input_names, - ElementsAreArray({example::kIntermediateTensorName})), - Field(&ExampleTypes::Op::output_names, - ElementsAreArray({output.Name()}))); - - Expectation reg_relu = EXPECT_CALL(builder, RegisterOp(match_reg_relu_args)) - .Times(1) - .After(reg_add, reg_intermediates, reg_outputs) - .WillOnce(Return(kLiteRtStatusOk)); - - Expectation finalize_graph = EXPECT_CALL(builder, FinalizeGraph()) - .Times(1) - .After(reg_relu) - .WillOnce(Return(kLiteRtStatusOk)); - - auto stat = ConvertGraph( - litert_subgraph, std::string(kGraphName), MakeTensorConverter, - tensor_alloc, op_alloc, MakeAllLegalizations(), builder); - - LITERT_ASSERT_OK(stat); -} - -TEST(ConvertGraphTest, ConvertMultipleOps) { - LiteRtSubgraphT subgraph; - - auto& op = subgraph.EmplaceOp(); - op.SetOpCode(kLiteRtOpCodeTflMul); - - auto& input1 = subgraph.EmplaceTensor(); - input1.SetType(GetTestTensorType()); - input1.SetName("input1"); - - auto& input2 = subgraph.EmplaceTensor(); - input2.SetType(GetTestTensorType()); - input2.SetName("input2"); - - auto& output1 = subgraph.EmplaceTensor(); - output1.SetType(GetTestTensorType()); - output1.SetName("output1"); - - auto& cst = subgraph.EmplaceTensor(); - OwningBufferRef weights(8); - SetWeightsFromUnownedBuffer(cst.Weights(), weights); - cst.SetName("cst"); - cst.SetType(GetTestTensorType()); - - auto& op2 = subgraph.EmplaceOp(); - op2.SetOpCode(kLiteRtOpCodeTflAdd); - - auto& output2 = subgraph.EmplaceTensor(); - output2.SetType(GetTestTensorType()); - output2.SetName("output2"); - - internal::AttachInput(&input1, op); - internal::AttachInput(&input2, op); - internal::AttachOutput(&output1, op); - - internal::AttachInput(&output1, op2); - internal::AttachInput(&cst, op2); - internal::AttachOutput(&output2, op2); - - subgraph.Inputs().push_back(&input1); - subgraph.Inputs().push_back(&input2); - subgraph.Outputs().push_back(&output2); - - Subgraph litert_subgraph(&subgraph); - - ExampleOpAllocator op_alloc; - ExampleTensorAllocator tensor_alloc; - - MockGraphBuilder builder; - - Expectation init_graph = - EXPECT_CALL(builder, InitGraph(std::string(kGraphName))).Times(1); - - ExpectationSet reg_inputs; - reg_inputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - input1.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - reg_inputs += - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - input2.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - Expectation reg_output1 = - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - output1.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - Expectation reg_cst = - EXPECT_CALL(builder, RegisterTensor( - Field(&ExampleTypes::Tensor::name, cst.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - Expectation reg_output2 = - EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, - output2.Name()))) - .Times(1) - .After(init_graph) - .WillOnce(Return(kLiteRtStatusOk)); - - auto match_reg_op1_args = - AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::MUL), - Field(&ExampleTypes::Op::input_names, - ElementsAreArray({input1.Name(), input2.Name()})), - Field(&ExampleTypes::Op::output_names, - ElementsAreArray({output1.Name()}))); - - Expectation reg_op1 = EXPECT_CALL(builder, RegisterOp(match_reg_op1_args)) - .Times(1) - .After(reg_inputs, reg_output1) - .WillOnce(Return(kLiteRtStatusOk)); - - auto match_reg_op2_args = - AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::ADD), - Field(&ExampleTypes::Op::input_names, - ElementsAreArray({output1.Name(), cst.Name()})), - Field(&ExampleTypes::Op::output_names, - ElementsAreArray({output2.Name()}))); - - Expectation reg_op2 = EXPECT_CALL(builder, RegisterOp(match_reg_op2_args)) - .Times(1) - .After(reg_op1, reg_cst, reg_output2, reg_output1) - .WillOnce(Return(kLiteRtStatusOk)); - - Expectation finalize_graph = EXPECT_CALL(builder, FinalizeGraph()) - .Times(1) - .After(reg_op2) - .WillOnce(Return(kLiteRtStatusOk)); - - auto stat = ConvertGraph( - litert_subgraph, std::string(kGraphName), MakeTensorConverter, - tensor_alloc, op_alloc, MakeAllLegalizations(), builder); - - LITERT_ASSERT_OK(stat); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/cc/ir_types.h b/tensorflow/lite/experimental/litert/vendors/cc/ir_types.h deleted file mode 100644 index a1da917de18a74..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/ir_types.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_IR_TYPES_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_IR_TYPES_H_ - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" - -namespace litert { - -// Holds particular backends IR template aliases for convenience. -template -struct IrTypes { - using Op = BackendOp; - using Tensor = BackendTensor; - using OpAllocator = OpAllocator; - using TensorAllocator = TensorAllocator; - using GraphBuilder = BackendGraphBuilder; - using GeneralConversionResult = GeneralConversionResult; - using SimpleConversionResult = SimpleConversionResult; - using ConversionResult = Expected>; - using Legalization = Legalization; - using Legalizations = Legalizations; - using LegalizationMap = LegalizationMap; - using TensorConverter = TensorConverter; - using TensorResult = Expected; - using TensorConverterFactory = TensorConverterFactory; - using TensorMap = TensorMap; - using Capability = Capability; - // NOLINTNEXTLINE - inline static auto MakeLegalizationMap = - litert::MakeLegalizationMap; -}; - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_IR_TYPES_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h b/tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h deleted file mode 100644 index 654457f0f75e24..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_LITERT_COMPILER_PLUGIN_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_LITERT_COMPILER_PLUGIN_H_ - -#include - -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" - -namespace litert { - -// Deleter for incomplete compiler plugin type. -struct LiteRtCompilerPluginDeleter { - void operator()(LiteRtCompilerPlugin plugin) { - if (plugin != nullptr) { - LiteRtDestroyCompilerPlugin(plugin); - } - } -}; - -// Smart pointer wrapper for incomplete plugin type. -using PluginPtr = - std::unique_ptr; - -// Initialize a plugin via c-api and wrap result in smart pointer. -inline PluginPtr CreatePlugin() { - LiteRtCompilerPlugin plugin; - LITERT_CHECK_STATUS_OK(LiteRtCreateCompilerPlugin(&plugin)); - return PluginPtr(plugin); -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_LITERT_COMPILER_PLUGIN_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h b/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h deleted file mode 100644 index a462d1744c3886..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_PARTITION_WITH_CAPABILITIES_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_PARTITION_WITH_CAPABILITIES_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" - -namespace litert { - -// Higher-level functions for partitioning by leveraging user-defined -// conversions. This method selects ops for partitioning via a callback that -// checks if an op is supported by the backend. - -// Selects ops for partitioning from given subgraph based on given Capability -// check. Returns all ops in the given supbgraph that are supported by the -// backend. Suitable for use in implementing LiteRtCompilerPluginPartition. Any -// allocations of new backend ir types will be done through given external -// allocators. -// NOTE: A missing legalization or any legalization failure will result in -// an op not being supported, rather than a failure of this function. -template -Expected> PartitionWithCapabilities( - const typename Ir::Legalizations& legalizations, - typename Ir::Capability capability, - typename Ir::TensorConverterFactory convert_tensor_fact, - typename Ir::TensorAllocator tensor_allocator, - typename Ir::OpAllocator op_allocator, const Subgraph& litert_subgraph) { - std::vector results; - - // Build map for legalization lookup by op code. - auto map = Ir::MakeLegalizationMap(legalizations); - - // Convert all ops from the given subgraph and check backend support. - for (const auto& litert_op : litert_subgraph.Ops()) { - const auto code = litert_op.Code(); - LITERT_LOG(LITERT_INFO, "Checking support for LiteRtOp: %d", code); - - auto it = map.find(code); - if (it == map.end()) { - LITERT_LOG(LITERT_WARNING, "No legalization found for LiteRtOp: %d", - code); - continue; - } - - // Call user-defined conversion. - auto result = it->second->Legalize(litert_op, convert_tensor_fact, - convert_tensor_fact, tensor_allocator, - op_allocator); - if (!result) { - LITERT_LOG(LITERT_WARNING, "Failed to legalize LiteRtOp: %d", code); - continue; - } - - if (auto simple_result = GetSimpleConversionResult(*result)) { - if (capability(*simple_result)) { - LITERT_LOG(LITERT_INFO, "Selected LiteRtOp: %d", litert_op.Code()); - results.push_back(litert_op.Get()); - } - continue; - } - - // Check all ops emitted from a one-to-many conversion are supported. - if (auto gen_result = GetGeneralConversionResult(*result)) { - const auto b_ops_start = gen_result->ops.cbegin(); - const auto b_ops_end = gen_result->ops.cend(); - if (std::all_of(b_ops_start, b_ops_end, capability)) { - LITERT_LOG(LITERT_INFO, "Selected LiteRtOp: %d", litert_op.Code()); - results.push_back(litert_op.Get()); - } - continue; - } - } - - return results; -} - -} // namespace litert - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_PARTITION_WITH_CAPABILITIES_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities_test.cc b/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities_test.cc deleted file mode 100644 index cece5adb48dca9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities_test.cc +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utility types for mapping LiteRt IR to arbitrary backend specific -// types. Implementations of these types define mapping for ops and tensors -// that may be used in a stndalone fashion. They also may be composed -// to create lowerings of entire graphs with topology. - -#include "tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h" - -#include -#include -#include - -#include -#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" - -namespace litert { -namespace { - -using ::litert::example::ExampleLegalizeAdd; -using ::litert::example::ExampleLegalizeMul; -using ::litert::example::ExampleOpAllocator; -using ::litert::example::ExampleOpType; -using ::litert::example::ExampleTensorAllocator; -using ::litert::example::ExampleTypes; -using ::litert::example::MakeTensorConverter; - -bool ExampleCapability(const ExampleTypes::Op* op) { - return op->op_code == ExampleOpType::ADD || - op->op_code == ExampleOpType::RELU; -} - -TEST(PartitionWithCapabilitiesTest, EmptyGraph) { - ExampleTypes::Legalizations legalizations; - legalizations.push_back(ExampleLegalizeAdd::Make()); - - LiteRtSubgraphT subgraph; - Subgraph litert_subgraph(&subgraph); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - auto ops = PartitionWithCapabilities( - legalizations, ExampleCapability, MakeTensorConverter, tensor_alloc, - op_alloc, litert_subgraph); - ASSERT_TRUE(ops); - EXPECT_TRUE(ops->empty()); -} - -TEST(PartitionWithCapabilitiesTest, SingleSelectedOp) { - static constexpr std::array kDims = {2, 2}; - - ExampleTypes::Legalizations legalizations; - legalizations.push_back(ExampleLegalizeAdd::Make()); - - LiteRtSubgraphT subgraph; - - const auto type = MakeRankedTensorType(kLiteRtElementTypeFloat32, kDims); - - auto& input1 = subgraph.EmplaceTensor(); - input1.SetType(type); - - auto& input2 = subgraph.EmplaceTensor(); - input2.SetType(type); - - auto& output = subgraph.EmplaceTensor(); - output.SetType(type); - - auto& op = subgraph.EmplaceOp(); - op.SetOpCode(kLiteRtOpCodeTflAdd); - - internal::AttachInput(&input1, op); - internal::AttachInput(&input2, op); - internal::AttachOutput(&output, op); - - Subgraph litert_subgraph(&subgraph); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - auto ops = PartitionWithCapabilities( - legalizations, ExampleCapability, MakeTensorConverter, tensor_alloc, - op_alloc, litert_subgraph); - - ASSERT_TRUE(ops); - EXPECT_EQ(ops->size(), 1); -} - -TEST(PartitionWithCapabilitiesTest, MultiSelectedOp) { - static constexpr std::array kDims = {2, 2}; - - ExampleTypes::Legalizations legalizations; - legalizations.push_back(ExampleLegalizeAdd::Make()); - - LiteRtSubgraphT subgraph; - - const auto type = MakeRankedTensorType(kLiteRtElementTypeFloat32, kDims); - - auto& add1_input = subgraph.EmplaceTensor(); - add1_input.SetType(type); - auto& add1_output = subgraph.EmplaceTensor(); - add1_output.SetType(type); - auto& add1 = subgraph.EmplaceOp(); - add1.SetOpCode(kLiteRtOpCodeTflAdd); - - internal::AttachInput(&add1_input, add1); - internal::AttachInput(&add1_input, add1); - internal::AttachOutput(&add1_output, add1); - - auto& mul_output = subgraph.EmplaceTensor(); - mul_output.SetType(type); - auto& mul = subgraph.EmplaceOp(); - mul.SetOpCode(kLiteRtOpCodeTflMul); - - internal::AttachInput(&add1_output, mul); - internal::AttachOutput(&mul_output, mul); - - auto& add2_output = subgraph.EmplaceTensor(); - add2_output.SetType(type); - auto& add2 = subgraph.EmplaceOp(); - add2.SetOpCode(kLiteRtOpCodeTflAdd); - - internal::AttachInput(&mul_output, add2); - internal::AttachInput(&mul_output, add2); - internal::AttachOutput(&add2_output, add2); - - Subgraph litert_subgraph(&subgraph); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - auto ops = PartitionWithCapabilities( - legalizations, ExampleCapability, MakeTensorConverter, tensor_alloc, - op_alloc, litert_subgraph); - - ASSERT_TRUE(ops); - - ASSERT_EQ(ops->size(), 2); - EXPECT_EQ(ops->front(), &add1); - EXPECT_EQ(ops->back(), &add2); -} - -TEST(PartitionWithCapabilitiesTest, WithGeneralResult) { - static constexpr std::array kDims = {2, 2}; - - ExampleTypes::Legalizations legalizations; - legalizations.push_back(ExampleLegalizeAdd::Make()); - - LiteRtSubgraphT subgraph; - - const auto type = MakeRankedTensorType(kLiteRtElementTypeFloat32, kDims); - - auto& add1_input = subgraph.EmplaceTensor(); - add1_input.SetType(type); - auto& add1_output = subgraph.EmplaceTensor(); - add1_output.SetType(type); - auto& add1 = subgraph.EmplaceOp(); - add1.SetOpCode(kLiteRtOpCodeTflAdd); - - internal::AttachInput(&add1_input, add1); - internal::AttachInput(&add1_input, add1); - internal::AttachOutput(&add1_output, add1); - - tflite::AddOptionsT add_opts; - add_opts.fused_activation_function = tflite::ActivationFunctionType_RELU; - internal::TflOptions tfl_opts; - tfl_opts.Set(std::move(add_opts)); - litert::internal::SetTflOptions(add1, std::move(tfl_opts)); - - Subgraph litert_subgraph(&subgraph); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - auto ops = PartitionWithCapabilities( - legalizations, ExampleCapability, MakeTensorConverter, tensor_alloc, - op_alloc, litert_subgraph); - - ASSERT_TRUE(ops); - - ASSERT_EQ(ops->size(), 1); - EXPECT_EQ(ops->front(), &add1); -} - -} // namespace - -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/examples/BUILD b/tensorflow/lite/experimental/litert/vendors/examples/BUILD deleted file mode 100644 index 16427953b936fe..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/BUILD +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:private"], -) - -litert_dynamic_lib( - name = "example_plugin", - srcs = [ - "example_plugin.cc", - "example_plugin_common.cc", - "example_plugin_common.h", - ], - hdrs = ["//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h"], - export_litert_only = True, - linkstatic = 1, - shared_lib_name = "example_plugin_so", - so_name = "libLiteRtCompilerPlugin_Example.so", - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/cc:litert_op_options", - ], -) - -cc_test( - name = "example_plugin_test", - srcs = [ - "example_plugin_test.cc", - ], - data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], - deps = [ - ":example_plugin", # buildcleaner: keep - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/vendors/cc:litert_compiler_plugin", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "example_conversion_impl", - srcs = ["example_conversion_impl.cc"], - hdrs = ["example_conversion_impl.h"], - visibility = ["//tensorflow/lite/experimental/litert/vendors/cc:__pkg__"], - deps = [ - ":example_ir", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/cc:backend_ir", - "//tensorflow/lite/experimental/litert/vendors/cc:conversion", - "//tensorflow/lite/experimental/litert/vendors/cc:ir_types", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "example_conversion_impl_test", - srcs = ["example_conversion_impl_test.cc"], - deps = [ - ":example_conversion_impl", - ":example_ir", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/core/model:model_graph", - "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/vendors/cc:conversion", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "example_ir", - srcs = ["example_ir.cc"], - hdrs = ["example_ir.h"], - visibility = ["//tensorflow/lite/experimental/litert/vendors/cc:__pkg__"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/vendors/cc:backend_ir", - "//tensorflow/lite/experimental/litert/vendors/cc:ir_types", - ], -) - -cc_library( - name = "example_plugin_with_conversions", - srcs = [ - "example_plugin_common.cc", - "example_plugin_common.h", - "example_plugin_with_conversions.cc", - ], - deps = [ - ":example_conversion_impl", - ":example_ir", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", - "//tensorflow/lite/experimental/litert/vendors/cc:convert_graph", - "//tensorflow/lite/experimental/litert/vendors/cc:partition_with_capabilities", - "@com_google_absl//absl/strings:str_format", - ], -) - -cc_test( - name = "example_plugin_with_conversions_test", - srcs = ["example_plugin_with_conversions_test.cc"], - data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], - deps = [ - ":example_plugin_with_conversions", # buildcleaner: keep - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", - "//tensorflow/lite/experimental/litert/vendors/cc:litert_compiler_plugin", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.cc deleted file mode 100644 index fa6e163aee4b77..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" - -namespace litert::example { - -TensorConverter MakeTensorConverter( - TensorAllocator alloc) { - return [alloc](const Tensor& litert_tensor) -> Expected { - auto& tensor = *alloc(); - tensor.name = litert_tensor.Name(); - - auto litert_type = litert_tensor.RankedTensorType(); - if (!litert_type) { - return Error(litert_type.Error().Status()); - } - - const auto litert_dims = litert_type->Layout().Dimensions(); - - tensor.dims.assign(litert_dims.cbegin(), litert_dims.cend()); - - switch (litert_tensor.RankedTensorType()->ElementType()) { - case ElementType::Float32: - tensor.type = ExampleTensorType::FLOAT; - break; - case ElementType::Int32: - tensor.type = ExampleTensorType::INT; - break; - default: - return Error(kLiteRtStatusErrorInvalidArgument); - } - - return &tensor; - }; -} - -ExampleTypes::Legalizations MakeAllLegalizations() { - ExampleTypes::Legalizations legalizations; - legalizations.push_back(ExampleLegalizeMul::Make()); - legalizations.push_back(ExampleLegalizeAdd::Make()); - return legalizations; -} - -} // namespace litert::example diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h b/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h deleted file mode 100644 index 64f3199bc363df..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_CONVERSION_IMPL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_CONVERSION_IMPL_H_ - -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/ir_types.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" - -namespace litert::example { - -// Conversion type implementations for the fictional "example" backend. - -ExampleTypes::TensorConverter MakeTensorConverter( - ExampleTypes::TensorAllocator alloc); - -static constexpr absl::string_view kIntermediateTensorName = - "intermediate_bin_output"; - -// Example legalization for simple binary ops. -template -class ExampleBinOpLegalization : public Legalization { - private: - using Self = ExampleBinOpLegalization; - - public: - using Ptr = std::unique_ptr; - - static Ptr Make() { return std::make_unique(); } - - // Return the litert op code to match on. - LiteRtOpCode OpToMatch() const override { return LiteRtOpType; } - - // Determines if the given litert op has a fused relu attribute. - bool HasFusedRelu(const Op& litert_op) const { - if constexpr (LiteRtOpType != kLiteRtOpCodeTflAdd) { - return false; - } - uint32_t faf; - if (LiteRtGetAddFusedActivationOption(litert_op.Get(), &faf) != - kLiteRtStatusOk) { - return false; - } - return faf == 1; - } - - // Transforms LiteRtAdd op into example op definition using the tensor - // converter to map tensors within. - ExampleTypes::ConversionResult LegalizeImpl( - const Op& litert_op, const Tensors& inputs, const Tensors& outputs, - ExampleTypes::TensorAllocator tensor_allocator, - ExampleTypes::OpAllocator op_allocator) const override { - ABSL_DCHECK_EQ(litert_op.Code(), LiteRtOpType); - - auto& bin_op = *op_allocator(); - bin_op.op_code = BackendOpType; - - if (inputs.size() != 2 || outputs.size() != 1) { - return Error(kLiteRtStatusErrorInvalidArgument); - } - - for (const auto* input : inputs) { - bin_op.inputs.push_back(input->id); - bin_op.input_names.push_back(input->name); - } - - auto& output_tensor = *outputs.front(); - if (!HasFusedRelu(litert_op)) { - bin_op.outputs.push_back(output_tensor.id); - bin_op.output_names.push_back(output_tensor.name); - return Expected(&bin_op); - } - - auto* bin_output = tensor_allocator(); - bin_output->dims = output_tensor.dims; - bin_output->type = output_tensor.type; - bin_output->name = std::string(kIntermediateTensorName); - bin_op.outputs.push_back(bin_output->id); - bin_op.output_names.push_back(bin_output->name); - - auto& relu = *op_allocator(); - relu.op_code = ExampleOpType::RELU; - relu.inputs.push_back(bin_output->id); - relu.input_names.push_back(bin_output->name); - relu.outputs.push_back(output_tensor.id); - relu.output_names.push_back(output_tensor.name); - - ExampleTypes::GeneralConversionResult result; - result.ops.push_back(&bin_op); - result.ops.push_back(&relu); - result.intermediate_tensors.push_back(bin_output); - - return ExampleTypes::ConversionResult(result); - } -}; - -using ExampleLegalizeAdd = - ExampleBinOpLegalization; -using ExampleLegalizeMul = - ExampleBinOpLegalization; - -ExampleTypes::Legalizations MakeAllLegalizations(); - -} // namespace litert::example - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_CONVERSION_IMPL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl_test.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl_test.cc deleted file mode 100644 index 8baf028313eda3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl_test.cc +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" - -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" -#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace litert::example { -namespace { - -using ::testing::ElementsAreArray; -using ::testing::HasSubstr; - -TEST(ExampleConversionImplTest, ConvertTensor) { - static constexpr std::array kDims = {2, 2}; - static constexpr absl::string_view kName = "foo"; - - LiteRtTensorT litert_tensor; - litert_tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - litert_tensor.SetName(std::string(kName)); - - ExampleTensorAllocator tensor_alloc; - auto tensor_convert = MakeTensorConverter(tensor_alloc); - - auto& example_tensor = **tensor_convert(Tensor(&litert_tensor)); - EXPECT_EQ(example_tensor.type, ExampleTensorType::FLOAT); - EXPECT_THAT(example_tensor.dims, ElementsAreArray(kDims)); - EXPECT_EQ(example_tensor.name, kName); -} - -TEST(ExampleConversionImplTest, ExampleGraphBuilder) { - ExampleTensor input; - input.type = ExampleTensorType::FLOAT; - input.dims = {2, 2}; - input.id = 1; - - ExampleTensor output; - output.type = ExampleTensorType::INT; - output.dims = {3, 3}; - output.id = 2; - - ExampleOp op; - op.op_code = ExampleOpType::ADD; - op.inputs = {1}; - op.outputs = {2}; - - ExampleGraphBuilder builder; - static constexpr absl::string_view kName = "FOO_GRAPH"; - - builder.InitGraph(std::string(kName)); - LITERT_ASSERT_OK(builder.RegisterTensor(input)); - LITERT_ASSERT_OK(builder.RegisterOp(op)); - LITERT_ASSERT_OK(builder.RegisterTensor(output)); - LITERT_ASSERT_OK(builder.FinalizeGraph()); - - const auto serialized = builder.Serialize(); - EXPECT_THAT(serialized, HasSubstr("1FLOAT[2, 2]")); - EXPECT_THAT(serialized, HasSubstr("2INT[3, 3]")); - EXPECT_THAT(serialized, HasSubstr("ADD(1)->(2)")); - EXPECT_THAT(serialized, HasSubstr("FINALIZED")); - EXPECT_THAT(serialized, HasSubstr(kName)); -} - -TEST(ExampleConversionImplTest, LegalizeAddSimpleResult) { - static constexpr std::array kDims = {2, 2}; - - LiteRtTensorT input1; - input1.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - input1.SetName("input1"); - - LiteRtTensorT input2; - input2.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - input2.SetName("input2"); - - LiteRtTensorT output; - output.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - output.SetName("output"); - - LiteRtOpT op; - op.SetOpCode(kLiteRtOpCodeTflAdd); - internal::AttachInput(&input1, op); - internal::AttachInput(&input2, op); - internal::AttachOutput(&output, op); - - tflite::AddOptionsT add_opts; - add_opts.fused_activation_function = tflite::ActivationFunctionType_NONE; - internal::TflOptions tfl_opts; - tfl_opts.Set(std::move(add_opts)); - litert::internal::SetTflOptions(op, std::move(tfl_opts)); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - ExampleLegalizeAdd legalize_add; - EXPECT_EQ(legalize_add.OpToMatch(), kLiteRtOpCodeTflAdd); - - auto legalized = - legalize_add.Legalize(Op(&op), MakeTensorConverter, MakeTensorConverter, - tensor_alloc, op_alloc); - - ASSERT_TRUE(legalized); - - auto simple_result = GetSimpleConversionResult(*legalized); - ASSERT_TRUE(simple_result); - auto& example_op = **simple_result; - - EXPECT_EQ(example_op.op_code, ExampleOpType::ADD); - EXPECT_THAT(example_op.inputs, ElementsAreArray({0, 1})); - EXPECT_THAT(example_op.input_names, - ElementsAreArray({input1.Name(), input2.Name()})); - EXPECT_THAT(example_op.outputs, ElementsAreArray({2})); - EXPECT_THAT(example_op.output_names, ElementsAreArray({output.Name()})); -} - -TEST(ExampleConversionImplTest, LegalizeAddGeneralResult) { - static constexpr std::array kDims = {2, 2}; - LiteRtTensorT input1; - input1.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - input1.SetName("input1"); - - LiteRtTensorT input2; - input2.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - input2.SetName("input2"); - - LiteRtTensorT output; - output.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, - absl::MakeConstSpan(kDims))); - output.SetName("output"); - - LiteRtOpT op; - op.SetOpCode(kLiteRtOpCodeTflAdd); - internal::AttachInput(&input1, op); - internal::AttachInput(&input2, op); - internal::AttachOutput(&output, op); - - tflite::AddOptionsT add_opts; - add_opts.fused_activation_function = tflite::ActivationFunctionType_RELU; - internal::TflOptions tfl_opts; - tfl_opts.Set(std::move(add_opts)); - litert::internal::SetTflOptions(op, std::move(tfl_opts)); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - auto legalize_add = ExampleLegalizeAdd::Make(); - EXPECT_EQ(legalize_add->OpToMatch(), kLiteRtOpCodeTflAdd); - - auto legalized = - legalize_add->Legalize(Op(&op), MakeTensorConverter, MakeTensorConverter, - tensor_alloc, op_alloc); - ASSERT_TRUE(legalized); - - auto gen_result = GetGeneralConversionResult(*legalized); - ASSERT_TRUE(gen_result); - - ASSERT_EQ(gen_result->ops.size(), 2); - EXPECT_EQ(gen_result->ops[0]->op_code, ExampleOpType::ADD); - EXPECT_THAT(gen_result->ops[0]->inputs, ElementsAreArray({0, 1})); - EXPECT_THAT(gen_result->ops[0]->input_names, - ElementsAreArray({input1.Name(), input2.Name()})); - EXPECT_THAT(gen_result->ops[0]->outputs, ElementsAreArray({3})); - EXPECT_THAT(gen_result->ops[0]->output_names, - ElementsAreArray({kIntermediateTensorName})); - EXPECT_EQ(gen_result->ops[1]->op_code, ExampleOpType::RELU); - EXPECT_THAT(gen_result->ops[1]->inputs, ElementsAreArray({3})); - EXPECT_THAT(gen_result->ops[1]->input_names, - ElementsAreArray({kIntermediateTensorName})); - EXPECT_THAT(gen_result->ops[1]->outputs, ElementsAreArray({2})); - EXPECT_THAT(gen_result->ops[1]->output_names, - ElementsAreArray({output.Name()})); - EXPECT_EQ(gen_result->intermediate_tensors.size(), 1); - EXPECT_EQ(gen_result->intermediate_tensors.front()->id, 3); - EXPECT_EQ(gen_result->intermediate_tensors.front()->name, - kIntermediateTensorName); -} - -} // namespace - -} // namespace litert::example diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_ir.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_ir.cc deleted file mode 100644 index da06b617d9f15b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_ir.cc +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" - -namespace litert::example { - -namespace { - -template -void PrintWithCommas(It start, It end, std::ostream& out) { - for (auto it = start; it < end; ++it) { - out << std::to_string(*it); - if (it != end - 1) { - out << ", "; - } - } -} - -} // namespace - -LiteRtStatus ExampleGraphBuilder::RegisterOp(ExampleOp& op) { - switch (op.op_code) { - case ExampleOpType::ADD: - example_graph_ << "ADD"; - break; - case ExampleOpType::MUL: - example_graph_ << "MUL"; - break; - case ExampleOpType::RELU: - example_graph_ << "RELU"; - break; - } - example_graph_ << "("; - PrintWithCommas(op.inputs.cbegin(), op.inputs.cend(), example_graph_); - example_graph_ << ")->("; - PrintWithCommas(op.outputs.cbegin(), op.outputs.cend(), example_graph_); - example_graph_ << ")"; - return kLiteRtStatusOk; -} - -LiteRtStatus ExampleGraphBuilder::RegisterTensor(ExampleTensor& tensor) { - example_graph_ << std::to_string(tensor.id); - switch (tensor.type) { - case ExampleTensorType::FLOAT: - example_graph_ << "FLOAT"; - break; - case ExampleTensorType::INT: - example_graph_ << "INT"; - break; - } - example_graph_ << "["; - PrintWithCommas(tensor.dims.cbegin(), tensor.dims.cend(), example_graph_); - example_graph_ << "]"; - return kLiteRtStatusOk; -} - -LiteRtStatus ExampleGraphBuilder::FinalizeGraph() { - example_graph_ << "FINALIZED"; - return kLiteRtStatusOk; -} - -void ExampleGraphBuilder::InitGraph(std::string graph_name) { - example_graph_ << "name=" << graph_name << "\n"; -} - -std::string ExampleGraphBuilder::Serialize() const { - return example_graph_.str(); -} - -} // namespace litert::example diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_ir.h b/tensorflow/lite/experimental/litert/vendors/examples/example_ir.h deleted file mode 100644 index e423a53f382b8d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_ir.h +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_IR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_IR_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/ir_types.h" - -namespace litert::example { - -// Example IR wrapper types for an imaginary backend. - -// Example backend knows only float and int 32. -enum class ExampleTensorType { - FLOAT, - INT, -}; - -// Example backend tensor wrapper that stores the type and shape and unique ID. -struct ExampleTensor { - using Id = int32_t; - ExampleTensorType type; - std::vector dims; - std::string name; - Id id = -1; -}; - -// Example backend knows only a few simple ops. -enum class ExampleOpType { - ADD, - MUL, - RELU, -}; - -// Example backend op that stores op type as well as input and output tensor -// IDs and names. -struct ExampleOp { - ExampleOpType op_code; - std::vector inputs; - std::vector input_names; - std::vector outputs; - std::vector output_names; -}; - -// Simple allocator(s) for example example IR types that provides pointer -// stability. -template -class ExampleIrAllocatorBase { - public: - ExampleIrAllocatorBase(const ExampleIrAllocatorBase&) = delete; - ExampleIrAllocatorBase& operator=(const ExampleIrAllocatorBase&) = delete; - ExampleIrAllocatorBase() = default; - - protected: - std::list ir_; -}; - -// Allocator for example tensors that provides pointer stability and unique IDs. -class ExampleTensorAllocator : public ExampleIrAllocatorBase { - private: - using Alloc = BackendIrAllocator; - - public: - ExampleTensor* operator()() { - auto& tensor = this->ir_.emplace_back(); - tensor.id = this->next_id_++; - return &tensor; - } - - // Return lambda instead of implicit copy construction when converting to - // function type. - // NOLINTNEXTLINE - operator Alloc() { - return [this]() { return this->operator()(); }; - } - - ExampleTensorAllocator(const ExampleTensorAllocator&) = delete; - ExampleTensorAllocator& operator=(const ExampleTensorAllocator&) = delete; - ExampleTensorAllocator() = default; - - private: - uint32_t next_id_ = 0; -}; - -// Allocator for example ops that provides pointer stability. -class ExampleOpAllocator : public ExampleIrAllocatorBase { - private: - using Alloc = BackendIrAllocator; - - public: - ExampleOp* operator()() { return &this->ir_.emplace_back(); } - - // Return lambda instead of implicit copy construction when converting to - // function type. - // NOLINTNEXTLINE - operator Alloc() { - return [this]() { return this->operator()(); }; - } - - ExampleOpAllocator(const ExampleOpAllocator&) = delete; - ExampleOpAllocator& operator=(const ExampleOpAllocator&) = delete; - ExampleOpAllocator() = default; -}; - -// Builder for graph conversion to example IR. The internal example IR graph is -// simply a string representation of the graph. -class ExampleGraphBuilder - : public BackendGraphBuilder { - public: - // Prefixes ir string. - void InitGraph(std::string graph_name) override; - - // Registers tensor into the currrent graph by simply appending its string - // representation. - LiteRtStatus RegisterTensor(ExampleTensor& tensor) override; - - // Registers op into the currrent graph by simply appending its string - // representation. - LiteRtStatus RegisterOp(ExampleOp& op) override; - - // Simply appends tag to IR string. - LiteRtStatus FinalizeGraph() override; - - // Gets the serialized IR representation. - std::string Serialize() const; - - private: - std::stringstream example_graph_; -}; - -using ExampleTypes = IrTypes; - -} // namespace litert::example - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_IR_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc deleted file mode 100644 index dff15a4490ec9d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_op_options.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h" - -// A simple compiler plugin example that implements everything directly. -// This plugin matches on mul ops, and emits "byte code" that is simply -// a string representative of the ops consumed. - -// Plugins can hold state. -struct LiteRtCompilerPluginT {}; - -LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { - *compiler_plugin = new LiteRtCompilerPluginT; - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { - delete compiler_plugin; -} - -LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtSubgraph subgraph, - LiteRtOpList selected_ops) { - ::litert::Subgraph main_subgraph(subgraph); - for (const auto& op : main_subgraph.Ops()) { - if (op.Code() == kLiteRtOpCodeTflMul) { - LITERT_RETURN_IF_ERROR(LiteRtPushOp(selected_ops, op.Get(), 0)); - } else if (op.Code() == kLiteRtOpCodeTflSub) { - LITERT_RETURN_IF_ERROR(LiteRtPushOp(selected_ops, op.Get(), 1)); - } else if (op.Code() == kLiteRtOpCodeShloComposite) { - const auto opts = - litert::GetOptionsAs(op.Get()); - if (!opts) { - return opts.Error().Status(); - } - if (opts->name == "odml.rms_norm") { - LITERT_RETURN_IF_ERROR(LiteRtPushOp(selected_ops, op.Get(), 0)); - } - } - } - return kLiteRtStatusOk; -} - -namespace { - -LiteRtStatus CompileSinglePartition(LiteRtParamIndex partition_index, - LiteRtSubgraph subgraph, - LiteRtCompiledResultT& result, - int byte_code_idx) { - const litert::Subgraph sg(subgraph); - int num_muls_in_partition = 0; - for (const auto& op : sg.Ops()) { - if (op.Code() != kLiteRtOpCodeTflMul && op.Code() != kLiteRtOpCodeTflSub) { - return kLiteRtStatusErrorUnsupported; - } - if (op.Code() == kLiteRtOpCodeTflMul) { - ++num_muls_in_partition; - } - } - - { - char* byte_code_append; - (void)asprintf(&byte_code_append, - "Partition_%lu_with_%d_muls:", partition_index, - num_muls_in_partition); - result.byte_code[byte_code_idx].append(byte_code_append); - free(byte_code_append); - } - - { - char* per_op_data; - (void)asprintf(&per_op_data, "Partition_%lu", partition_index); - result.per_op_data.push_back(per_op_data); - free(per_op_data); - } - - return kLiteRtStatusOk; -} - -} // namespace - -LiteRtStatus LiteRtCompilerPluginCompile( - LiteRtCompilerPlugin compiler_plugin, const char* soc_model, - LiteRtModel partitions, LiteRtCompiledResult* compiled_result) { - auto model = litert::Model::CreateFromNonOwnedHandle(partitions); - const auto num_partitions = model.NumSubgraphs(); - auto result = std::make_unique(); - result->byte_code.resize(num_partitions); - for (auto i = 0; i < num_partitions; ++i) { - LITERT_RETURN_IF_ERROR( - CompileSinglePartition(i, model.Subgraph(i)->Get(), *result, i)); - } - - *compiled_result = result.release(); - - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.cc deleted file mode 100644 index 19c84dc55e7869..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.cc +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h" - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" - -// -// Configurations -// - -namespace litert::example { -namespace { - -constexpr char kPluginManufacturer[] = "ExampleSocManufacturer"; -constexpr char kPluginSocModel[] = "ExampleSocModel"; - -} // namespace -} // namespace litert::example - -LiteRtStatus LiteRtCompilerPluginSetFlags(LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex num_flags, - const char** keys, - const char** values) { - // IMPLEMENT ME - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version) { - if (!api_version) { - return kLiteRtStatusErrorInvalidArgument; - } - api_version->major = LITERT_API_VERSION_MAJOR; - api_version->minor = LITERT_API_VERSION_MINOR; - api_version->patch = LITERT_API_VERSION_PATCH; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( - LiteRtCompilerPlugin compiler_plugin, - LiteRtHwAccelerators* supported_hardware) { - if (!compiler_plugin || !supported_hardware) { - return kLiteRtStatusErrorInvalidArgument; - } - *supported_hardware = kLiteRtHwAcceleratorCpu; - return kLiteRtStatusOk; -} - -const char* LiteRtGetCompilerPluginSocManufacturer() { - return litert::example::kPluginManufacturer; -} - -LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( - LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex* num_supported_soc_models) { - if (!compiler_plugin || !num_supported_soc_models) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_supported_soc_models = 1; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, - const char** soc_model_name) { - if (!compiler_plugin || !soc_model_name) { - return kLiteRtStatusErrorInvalidArgument; - } else if (soc_model_idx != 0) { - return kLiteRtStatusErrorUnsupported; - } - *soc_model_name = litert::example::kPluginSocModel; - return kLiteRtStatusOk; -} - -// -// Compiled Result Definition -// - -LiteRtStatus LiteRtGetCompiledResultByteCode( - LiteRtCompiledResult compiled_result, LiteRtParamIndex byte_code_idx, - const void** byte_code, size_t* byte_code_size) { - if (!compiled_result) { - return kLiteRtStatusErrorInvalidArgument; - } - *byte_code = compiled_result->byte_code[byte_code_idx].data(); - *byte_code_size = compiled_result->byte_code[byte_code_idx].size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledResultCallInfo( - LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, - const void** call_info, size_t* call_info_size, - LiteRtParamIndex* byte_code_idx) { - if (call_idx >= compiled_result->per_op_data.size()) { - return kLiteRtStatusErrorIndexOOB; - } - *call_info = compiled_result->per_op_data.at(call_idx).data(); - *call_info_size = compiled_result->per_op_data.at(call_idx).size(); - *byte_code_idx = 0; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompiledResultCalls( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { - if (!compiled_result) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_calls = compiled_result->per_op_data.size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCompiledResultNumByteCodeModules( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_byte_code) { - *num_byte_code = compiled_result->byte_code.size(); - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompiledResult(LiteRtCompiledResult compiled_result) { - delete compiled_result; -} diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h deleted file mode 100644 index cc7c0f60df4e85..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_PLUGIN_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_PLUGIN_COMMON_H_ - -#include -#include - -// Simple compiled result def holds byte code and per op data. -struct LiteRtCompiledResultT { - std::vector byte_code; - std::vector per_op_data; -}; - -namespace litert::example {} // namespace litert::example - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_PLUGIN_COMMON_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc deleted file mode 100644 index 3b1b098ff62bfa..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h" - -namespace litert { -namespace { - -TEST(TestDummyPlugin, GetConfigInfo) { - ASSERT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), - "ExampleSocManufacturer"); - - auto plugin = CreatePlugin(); - - LiteRtParamIndex num_supported_soc_models; - LITERT_ASSERT_OK(LiteRtGetNumCompilerPluginSupportedSocModels( - plugin.get(), &num_supported_soc_models)); - ASSERT_EQ(num_supported_soc_models, 1); - - const char* soc_model_name; - LITERT_ASSERT_OK(LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, - &soc_model_name)); - ASSERT_STREQ(soc_model_name, "ExampleSocModel"); -} - -TEST(TestCallDummyPlugin, PartitionSimpleMultiAdd) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("simple_multi_op.tflite"); - - LiteRtOpListT selected_op_list; - LITERT_ASSERT_OK(LiteRtCompilerPluginPartition( - plugin.get(), /*soc_model=*/nullptr, model.Subgraph(0)->Get(), - &selected_op_list)); - const auto selected_ops = selected_op_list.Values(); - - ASSERT_EQ(selected_ops.size(), 2); - ASSERT_EQ(selected_ops[0].first->OpCode(), kLiteRtOpCodeTflMul); - ASSERT_EQ(selected_ops[1].first->OpCode(), kLiteRtOpCodeTflMul); -} - -TEST(TestCallDummyPlugin, CompileMulSubgraph) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("mul_simple.tflite"); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK(LiteRtCompilerPluginCompile( - plugin.get(), /*soc_model=*/nullptr, model.Get(), &compiled)); - - const void* byte_code; - size_t byte_code_size; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultByteCode( - compiled, /*byte_code_idx=*/0, &byte_code, &byte_code_size)); - - absl::string_view byte_code_string(reinterpret_cast(byte_code), - byte_code_size); - ASSERT_EQ(byte_code_string, "Partition_0_with_2_muls:"); - - LiteRtParamIndex byte_code_idx; - const void* op_data; - size_t op_data_size; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultCallInfo( - compiled, /*call_idx=*/0, &op_data, &op_data_size, &byte_code_idx)); - - absl::string_view op_data_string(reinterpret_cast(op_data), - op_data_size); - ASSERT_EQ(op_data_string, "Partition_0"); - - LiteRtDestroyCompiledResult(compiled); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions.cc deleted file mode 100644 index 22f11167c2cc5a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions.cc +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" -#include "tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h" - -using ::litert::PartitionWithCapabilities; -using ::litert::example::ExampleGraphBuilder; -using ::litert::example::ExampleOpAllocator; -using ::litert::example::ExampleOpType; -using ::litert::example::ExampleTensorAllocator; -using ::litert::example::ExampleTypes; -using ::litert::example::MakeAllLegalizations; -using ::litert::example::MakeTensorConverter; - -// Example plugin implementations that leverage the pluggable conversion -// infrastructure. Implementations of common interfaces are provided in -// example_conversion_impl.h. These are passed to higher-level litert functions -// to perform the actual conversion. -// The primary benifit of this approach is the re-use of conversion logic -// between the partition and compile phases. - -// Plugins can hold state. -struct LiteRtCompilerPluginT { - ExampleTypes::Legalizations legalizations; -}; - -namespace { - -bool MulCapability(const ExampleTypes::Op* op) { - return op->op_code == ExampleOpType::MUL; -} - -} // namespace - -// Initialize example plugin and register legalizations. -LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { - auto* plugin = new LiteRtCompilerPluginT; - plugin->legalizations = MakeAllLegalizations(); - *compiler_plugin = plugin; - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { - delete compiler_plugin; -} - -// Leverage the convert_type PartitionViaCapabilties algorithm for partitioning -// implementation. -LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtSubgraph subgraph, - LiteRtOpList selected_ops) { - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - auto ops = PartitionWithCapabilities( - compiler_plugin->legalizations, MulCapability, MakeTensorConverter, - tensor_alloc, op_alloc, ::litert::Subgraph(subgraph)); - if (!ops) { - return ops.Error().Status(); - } - - for (auto* op : *ops) { - LITERT_RETURN_IF_ERROR(LiteRtPushOp(selected_ops, op, 0)); - } - - return kLiteRtStatusOk; -} - -namespace { - -LiteRtStatus CompileSinglePartition( - const ExampleTypes::Legalizations& legalizations, std::string name, - LiteRtSubgraph subgraph, LiteRtCompiledResultT& result) { - ::litert::Subgraph litert_subgraph(subgraph); - - ExampleTensorAllocator tensor_alloc; - ExampleOpAllocator op_alloc; - - ExampleGraphBuilder builder; - - LITERT_RETURN_IF_ERROR(::litert::ConvertGraph( - litert_subgraph, name, MakeTensorConverter, tensor_alloc, op_alloc, - legalizations, builder)); - - // This example plugin only supports a single byte code module. - result.byte_code[0].append(builder.Serialize()); - result.per_op_data.push_back(std::move(name)); - - return kLiteRtStatusOk; -} - -} // namespace - -// Plugin compiler implementation that leverages the pluggable convert_types -// infrastructure. -LiteRtStatus LiteRtCompilerPluginCompile( - LiteRtCompilerPlugin compiler_plugin, const char* soc_model, - LiteRtModel partitions, LiteRtCompiledResult* compiled_result) { - auto model = litert::Model::CreateFromNonOwnedHandle(partitions); - const auto num_partitions = model.NumSubgraphs(); - auto result = std::make_unique(); - result->byte_code.resize(num_partitions); - for (auto i = 0; i < num_partitions; ++i) { - auto name = absl::StrFormat("partition_%lu", i); - LITERT_RETURN_IF_ERROR( - CompileSinglePartition(compiler_plugin->legalizations, std::move(name), - model.Subgraph(i)->Get(), *result)); - } - - *compiled_result = result.release(); - - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions_test.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions_test.cc deleted file mode 100644 index 10c7928cab629f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions_test.cc +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h" - -namespace litert { -namespace { - -using ::testing::HasSubstr; - -TEST(ExamplePluginWithConvertTypesTest, GetConfigInfo) { - ASSERT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), - "ExampleSocManufacturer"); - - auto plugin = CreatePlugin(); - - LiteRtParamIndex num_supported_soc_models; - LITERT_ASSERT_OK(LiteRtGetNumCompilerPluginSupportedSocModels( - plugin.get(), &num_supported_soc_models)); - ASSERT_EQ(num_supported_soc_models, 1); - - const char* soc_model_name; - LITERT_ASSERT_OK(LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, - &soc_model_name)); - ASSERT_STREQ(soc_model_name, "ExampleSocModel"); -} - -TEST(ExamplePluginWithConvertTypesTest, PartitionSimpleMultiAdd) { - auto plugin = CreatePlugin(); - auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); - - LiteRtOpListT selected_op_list; - LITERT_ASSERT_OK(LiteRtCompilerPluginPartition( - plugin.get(), /*soc_model=*/nullptr, model.Get()->MainSubgraph(), - &selected_op_list)); - const auto selected_ops = selected_op_list.Values(); - - ASSERT_EQ(selected_ops.size(), 2); - ASSERT_EQ(selected_ops[0].first->OpCode(), kLiteRtOpCodeTflMul); - ASSERT_EQ(selected_ops[1].first->OpCode(), kLiteRtOpCodeTflMul); -} - -TEST(ExamplePluginWithConvertTypesTest, CompileMulSubgraph) { - static constexpr absl::string_view kName = "partition_0"; - - auto plugin = CreatePlugin(); - auto model = litert::testing::LoadTestFileModel("mul_simple.tflite"); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK(LiteRtCompilerPluginCompile( - plugin.get(), /*soc_model=*/nullptr, model.Get(), &compiled)); - - const void* byte_code; - size_t byte_code_size; - LITERT_ASSERT_OK(LiteRtGetCompiledResultByteCode( - compiled, /*byte_code_idx=*/0, &byte_code, &byte_code_size)); - absl::string_view byte_code_str(reinterpret_cast(byte_code), - byte_code_size); - - EXPECT_THAT(byte_code_str, HasSubstr(kName)); - EXPECT_THAT(byte_code_str, HasSubstr("0FLOAT[2, 2]")); - EXPECT_THAT(byte_code_str, HasSubstr("1FLOAT[2, 2]")); - EXPECT_THAT(byte_code_str, HasSubstr("2FLOAT[2, 2]")); - EXPECT_THAT(byte_code_str, HasSubstr("MUL")); - EXPECT_THAT(byte_code_str, HasSubstr("FINALIZED")); - - LiteRtParamIndex num_call_infos; - LITERT_ASSERT_OK(LiteRtGetNumCompiledResultCalls(compiled, &num_call_infos)); - - ASSERT_EQ(num_call_infos, 1); - - const void* op_data; - size_t op_data_size; - LiteRtParamIndex byte_code_idx; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultCallInfo( - compiled, 0, &op_data, &op_data_size, &byte_code_idx)); - - absl::string_view op_data_str(reinterpret_cast(op_data), - op_data_size); - EXPECT_EQ(op_data_str, kName); - - LiteRtDestroyCompiledResult(compiled); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.cc deleted file mode 100644 index b0e1c25c591e57..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.cc +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.h" - -#include - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { -namespace google_tensor { - -Adapter::Adapter() : api_(new Api) {} - -Adapter::~Adapter() { - if (dlib_handle_) { - dlclose(dlib_handle_); // Use dlclose directly - } -} - -litert::Expected Adapter::Create( - std::optional shared_library_dir) { - Ptr adapter(new Adapter); - auto status = adapter->LoadSymbols(shared_library_dir); - if (!status.HasValue()) { - LITERT_LOG(LITERT_ERROR, "Failed to create Adapter: %s", - status.Error().Message().c_str()); - return status.Error(); - } - return adapter; -} - -litert::Expected Adapter::LoadSymbols( - std::optional shared_library_dir) { - constexpr auto kLibTensorTPUCompiler = "libcompiler_api_wrapper.so"; - - const std::vector so_paths = { - shared_library_dir.has_value() - ? absl::StrCat(*shared_library_dir, "/", kLibTensorTPUCompiler) - : kLibTensorTPUCompiler}; - - // Use dlopen directly - for (const auto& path : so_paths) { - dlib_handle_ = dlopen(path.c_str(), RTLD_LAZY | RTLD_LOCAL); - if (dlib_handle_) { - void* init_func = dlsym(dlib_handle_, "Initialize"); - if (init_func) { - (*reinterpret_cast(init_func))(); - } - break; // Found the library - } - } - - if (!dlib_handle_) { - const std::string error_message = - "Failed to load Tensor TPU compiler library: " + std::string(dlerror()); - LITERT_LOG(LITERT_ERROR, "Failed to load Tensor TPU compiler library: %s", - error_message.c_str()); // Include dlerror() for more info - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, error_message); - } - - api_->compile = - reinterpret_cast(dlsym(dlib_handle_, "CompileFlatbuffer")); - if (!api_->compile) { - const std::string error_message = - "Failed to load Tensor TPU compiler API: " + std::string(dlerror()); - LITERT_LOG(LITERT_ERROR, "Failed to load Tensor TPU compiler API: %s", - error_message.c_str()); // Include dlerror() - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, error_message); - } - - LITERT_LOG(LITERT_INFO, "Tensor TPU compiler API symbols loaded"); - return {}; -} - -} // namespace google_tensor -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.h deleted file mode 100644 index 37a88a840c793f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_ADAPTER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_ADAPTER_H_ -#include -#include -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::google_tensor { - -// Flags is a vector of key-value pairs. where key is the flag name and value is -// the flag value. eg. {{"enable_reference", "true"}} -using Flags = std::vector>; -typedef absl::Status (*Compile)(absl::string_view serialized_tfl_buffer, - absl::string_view soc_model, const Flags& flags, - std::string* compiled_code); - -// This class adapts the google tensor compiler API for dynamic loading. -class Adapter { - public: - // A smart pointer for managing TensorAdapter objects. - using Ptr = std::unique_ptr; - struct Api; - - Adapter(); - ~Adapter(); - - // Creates a new TensorAdapter and loads the compiler API symbols. - static litert::Expected Create( - std::optional shared_library_dir); - - // Returns a reference to the loaded API. - const Api& api() const { return *api_; } - - private: - // Loads the symbols from the compiler library. - litert::Expected LoadSymbols( - std::optional shared_library_dir); - - void* dlib_handle_ = nullptr; - std::unique_ptr api_; -}; - -struct Adapter::Api { - // The function pointer to the compiler wrapper API. - Compile compile = nullptr; -}; - -} // namespace litert::google_tensor - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_ADAPTER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter_test.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter_test.cc deleted file mode 100644 index 55872dfeb1160b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/adapter_test.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.h" - -#include - -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" - -namespace litert { -namespace google_tensor { - -TEST(AdapterTest, CreateSuccess) { - auto adapter_result = Adapter::Create(/*shared_library_dir=*/ - std::nullopt); - if (!adapter_result.HasValue()) { - LITERT_LOG(LITERT_ERROR, "Failed to create Adapter: %s", - adapter_result.Error().Message().c_str()); - } - ASSERT_TRUE(adapter_result.HasValue()); -} - -TEST(AdapterTest, CreateFailure) { - auto kLibDarwinnCompilerNoLib = "libcompiler_api_wrapper_no_lib.so"; - auto adapter_result = Adapter::Create(kLibDarwinnCompilerNoLib); - ASSERT_FALSE(adapter_result.HasValue()); -} - -TEST(AdapterTest, CompileSuccess) { - auto adapter_result = Adapter::Create(/*shared_library_dir=*/ - std::nullopt); - if (!adapter_result.HasValue()) { - LITERT_LOG(LITERT_ERROR, "Failed to create Adapter: %s", - adapter_result.Error().Message().c_str()); - } - - auto model = litert::testing::LoadTestFileModel("mul_simple.tflite"); - LiteRtModel litert_model = model.Get(); - - LITERT_LOG(LITERT_INFO, "%s", "Serializing model"); - litert::OwningBufferRef buf; - - // Using weak pointer to link the data to the buffer. - auto [data, size, offset] = buf.GetWeak(); - - const auto opts = litert::SerializationOptions::Defaults(); - auto status = - LiteRtSerializeModel(litert_model, &data, &size, &offset, false, opts); - if (status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to serialize model"); - } - - absl::string_view buffer_str(reinterpret_cast(buf.Data()), - buf.Size()); - - ASSERT_FALSE(buffer_str.empty()); - LITERT_LOG(LITERT_INFO, "buffer_str size: %d", buffer_str.size()); - LITERT_LOG(LITERT_INFO, "Compling model..."); - absl::string_view soc_model = "P25"; - litert::google_tensor::Flags flags; - flags.clear(); - std::string compiled_code; - auto compile_status = adapter_result.Value()->api().compile( - buffer_str, soc_model, flags, &compiled_code); - ASSERT_OK(compile_status); - ASSERT_FALSE(compiled_code.empty()); -} - -} // namespace google_tensor -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/BUILD b/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/BUILD deleted file mode 100644 index eaf46ce867d1b6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/BUILD +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:private"], -) - -litert_dynamic_lib( - name = "compiler_plugin", - srcs = ["compiler_plugin.cc"], - hdrs = ["//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h"], - export_litert_only = True, - linkstatic = 1, - shared_lib_name = "google_tensor_compiler_plugin_so", - so_name = "libLiteRtCompilerPlugin_google_tensor.so", - tags = [ - # Don't build/test in OS until google tensor is available. - "nobuilder", - "no_oss", - "notap", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/google_tensor:adapter", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_test( - name = "compiler_plugin_test", - srcs = [ - "compiler_plugin_test.cc", - ], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - ], - linkstatic = True, - tags = [ - # Tests with ungrte deps do not currently work on forge. - "no-remote-exec", - "notap", - # Don't build/test in OS until google tensor is available. - "nobuilder", - "no_oss", - # Sanatizer runtime doesn't work with anything that loads a shared library. - "nosan", - "manual", - ], - # This test can only be run on Android and Linux. - target_compatible_with = select({ - "@platforms//os:android": [], - "@platforms//os:linux": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - deps = [ - ":compiler_plugin", # buildcleaner: keep - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/vendors/cc:litert_compiler_plugin", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin.cc deleted file mode 100644 index 0b5854f03a18b9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin.cc +++ /dev/null @@ -1,360 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/adapter.h" - -// -// Configurations -// - -namespace google_tensor { - -constexpr char kPluginManufacturer[] = "GoogleTensor"; - -constexpr const char* kPluginSocModels[] = { - "P25", -}; // get the name for plugin soc model - -constexpr LiteRtOpCode kUnSupportedOps[] = { - kLiteRtOpCodeTflAssignVariable, - kLiteRtOpCodeTflBidirectionalSequenceLstm, - kLiteRtOpCodeTflBroadcastArgs, - kLiteRtOpCodeTflBucketize, - kLiteRtOpCodeTflCallOnce, - kLiteRtOpCodeTflComplexAbs, - kLiteRtOpCodeTflConv3d, - kLiteRtOpCodeTflConv3dTranspose, - kLiteRtOpCodeTflDensify, - kLiteRtOpCodeTflFakeQuant, - kLiteRtOpCodeTflHashtable, - kLiteRtOpCodeTflHashtableFind, - kLiteRtOpCodeTflHashtableImport, - kLiteRtOpCodeTflHashtableSize, - kLiteRtOpCodeTflImag, - kLiteRtOpCodeTflLocalResponseNormalization, - kLiteRtOpCodeTflMatrixDiag, - kLiteRtOpCodeTflMatrixSetDiag, - kLiteRtOpCodeTflMultinomial, - kLiteRtOpCodeTflNonMaxSuppressionV4, - kLiteRtOpCodeTflNonMaxSuppressionV5, - kLiteRtOpCodeTflRandomStandardNormal, - kLiteRtOpCodeTflRandomUniform, - kLiteRtOpCodeTflRank, - kLiteRtOpCodeTflReadVariable, - kLiteRtOpCodeTflReal, - kLiteRtOpCodeTflReduceProd, - kLiteRtOpCodeTflReverseSequence, - kLiteRtOpCodeTflRfft2d, - kLiteRtOpCodeTflSegmentSum, - kLiteRtOpCodeTflShape, - kLiteRtOpCodeTflSparseToDense, - kLiteRtOpCodeTflSvdf, - kLiteRtOpCodeTflUnidirectionalSequenceRnn, - kLiteRtOpCodeTflUnique, - kLiteRtOpCodeTflUnsortedSegmentMax, - kLiteRtOpCodeTflUnsortedSegmentMin, - kLiteRtOpCodeTflUnsortedSegmentProd, - kLiteRtOpCodeTflUnsortedSegmentSum, - kLiteRtOpCodeTflVarHandle, - kLiteRtOpCodeTflWhere, -}; -// clang format on - -constexpr auto kNumPluginSocModels = - sizeof(kPluginSocModels) / sizeof(kPluginSocModels[0]); - -} // namespace google_tensor - -LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version) { - if (api_version == nullptr) { - LITERT_LOG(LITERT_ERROR, "%s", "api_version is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - api_version->major = LITERT_API_VERSION_MAJOR; - api_version->minor = LITERT_API_VERSION_MINOR; - api_version->patch = LITERT_API_VERSION_PATCH; - return kLiteRtStatusOk; -} - -const char* LiteRtGetCompilerPluginSocManufacturer() { - return google_tensor::kPluginManufacturer; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( - LiteRtCompilerPlugin compiler_plugin, - LiteRtHwAccelerators* supported_hardware) { - if (!compiler_plugin || !supported_hardware) { - LITERT_LOG(LITERT_ERROR, "%s", - "compiler_plugin or supported_hardware is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - *supported_hardware = kLiteRtHwAcceleratorNpu; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( - LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex* num_supported_soc_models) { - if (compiler_plugin == nullptr || num_supported_soc_models == nullptr) { - LITERT_LOG(LITERT_ERROR, "%s", - "compiler_plugin or num_supported_soc_models is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - *num_supported_soc_models = google_tensor::kNumPluginSocModels; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, - const char** soc_model_name) { - if (compiler_plugin == nullptr || - soc_model_idx >= google_tensor::kNumPluginSocModels || - soc_model_name == nullptr) { - LITERT_LOG(LITERT_ERROR, "%s", - "compiler_plugin or soc_model_idx or soc_model_name is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - *soc_model_name = google_tensor::kPluginSocModels[soc_model_idx]; - return kLiteRtStatusOk; -} - -// -// Compiled Result Definition -// - -// TODO (abhirs): Revisit this struct after updating the compiler api wrapper to -// return multiple bytecodes. -struct LiteRtCompiledResultT { - std::string byte_code; - std::vector per_op_data; -}; - -LiteRtStatus LiteRtGetCompiledResultByteCode( - LiteRtCompiledResult compiled_result, LiteRtParamIndex byte_code_idx, - const void** byte_code, size_t* byte_code_size) { - if (!compiled_result || !byte_code || !byte_code_size || - (byte_code_idx != 0)) { - LITERT_LOG(LITERT_ERROR, "%s", - "compiled_result or byte_code or byte_code_size" - "or byte_code_idx is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - *byte_code = compiled_result->byte_code.data(); - *byte_code_size = compiled_result->byte_code.size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCompiledResultNumByteCodeModules( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_byte_code) { - if (!compiled_result || !num_byte_code) { - LITERT_LOG(LITERT_ERROR, "%s", - "compiled_result or num_byte_code is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - *num_byte_code = 1; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledResultCallInfo( - LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, - const void** call_info, size_t* call_info_size, - LiteRtParamIndex* byte_code_idx) { - if (!compiled_result || !call_info || !call_info_size) { - LITERT_LOG(LITERT_ERROR, "%s", - "compiled_result or call_info or call_info_size is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } else if (call_idx >= compiled_result->per_op_data.size()) { - LITERT_LOG(LITERT_ERROR, "%s", "call_idx is out of bounds"); - return kLiteRtStatusErrorIndexOOB; - } - - *call_info = compiled_result->per_op_data.at(call_idx).data(); - *call_info_size = compiled_result->per_op_data.at(call_idx).size(); - *byte_code_idx = 0; - - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompiledResultCalls( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { - if (!compiled_result || !num_calls) { - LITERT_LOG(LITERT_ERROR, "%s", "compiled_result or num_calls is nullptr"); - return kLiteRtStatusErrorInvalidArgument; - } - *num_calls = compiled_result->per_op_data.size(); - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompiledResult(LiteRtCompiledResult compiled_result) { - delete compiled_result; -} - -// -// Plugin Definition -// - -// Plugins can hold state. -struct LiteRtCompilerPluginT { - using Flag = std::pair; - std::vector flags; -}; - -LiteRtStatus LiteRtCompilerPluginSetFlags(LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex num_flags, - const char** keys, - const char** values) { - auto& flags = compiler_plugin->flags; - if (flags.size() != 0) { - LITERT_LOG(LITERT_INFO, "Overwriting existing flags"); - flags.clear(); - } - flags.resize(num_flags); - for (int i = 0; i < num_flags; ++i) { - auto& flag = flags[i]; - flag.first = std::string(keys[i]); - flag.second = std::string(values[i]); - LITERT_LOG(LITERT_INFO, "Setting Flag: %s = %s", flag.first.c_str(), - flag.second.c_str()); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { - *compiler_plugin = new LiteRtCompilerPluginT; - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { - if (compiler_plugin == nullptr) { - return; - } - delete compiler_plugin; -} - -namespace google_tensor { -// TODO(abhirs): update the function to use the darwinn inbuilt way of -// finding supportedops -bool IsOpSupported(const litert::Op& op) { - for (auto unsupported_op : kUnSupportedOps) { - if (unsupported_op == op.Code()) { - return false; - } - } - return true; -} - -} // namespace google_tensor - -LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtSubgraph subgraph, - LiteRtOpList selected_ops) { - ::litert::Subgraph graph(subgraph); - for (const auto& op : graph.Ops()) { - if (!google_tensor::IsOpSupported(op)) { - continue; - } - - LITERT_RETURN_IF_ERROR(LiteRtPushOp(selected_ops, op.Get(), 0)); - } - - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCompilerPluginCompile( - LiteRtCompilerPlugin compiler_plugin, const char* soc_model, - LiteRtModel partitions, LiteRtCompiledResult* compiled_result) { - if (compiler_plugin == nullptr || soc_model == nullptr || - partitions == nullptr || compiled_result == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - auto model = litert::Model::CreateFromNonOwnedHandle(partitions); - const auto num_partitions = model.NumSubgraphs(); - LITERT_LOG(LITERT_INFO, - "Starting GoogleTensor Compilation for %d subgraphs, soc_model=%s", - num_partitions, soc_model); - - // Serialize model. - LITERT_LOG(LITERT_INFO, "%s", "Serializing model"); - litert::OwningBufferRef buf; - auto [data, size, offset] = buf.GetWeak(); - const auto opts = litert::SerializationOptions::Defaults(); - LITERT_RETURN_IF_ERROR( - LiteRtSerializeModel(partitions, &data, &size, &offset, false, opts)); - // TODO(abhirs): add support for serializing subgraphs - - absl::string_view buffer_str(reinterpret_cast(buf.Data()), - buf.Size()); - - // Loading Google Tensor Compiler Adapter - LITERT_LOG(LITERT_INFO, "%s", "Loading Google Tensor Compiler Adapter"); - auto adapter_result = litert::google_tensor::Adapter::Create( - /*shared_library_dir=*/std::nullopt); - if (!adapter_result.HasValue()) { - const auto& error_message = adapter_result.Error().Message(); - LITERT_LOG(LITERT_ERROR, "Failed to create adapter: %s", - error_message.c_str()); - return kLiteRtStatusErrorRuntimeFailure; - } - - // Compile model. - LITERT_LOG(LITERT_INFO, "%s", "Compiling model..."); - // TODO(b/398984678): add support for multiple bytecodes - absl::string_view soc_model_view(soc_model); - std::string compiled; - auto compile_status = adapter_result.Value()->api().compile( - buffer_str, soc_model_view, compiler_plugin->flags, &compiled); - - if (!compile_status.ok()) { - LITERT_LOG( - LITERT_ERROR, "%s", - absl::StrCat("Failed to compile model: ", compile_status.message()) - .c_str()); - return kLiteRtStatusErrorRuntimeFailure; - } - - // Result - auto result = std::make_unique(); - - result->byte_code = std::string(compiled.data(), compiled.size()); - // Generate per_op_data. - for (auto i = 0; i < num_partitions; ++i) { - result->per_op_data.emplace_back( - absl::StrFormat("Partition_%d", static_cast(i))); - } - *compiled_result = result.release(); - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin_test.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin_test.cc deleted file mode 100644 index 7f6ca4aaf95a74..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/compiler_plugin_test.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h" - -namespace litert { -namespace { - -TEST(TestGoogleTensorPlugin, GetConfigInfo) { - ASSERT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), "GoogleTensor"); - - auto plugin = CreatePlugin(); - - LiteRtParamIndex num_supported_soc_models; - LITERT_ASSERT_OK(LiteRtGetNumCompilerPluginSupportedSocModels( - plugin.get(), &num_supported_soc_models)); - ASSERT_EQ(num_supported_soc_models, 1); - - const char* soc_model_name; - LITERT_ASSERT_OK(LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, - &soc_model_name)); - ASSERT_STREQ(soc_model_name, "P25"); -} - -TEST(TestCallGoogleTensorPlugin, PartitionSimpleMultiAdd) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("simple_multi_op.tflite"); - - LiteRtOpListT selected_op_list; - LITERT_ASSERT_OK(LiteRtCompilerPluginPartition( - plugin.get(), /*soc_model=*/nullptr, model.Subgraph(0)->Get(), - &selected_op_list)); - const auto selected_ops = selected_op_list.Values(); - - ASSERT_EQ(selected_ops.size(), 4); - ASSERT_EQ(selected_ops[0].first->OpCode(), kLiteRtOpCodeTflAdd); - ASSERT_EQ(selected_ops[1].first->OpCode(), kLiteRtOpCodeTflMul); -} - -TEST(TestCallGoogleTensorPlugin, CompileMulSubgraph) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("mul_simple.tflite"); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK( - LiteRtCompilerPluginCompile(plugin.get(), "P25", model.Get(), &compiled)); - - const void* byte_code; - size_t byte_code_size; - LITERT_ASSERT_OK(LiteRtGetCompiledResultByteCode(compiled, 0, &byte_code, - &byte_code_size)); - absl::string_view byte_code_string(reinterpret_cast(byte_code), - byte_code_size); - ASSERT_FALSE(byte_code_string.empty()); - - const void* op_data; - size_t op_data_size; - LiteRtParamIndex byte_code_idx; - LITERT_ASSERT_OK(LiteRtGetCompiledResultCallInfo( - compiled, 0, &op_data, &op_data_size, &byte_code_idx)); - absl::string_view op_data_string(reinterpret_cast(op_data), - op_data_size); - ASSERT_EQ("Partition_0", op_data_string); - - LiteRtDestroyCompiledResult(compiled); -} - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD deleted file mode 100644 index 6267f882339fed..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "copy_file", "litert_dynamic_lib") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -litert_dynamic_lib( - name = "dispatch_api", - srcs = [ - "dispatch_api.cc", - "litert_dispatch_device_context.cc", - "litert_dispatch_graph.cc", - "litert_dispatch_invocation_context.cc", - "southbound.cc", - ], - hdrs = [ - "dispatch_api.h", - "litert_dispatch_device_context.h", - "litert_dispatch_graph.h", - "litert_dispatch_invocation_context.h", - "litert_dispatch_metrics.h", - "southbound.h", - # copybara:uncomment "//third_party/odml/infra/southbound:sb_api.h", - ], - copts = [ - "-Os", - "-fno-exceptions", - "-fno-unwind-tables", - "-fno-asynchronous-unwind-tables", - "-ffunction-sections", - "-fdata-sections", - ], - export_litert_only = True, - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }) + [ - "-Wl,-soname=libLiteRtDispatch_GoogleTensor.so", - "-Wl,-lc++abi", - ], - shared_lib_name = "dispatch_api_so", - so_name = "libLiteRtDispatch_GoogleTensor.so", - tags = [ - # Don't build/test in OSS until Southbound is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_any", - "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core/util:tensor_type_util", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/strings:string_view", - ], -) - -# This is cc_library target for `libLiteRtDispatch_GoogleTensor.so`. -cc_library( - name = "dispatch_api_shared_lib", - srcs = [":dispatch_api_so"], - visibility = ["//visibility:public"], -) - -# Copies the shared library so that it is available for use in test data as libLiteRtDispatch_GoogleTensor.so. -copy_file( - name = "copy_dispatch_api_so", - src = "//tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch:dispatch_api_so", - target = "libLiteRtDispatch_GoogleTensor.so", -) - -cc_test( - name = "dispatch_api_google_tensor_test", - srcs = [ - "dispatch_api_google_tensor_test.cc", - ], - data = [ - ":dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - # Don't build/test in OSS until Southbound is available. - "nobuilder", - "no_oss", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "dispatch_api_async_google_tensor_test", - srcs = [ - "dispatch_api_async_google_tensor_test.cc", - ], - data = [ - ":dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - # Don't build/test in OSS until Southbound is available. - "nobuilder", - "no_oss", - ], - deps = [ - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/darwinn/driver_shared/fence:fence_test_util", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_event", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc deleted file mode 100644 index f1aede35c98885..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc +++ /dev/null @@ -1,644 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h" - -#include -#include -#include -#include -#include - -#if LITERT_HAS_AHWB_SUPPORT -#include -#endif - -#include "absl/strings/string_view.h" -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_metrics.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" - -namespace { - -litert::google_tensor::Southbound* TheSouthbound; -char BuildId[256]; - -} // namespace - -namespace litert { -namespace google_tensor { - -// ///////////////////////////////////////////////////////////////////////////// -// Basic Execution API -// ///////////////////////////////////////////////////////////////////////////// - -const char* GetSharedLibraryDir(const LiteRtDispatchOption* options, - int num_options) { - for (auto i = 0; i < num_options; ++i) { - auto& option = options[i]; - if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { - return option.value.str_value; - } - } - return nullptr; -} - -LiteRtStatus Initialize(const LiteRtDispatchOption* options, int num_options) { - auto* shared_library_dir = GetSharedLibraryDir(options, num_options); - std::optional shared_library_dir_opt = - shared_library_dir ? std::make_optional(std::string(shared_library_dir)) - : std::nullopt; - - if (auto southbound = - litert::google_tensor::Southbound::Create(shared_library_dir_opt); - !southbound) { - LITERT_LOG(LITERT_INFO, "Initialization failure: %s", - southbound.Error().Message().c_str()); - return southbound.Error().Status(); - } else { - TheSouthbound = southbound->release(); - } - - auto thr_initialize = TheSouthbound->api().thr_initialize; - if (!thr_initialize) { - LITERT_LOG(LITERT_INFO, "thr_initialize not found"); - return kLiteRtStatusErrorRuntimeFailure; - } - if (auto status = thr_initialize(); status != kThrStatusSuccess) { - LITERT_LOG(LITERT_INFO, "thr_initialize failed: %d", status); - return kLiteRtStatusErrorRuntimeFailure; - } - - auto thr_get_vendor_api_version = - TheSouthbound->api().thr_get_vendor_api_version; - const char* sb_api_version = - thr_get_vendor_api_version ? thr_get_vendor_api_version() : "N.A."; - auto thr_get_vendor_id = TheSouthbound->api().thr_get_vendor_id; - const char* sb_vendor_id = thr_get_vendor_id ? thr_get_vendor_id() : "N.A."; - snprintf( - BuildId, sizeof(BuildId), - "GoogleTensor Dispatch API version %d.%d.%d, Darwinn API version %s, " - "vendor id: %s", - LITERT_API_VERSION_MAJOR, LITERT_API_VERSION_MINOR, - LITERT_API_VERSION_PATCH, sb_api_version, sb_vendor_id); - BuildId[sizeof(BuildId) - 1] = 0; - - return kLiteRtStatusOk; -} - -LiteRtStatus GetVendorId(const char** vendor_id) { - *vendor_id = "Google"; - return kLiteRtStatusOk; -} - -LiteRtStatus GetBuildId(const char** build_id) { - *build_id = BuildId; - return kLiteRtStatusOk; -} - -LiteRtStatus GetCapabilities(int* capabilities) { - *capabilities = kLiteRtDispatchCapabilitiesBasic | - kLiteRtDispatchCapabilitiesAsync | - kLiteRtDispatchCapabilitiesGraph; - return kLiteRtStatusOk; -} - -LiteRtStatus DeviceContextCreate(LiteRtDispatchDeviceContext* device_context) { - if (auto result = LiteRtDispatchDeviceContextT::Create(*TheSouthbound); - result) { - *device_context = result->release(); - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to create device context: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus DeviceContextDestroy(LiteRtDispatchDeviceContext device_context) { - delete device_context; - return kLiteRtStatusOk; -} - -LiteRtStatus GetInputRequirements( - LiteRtDispatchInvocationContext invocation_context, int input_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (auto result = - invocation_context->GetInputRequirements(input_index, *tensor_type); - result) { - *tensor_buffer_requirements = *result; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to get input requirements: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus GetOutputRequirements( - LiteRtDispatchInvocationContext invocation_context, int output_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (auto result = - invocation_context->GetOutputRequirements(output_index, *tensor_type); - result) { - *tensor_buffer_requirements = *result; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to get output requirements: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus RegisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, LiteRtTensorBuffer buffer, - LiteRtTensorBufferHandle* tensor_buffer_handle) { - if (auto status = device_context->RegisterTensorBuffer(buffer); status) { - *tensor_buffer_handle = *status; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to register buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } -} - -LiteRtStatus UnregisterTensorBuffer(LiteRtDispatchDeviceContext device_context, - LiteRtTensorBufferHandle handle) { - if (auto status = device_context->UnregisterTensorBuffer(handle); status) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to unregister buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } -} - -LiteRtStatus InvocationContextCreate( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs, - LiteRtDispatchInvocationContext* invocation_context) { - function_name = ""; - if (auto result = LiteRtDispatchInvocationContextT::CreateFromBytecode( - *TheSouthbound, device_context, exec_type, exec_bytecode_buffer, - function_name, num_inputs, num_outputs); - result) { - *invocation_context = result->release(); - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to create invocation context: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus InvocationContextDestroy( - LiteRtDispatchInvocationContext invocation_context) { - delete invocation_context; - return kLiteRtStatusOk; -} - -LiteRtStatus AttachInput(LiteRtDispatchInvocationContext invocation_context, - int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = invocation_context->AttachInput(graph_input_index, - tensor_buffer_handle); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to attach input: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AttachOutput(LiteRtDispatchInvocationContext invocation_context, - int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = invocation_context->AttachOutput(graph_output_index, - tensor_buffer_handle); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to attach output: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - - -LiteRtStatus DetachInput(LiteRtDispatchInvocationContext invocation_context, - int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = invocation_context->DetachInput(graph_input_index, - tensor_buffer_handle); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to detatch input: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus DetachOutput(LiteRtDispatchInvocationContext invocation_context, - int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = invocation_context->DetachOutput(graph_output_index, - tensor_buffer_handle); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to detatch output: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus Invoke(LiteRtDispatchInvocationContext invocation_context) { - if (auto result = invocation_context->Invoke(); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to invoke: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -// ///////////////////////////////////////////////////////////////////////////// -// Async Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus AttachInputEvent( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtEvent input_event) { - if (auto result = - invocation_context->AttachInputEvent(graph_input_index, input_event); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to attach input event: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus InvokeAsync(LiteRtDispatchInvocationContext invocation_context, - int num_output_events, LiteRtEvent* output_events) { - if (auto result = - invocation_context->InvokeAsync(num_output_events, output_events); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to invoke asynchronously: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -// ///////////////////////////////////////////////////////////////////////////// -// Metrics API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus StartMetricsCollection( - LiteRtDispatchInvocationContext invocation_context, int detail_level) { - if (auto result = invocation_context->StartMetricsCollection(detail_level); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to start metrics collection: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus StopMetricsCollection( - LiteRtDispatchInvocationContext invocation_context, - LiteRtDispatchMetrics* metrics) { - if (auto result = invocation_context->StopMetricsCollection(metrics); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to stop metrics collection: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus GetNumMetrics(LiteRtDispatchMetrics metrics, int* num_metrics) { - if (metrics == nullptr) { - LITERT_LOG(LITERT_ERROR, - "GetNumMetrics failed: metrics should not be null"); - return kLiteRtStatusErrorInvalidArgument; - } - *num_metrics = metrics->GetNumMetrics(); - return kLiteRtStatusOk; -} - -LiteRtStatus GetMetric(LiteRtDispatchMetrics metrics, int metric_index, - LiteRtMetric* metric) { - if (metrics == nullptr) { - LITERT_LOG(LITERT_ERROR, "GetMetric failed: metrics should not be null"); - return kLiteRtStatusErrorInvalidArgument; - } - *metric = metrics->GetMetric(metric_index); - return kLiteRtStatusOk; -} - -LiteRtStatus DestroyMetrics(LiteRtDispatchMetrics metrics) { - if (metrics) { - delete metrics; - } - return kLiteRtStatusOk; -} - -// ///////////////////////////////////////////////////////////////////////////// -// Graph Execution API -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtStatus GraphCreate(LiteRtDispatchDeviceContext device_context, - LiteRtDispatchGraph* graph) { - if (auto result = device_context->CreateGraph(); result) { - *graph = *result; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to create graph: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus GraphDestroy(LiteRtDispatchGraph graph) { - if (auto result = graph->device_context()->DestroyGraph(graph); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to delete graph: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AddNode(LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, - LiteRtDispatchNodeType node_type) { - if (auto result = graph->AddNode(node_id, node_type); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to add node: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AddEdge(LiteRtDispatchGraph graph, LiteRtDispatchEdgeId edge_id) { - if (auto result = graph->AddEdge(edge_id); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to add edge: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus ConnectNodeInput(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, int input_index, - LiteRtDispatchEdgeId edge_id) { - if (auto result = graph->ConnectNodeInput(node_id, input_index, edge_id); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to connect node input: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus ConnectNodeOutput(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, int output_index, - LiteRtDispatchEdgeId edge_id) { - if (auto result = graph->ConnectNodeOutput(node_id, output_index, edge_id); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to connect node output: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus ConnectGraphInput(LiteRtDispatchGraph graph, int input_index, - LiteRtDispatchEdgeId edge_id) { - if (auto result = graph->ConnectGraphInput(input_index, edge_id); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to connect graph input: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus ConnectGraphOutput(LiteRtDispatchGraph graph, int output_index, - LiteRtDispatchEdgeId edge_id) { - if (auto result = graph->ConnectGraphOutput(output_index, edge_id); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to connect graph output: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus LoadExecutable(LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType type, - const LiteRtMemBuffer* bytecode_buffer, - LiteRtDispatchExecutableHandle* exec_handle) { - if (auto result = device_context->LoadExecutable(type, bytecode_buffer); - result) { - *exec_handle = *result; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to load executable: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus UnloadExecutable(LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableHandle exec_handle) { - if (auto result = device_context->UnloadExecutable(exec_handle); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to unload executable: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AssignNodeFunction(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, - LiteRtDispatchExecutableHandle exec_handle, - const char* function_name) { - // TODO - b/397771624: Southbound currently doesn't support function names, so - // overriding function names to empty strings as a temporary fix. We need to - // investigate with the CoreML team to find a more robust solution. - function_name = ""; - if (auto result = - graph->AssignNodeFunction(node_id, exec_handle, function_name); - result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to assign node function: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AnnotateGraph(LiteRtDispatchGraph graph, const char* key, - const char* value) { - if (auto result = graph->AnnotateGraph(key, value); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to annotate graph: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AnnotateNode(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, const char* key, - const char* value) { - if (auto result = graph->AnnotateNode(node_id, key, value); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to annotate node: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus AnnotateEdge(LiteRtDispatchGraph graph, - LiteRtDispatchEdgeId edge_id, const char* key, - const char* value) { - if (auto result = graph->AnnotateEdge(edge_id, key, value); result) { - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to annotate edge: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus InvocationContextCreateFromGraph( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, - LiteRtDispatchInvocationContext* invocation_context) { - if (auto result = LiteRtDispatchInvocationContextT::CreateFromGraph( - *TheSouthbound, device_context, graph); - result) { - *invocation_context = result->release(); - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to create invocation context: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -} // namespace google_tensor -} // namespace litert - -// ///////////////////////////////////////////////////////////////////////////// - -namespace { - -LiteRtDispatchInterface TheInterface = { - .initialize = litert::google_tensor::Initialize, - .get_vendor_id = litert::google_tensor::GetVendorId, - .get_build_id = litert::google_tensor::GetBuildId, - .get_capabilities = litert::google_tensor::GetCapabilities, - .device_context_create = litert::google_tensor::DeviceContextCreate, - .device_context_destroy = litert::google_tensor::DeviceContextDestroy, - .get_input_requirements = litert::google_tensor::GetInputRequirements, - .get_output_requirements = litert::google_tensor::GetOutputRequirements, - .register_tensor_buffer = litert::google_tensor::RegisterTensorBuffer, - .unregister_tensor_buffer = litert::google_tensor::UnregisterTensorBuffer, - .invocation_context_create = litert::google_tensor::InvocationContextCreate, - .invocation_context_destroy = - litert::google_tensor::InvocationContextDestroy, - .attach_input = litert::google_tensor::AttachInput, - .attach_output = litert::google_tensor::AttachOutput, - .detach_input = litert::google_tensor::DetachInput, - .detach_output = litert::google_tensor::DetachOutput, - .invoke = litert::google_tensor::Invoke, - .start_metrics_collection = litert::google_tensor::StartMetricsCollection, - .stop_metrics_collection = litert::google_tensor::StopMetricsCollection, - .get_num_metrics = litert::google_tensor::GetNumMetrics, - .get_metric = litert::google_tensor::GetMetric, - .destroy_metrics = litert::google_tensor::DestroyMetrics, -}; - -LiteRtDispatchAsyncInterface TheAsyncInterface = { - .attach_input_event = litert::google_tensor::AttachInputEvent, - .invoke_async = litert::google_tensor::InvokeAsync, -}; - -LiteRtDispatchGraphInterface TheGraphInterface = { - .graph_create = litert::google_tensor::GraphCreate, - .graph_destroy = litert::google_tensor::GraphDestroy, - .add_node = litert::google_tensor::AddNode, - .add_edge = litert::google_tensor::AddEdge, - .connect_node_input = litert::google_tensor::ConnectNodeInput, - .connect_node_output = litert::google_tensor::ConnectNodeOutput, - .connect_graph_input = litert::google_tensor::ConnectGraphInput, - .connect_graph_output = litert::google_tensor::ConnectGraphOutput, - .load_executable = litert::google_tensor::LoadExecutable, - .unload_executable = litert::google_tensor::UnloadExecutable, - .assign_node_function = litert::google_tensor::AssignNodeFunction, - .annotate_graph = litert::google_tensor::AnnotateGraph, - .annotate_node = litert::google_tensor::AnnotateNode, - .annotate_edge = litert::google_tensor::AnnotateEdge, - .invocation_context_create_from_graph = - litert::google_tensor::InvocationContextCreateFromGraph, -}; - -LiteRtDispatchApi TheApi = { - .version = {.major = LITERT_API_VERSION_MAJOR, - .minor = LITERT_API_VERSION_MINOR, - .patch = LITERT_API_VERSION_PATCH}, - .interface = &TheInterface, - .async_interface = &TheAsyncInterface, - .graph_interface = &TheGraphInterface, -}; - -} // namespace - -LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api) { - *api = TheApi; - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h deleted file mode 100644 index 00392c06efe163..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_DISPATCH_API_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_DISPATCH_API_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -namespace litert::google_tensor { - -LiteRtStatus GraphCreate(LiteRtDispatchDeviceContext device_context, - LiteRtDispatchGraph* graph); -LiteRtStatus GraphDestroy(LiteRtDispatchGraph graph); -LiteRtStatus AddNode(LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, - LiteRtDispatchNodeType node_type); -LiteRtStatus AddEdge(LiteRtDispatchGraph graph, LiteRtDispatchEdgeId edge_id); -LiteRtStatus ConnectNodeInput(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, int input_index, - LiteRtDispatchEdgeId edge_id); -LiteRtStatus ConnectNodeOutput(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, int output_index, - LiteRtDispatchEdgeId edge_id); -LiteRtStatus ConnectGraphInput(LiteRtDispatchGraph graph, int input_index, - LiteRtDispatchEdgeId edge_id); -LiteRtStatus ConnectGraphOutput(LiteRtDispatchGraph graph, int output_index, - LiteRtDispatchEdgeId edge_id); -LiteRtStatus LoadExecutable(LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType type, - const void* bytecode, size_t bytecode_size, - LiteRtDispatchExecutableHandle* exec_handle); -LiteRtStatus UnloadExecutable(LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableHandle exec_handle); -LiteRtStatus AssignNodeFunction(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, - LiteRtDispatchExecutableHandle exec_handle, - const char* function_name); -LiteRtStatus AnnotateGraph(LiteRtDispatchGraph graph, const char* key, - const char* value); -LiteRtStatus AnnotateNode(LiteRtDispatchGraph graph, - LiteRtDispatchNodeId node_id, const char* key, - const char* value); -LiteRtStatus AnnotateEdge(LiteRtDispatchGraph graph, - LiteRtDispatchEdgeId edge_id, const char* key, - const char* value); -LiteRtStatus InvocationContextCreateFromGraph( - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, - LiteRtDispatchInvocationContext* invocation_context); - -} // namespace litert::google_tensor - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_DISPATCH_API_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_async_google_tensor_test.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_async_google_tensor_test.cc deleted file mode 100644 index 762792a135e0ca..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_async_google_tensor_test.cc +++ /dev/null @@ -1,340 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#if defined(__ANDROID__) -#include "platforms/darwinn/tachyon/core/fence/fence.h" -#endif -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "third_party/darwinn/driver_shared/fence/fence_test_util.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -using ::testing::Pointwise; -using Fence = std::shared_ptr; - -TEST(DispatchApiAsync, GoogleTensor) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "This test is specific to Android devices with a GoogleTensor eTPU"; -#endif - - LiteRtDispatchOption dispatch_option = { - /*.name=*/kDispatchOptionSharedLibraryDir, - /*.value=*/*litert::ToLiteRtAny(std::any("/data/local/tmp")), - }; - ASSERT_EQ( - LiteRtDispatchInitialize(/*options=*/&dispatch_option, /*num_options=*/1), - kLiteRtStatusOk); - - const char* vendor_id; - EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "vendor_id: " << vendor_id; - - const char* build_id; - EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "build_id: " << build_id; - - LiteRtApiVersion api_version; - EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); - ABSL_LOG(INFO) << "api_version: " << api_version.major << "." - << api_version.minor << "." << api_version.patch; - - int capabilities; - EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); - ABSL_LOG(INFO) << "capabilities: " << capabilities; - - LiteRtDispatchDeviceContext device_context = nullptr; - EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "device_context: " << device_context; - - auto model_file_name = - litert::testing::GetTestFilePath(kGoogleTensorModelFileName); - auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model) << model.Error(); - ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() - << " bytes"; - - // /////////////////////////////////////////////////////////////////////////// - // Set up an invocation context for a given model. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtMemBuffer exec_bytecode_buffer = {/*.fd=*/-1, - /*.base_addr=*/model->Data(), - /*.offset=*/0, - /*.size=*/model->Size()}; - LiteRtDispatchInvocationContext invocation_context = nullptr; - EXPECT_EQ(LiteRtDispatchInvocationContextCreate( - device_context, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, /*function_name=*/nullptr, - /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "Invocation context: " << invocation_context; - - // /////////////////////////////////////////////////////////////////////////// - // Determine tensor buffer requirements. - // /////////////////////////////////////////////////////////////////////////// - - int num_tensor_buffer_types; - LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/0, &kInput0TensorType, - &input_0_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_0_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_0_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_0_tensor_buffer_requirements, /*type_index=*/0, - &input_0_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t input_0_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); - - LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/1, &kInput1TensorType, - &input_1_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_1_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_1_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_1_tensor_buffer_requirements, /*type_index=*/0, - &input_1_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t input_1_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); - - LiteRtTensorBufferRequirements output_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetOutputRequirements( - invocation_context, /*output_index=*/0, &kOutputTensorType, - &output_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - output_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType output_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - output_tensor_buffer_requirements, /*type_index=*/0, - &output_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t output_tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - output_tensor_buffer_requirements, &output_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); - - // /////////////////////////////////////////////////////////////////////////// - // Allocate tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBuffer input_0_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_0_tensor_buffer_type, &kInput0TensorType, - input_0_tensor_buffer_size, &input_0_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer input_1_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_1_tensor_buffer_type, &kInput1TensorType, - input_1_tensor_buffer_size, &input_1_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer output_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - output_tensor_buffer_type, &kOutputTensorType, - output_tensor_buffer_size, &output_tensor_buffer), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Register tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBufferHandle input_1_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_1_tensor_buffer, &input_1_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle input_0_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_0_tensor_buffer, &input_0_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle output_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, output_tensor_buffer, &output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Attach tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Attach sync fences to input buffers. - // /////////////////////////////////////////////////////////////////////////// - - Fence input_fence_0 = platforms::darwinn::fence_util::CreateFence(); - Fence input_fence_1 = platforms::darwinn::fence_util::CreateFence(); - - LiteRtEvent input_event_0; - ASSERT_EQ(LiteRtCreateEventFromSyncFenceFd(input_fence_0->GetFd(), - /*owns_fd=*/false, &input_event_0), - kLiteRtStatusOk); - - LiteRtEvent input_event_1; - ASSERT_EQ(LiteRtCreateEventFromSyncFenceFd(input_fence_1->GetFd(), - /*owns_fd=*/false, &input_event_1), - kLiteRtStatusOk); - - ASSERT_EQ(LiteRtDispatchAttachInputEvent( - invocation_context, /*graph_input_index=*/0, input_event_0), - kLiteRtStatusOk); - ASSERT_EQ(LiteRtDispatchAttachInputEvent( - invocation_context, /*graph_input_index=*/1, input_event_1), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Execute model. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking execution..."; - LiteRtEvent output_event = nullptr; - EXPECT_EQ(LiteRtDispatchInvokeAsync(invocation_context, 1, &output_event), - kLiteRtStatusOk); - ASSERT_NE(output_event, nullptr); - - // Attach output event to output tensor buffer. - ASSERT_EQ(LiteRtSetTensorBufferEvent(output_tensor_buffer, output_event), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Signal input fences. - // /////////////////////////////////////////////////////////////////////////// - - ASSERT_OK(input_fence_0->Signal(/*success=*/true)); - ASSERT_OK(input_fence_1->Signal(/*success=*/true)); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Clean up resources. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtDestroyEvent(input_event_0); - LiteRtDestroyEvent(input_event_1); - LiteRtDestroyEvent(output_event); - - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), - kLiteRtStatusOk); - LiteRtDestroyTensorBuffer(output_tensor_buffer); - LiteRtDestroyTensorBuffer(input_1_tensor_buffer); - LiteRtDestroyTensorBuffer(input_0_tensor_buffer); - EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), - kLiteRtStatusOk); -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc deleted file mode 100644 index 2d2cca562552ff..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc +++ /dev/null @@ -1,291 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -using ::testing::Pointwise; - -TEST(DispatchApi, GoogleTensor) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "This test is specific to Android devices with a GoogleTensor eTPU"; -#endif - - LiteRtDispatchOption dispatch_option = { - /*.name=*/kDispatchOptionSharedLibraryDir, - /*.value=*/*litert::ToLiteRtAny(std::any("/data/local/tmp")), - }; - ASSERT_EQ( - LiteRtDispatchInitialize(/*options=*/&dispatch_option, /*num_options=*/1), - kLiteRtStatusOk); - - const char* vendor_id; - EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "vendor_id: " << vendor_id; - - const char* build_id; - EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "build_id: " << build_id; - - LiteRtApiVersion api_version; - EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); - ABSL_LOG(INFO) << "api_version: " << api_version.major << "." - << api_version.minor << "." << api_version.patch; - - int capabilities; - EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); - ABSL_LOG(INFO) << "capabilities: " << capabilities; - - LiteRtDispatchDeviceContext device_context = nullptr; - EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "device_context: " << device_context; - - auto model_file_name = - litert::testing::GetTestFilePath(kGoogleTensorModelFileName); - auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model) << model.Error(); - ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() - << " bytes"; - - // /////////////////////////////////////////////////////////////////////////// - // Set up an invocation context for a given model. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtMemBuffer exec_bytecode_buffer = {/*.fd=*/-1, - /*.base_addr=*/model->Data(), - /*.offset=*/0, - /*.size=*/model->Size()}; - LiteRtDispatchInvocationContext invocation_context = nullptr; - EXPECT_EQ(LiteRtDispatchInvocationContextCreate( - device_context, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, /*function_name=*/nullptr, - /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "Invocation context: " << invocation_context; - - // /////////////////////////////////////////////////////////////////////////// - // Determine tensor buffer requirements. - // /////////////////////////////////////////////////////////////////////////// - - int num_tensor_buffer_types; - LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/0, &kInput0TensorType, - &input_0_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_0_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_0_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_0_tensor_buffer_requirements, /*type_index=*/0, - &input_0_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t input_0_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); - - LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/1, &kInput1TensorType, - &input_1_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_1_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_1_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_1_tensor_buffer_requirements, /*type_index=*/0, - &input_1_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t input_1_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); - - LiteRtTensorBufferRequirements output_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetOutputRequirements( - invocation_context, /*output_index=*/0, &kOutputTensorType, - &output_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - output_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType output_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - output_tensor_buffer_requirements, /*type_index=*/0, - &output_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t output_tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - output_tensor_buffer_requirements, &output_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); - - // /////////////////////////////////////////////////////////////////////////// - // Allocate tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBuffer input_0_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_0_tensor_buffer_type, &kInput0TensorType, - input_0_tensor_buffer_size, &input_0_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer input_1_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_1_tensor_buffer_type, &kInput1TensorType, - input_1_tensor_buffer_size, &input_1_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer output_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - output_tensor_buffer_type, &kOutputTensorType, - output_tensor_buffer_size, &output_tensor_buffer), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Register tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBufferHandle input_1_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_1_tensor_buffer, &input_1_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle input_0_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_0_tensor_buffer, &input_0_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle output_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, output_tensor_buffer, &output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Attach tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Clean up resources. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), - kLiteRtStatusOk); - LiteRtDestroyTensorBuffer(output_tensor_buffer); - LiteRtDestroyTensorBuffer(input_1_tensor_buffer); - LiteRtDestroyTensorBuffer(input_0_tensor_buffer); - EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), - kLiteRtStatusOk); -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc deleted file mode 100644 index 342c469a7cdb68..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h" - -#include -#include - -#if __ANDROID__ -#include -#endif // __ANDROID__ - -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h" - -using litert::Error; -using litert::Expected; -using litert::Unexpected; - -LiteRtDispatchDeviceContextT::~LiteRtDispatchDeviceContextT() { - if (!thr_graphs_.empty()) { - auto thr_graph_delete = southbound_.api().thr_graph_delete; - if (!thr_graph_delete) { - LITERT_LOG(LITERT_ERROR, "thr_graph_delete not found"); - } else { - for (auto* thr_graph : thr_graphs_) { - thr_graph_delete(thr_graph); - } - } - } - - if (thr_context_) { - auto thr_context_delete = southbound_.api().thr_context_delete; - if (!thr_context_delete) { - LITERT_LOG(LITERT_ERROR, "thr_context_delete not found"); - } else { - thr_context_delete(thr_context_); - } - } -} - -Expected -LiteRtDispatchDeviceContextT::Create( - const litert::google_tensor::Southbound& southbound) { - Ptr device_context(new LiteRtDispatchDeviceContextT(southbound)); - - auto thr_context_create = southbound.api().thr_context_create; - if (!thr_context_create) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "thr_context_create not found"); - } - - device_context->thr_context_ = thr_context_create(); - return device_context; -} - -Expected -LiteRtDispatchDeviceContextT::RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer) { - LiteRtTensorBufferType tensor_buffer_type; - if (auto status = - LiteRtGetTensorBufferType(tensor_buffer, &tensor_buffer_type); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get buffer type"); - } - - if (tensor_buffer_type != kLiteRtTensorBufferTypeAhwb) { - return Error(kLiteRtStatusErrorUnsupported, "Unsupported buffer type"); - } - - size_t tensor_buffer_size; - if (auto status = - LiteRtGetTensorBufferSize(tensor_buffer, &tensor_buffer_size); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get buffer size"); - } - - size_t tensor_buffer_offset; - if (auto status = - LiteRtGetTensorBufferOffset(tensor_buffer, &tensor_buffer_offset); - status != kLiteRtStatusOk) { - if (status == kLiteRtStatusErrorNotFound) { - tensor_buffer_offset = 0; - } else { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get buffer offset"); - } - } - - LiteRtRankedTensorType tensor_type; - if (auto status = - LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get tensor buffer type"); - } - - auto* tensor_strides = tensor_type.layout.strides; - if (tensor_strides != nullptr) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Tensor strides are not supported"); - } - - AHardwareBuffer* ahwb; -#if LITERT_HAS_AHWB_SUPPORT - if (auto status = LiteRtGetTensorBufferAhwb(tensor_buffer, &ahwb); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get AHWB"); - } -#else - return Error(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer is not supported on this platform"); -#endif - - ThrBufferHandle thr_buffer_handle; - - if (tensor_buffer_offset == 0) { - auto thr_register_buffer = southbound_.api().thr_register_buffer; - if (!thr_register_buffer) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_register_buffer not found"); - } - - if (auto status = thr_register_buffer( - thr_context_, ThrBufferType::kThrBufferTypeAHardwareBuffer, ahwb, - tensor_buffer_size, &thr_buffer_handle); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_register_buffer failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_register_buffer failed"); - } - - } else { - auto thr_register_buffer_with_offset = - southbound_.api().thr_register_buffer_with_offset; - if (!thr_register_buffer_with_offset) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_register_buffer_with_offset not found"); - } - - if (auto status = thr_register_buffer_with_offset( - thr_context_, ThrBufferType::kThrBufferTypeAHardwareBuffer, ahwb, - tensor_buffer_offset, tensor_buffer_size, &thr_buffer_handle); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_register_buffer_with_offset failed: %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_register_buffer_with_offset failed"); - } - } - - return thr_buffer_handle; -} - -litert::Expected LiteRtDispatchDeviceContextT::UnregisterTensorBuffer( - LiteRtTensorBufferHandle tensor_buffer_handle) { - auto thr_unregister_buffer = southbound_.api().thr_unregister_buffer; - if (!thr_unregister_buffer) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_unregister_buffer not found"); - } - - ThrBufferHandle thr_buffer_handle = tensor_buffer_handle; - if (auto status = thr_unregister_buffer(thr_context_, thr_buffer_handle); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_unregister_buffer failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_unregister_buffer failed"); - } - - return {}; -} - -litert::Expected -LiteRtDispatchDeviceContextT::CreateGraph() { - auto thr_graph_create = southbound_.api().thr_graph_create; - if (!thr_graph_create) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_create not found"); - } - - ThrGraph* thr_graph = thr_graph_create(thr_context_); - if (!thr_graph) { - return Error(kLiteRtStatusErrorRuntimeFailure, "thr_graph_create failed"); - } - - return new LiteRtDispatchGraphT(southbound_, thr_graph, this); -} - -litert::Expected LiteRtDispatchDeviceContextT::DestroyGraph( - LiteRtDispatchGraph graph) { - auto thr_graph_delete = southbound_.api().thr_graph_delete; - if (!thr_graph_delete) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_delete not found"); - } - - thr_graphs_.erase(graph->thr_graph()); - - ThrGraph* thr_graph = graph->thr_graph(); - if (auto status = thr_graph_delete(thr_graph); status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_destroy failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, "thr_graph_destroy failed"); - } - - delete graph; - return {}; -} - -litert::Expected -LiteRtDispatchDeviceContextT::LoadExecutable( - LiteRtDispatchExecutableType type, const LiteRtMemBuffer* bytecode_buffer) { - auto thr_load_sq_container = southbound_.api().thr_load_sq_container; - if (!thr_load_sq_container) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_load_sq_container not found"); - } - - ThrSqContainerType thr_type; - switch (type) { - case kLiteRtDispatchExecutableTypeDspLibrary: - thr_type = kThrSqContainerTypeFunctionLibrary; - break; - case kLiteRtDispatchExecutableTypeMlModel: - thr_type = kThrSqContainerTypeMlModel; - break; - default: - LITERT_LOG(LITERT_ERROR, "Unexpected executable type: %d", type); - return Error(kLiteRtStatusErrorRuntimeFailure, - "Unexpected executable type"); - } - - ThrSqContainerHandle sq_handle; - ThrStatus status; - if (bytecode_buffer->fd >= 0 && - // Unfortunately thrLoadSqContainerFd doesn't support passing an - // offset. So if the offset is non-zero, we fallback to passing a CPU - // memory address right below. - (bytecode_buffer->offset == 0)) { - bool lazy_loading = false; - status = southbound_.api().thr_load_sq_container_fd( - thr_context_, thr_type, bytecode_buffer->fd, bytecode_buffer->size, - lazy_loading, &sq_handle); - } else { - auto bytecode_ptr = - static_cast(bytecode_buffer->base_addr) + - bytecode_buffer->offset; - status = southbound_.api().thr_load_sq_container( - thr_context_, thr_type, bytecode_ptr, bytecode_buffer->size, - &sq_handle); - } - if (status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_load_sq_container failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_load_sq_container failed"); - } - - return sq_handle; -} - -litert::Expected LiteRtDispatchDeviceContextT::UnloadExecutable( - LiteRtDispatchExecutableHandle exec_handle) { - auto thr_unload_sq_container = southbound_.api().thr_unload_sq_container; - if (!thr_unload_sq_container) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_unload_sq_container not found"); - } - - ThrSqContainerHandle sq_handle = exec_handle; - if (auto status = thr_unload_sq_container(thr_context_, sq_handle); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_unload_sq_container failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_unload_sq_container failed"); - } - - return {}; -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h deleted file mode 100644 index 4a7074d49ede66..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ - -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" - -class LiteRtDispatchDeviceContextT { - public: - using Ptr = std::unique_ptr; - - ~LiteRtDispatchDeviceContextT(); - - static litert::Expected Create( - const litert::google_tensor::Southbound& southbound); - - litert::Expected RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer); - - litert::Expected UnregisterTensorBuffer( - LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected CreateGraph(); - litert::Expected DestroyGraph(LiteRtDispatchGraph graph); - - litert::Expected LoadExecutable( - LiteRtDispatchExecutableType type, - const LiteRtMemBuffer* bytecode_buffer); - - litert::Expected UnloadExecutable( - LiteRtDispatchExecutableHandle exec_handle); - - ThrContext* thr_context() { return thr_context_; } - - void add_graph(ThrGraph* graph) { thr_graphs_.insert(graph); } - - private: - explicit LiteRtDispatchDeviceContextT( - const litert::google_tensor::Southbound& southbound) - : southbound_(southbound) {} - - const litert::google_tensor::Southbound& southbound_; - ThrContext* thr_context_ = nullptr; - absl::flat_hash_set thr_graphs_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.cc deleted file mode 100644 index d3530b56d57f46..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.cc +++ /dev/null @@ -1,305 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h" - -#include -#include - -#include "absl/strings/string_view.h" -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -using litert::Error; -using litert::Expected; - -namespace { - -// We store THR names in a global set as a workaround to b/369144429. -std::set* ThrNames = new std::set(); - -absl::string_view ThrNodeIdStr(LiteRtDispatchNodeId node_id) { - auto str = "node_" + std::to_string(node_id); - auto iter = ThrNames->find(str); - if (iter == ThrNames->end()) { - iter = ThrNames->insert(iter, str); - } - return *iter; -} - -} // namespace - -absl::string_view ThrEdgeIdStr(LiteRtDispatchEdgeId edge_id) { - auto str = "edge_" + std::to_string(edge_id); - auto iter = ThrNames->find(str); - if (iter == ThrNames->end()) { - iter = ThrNames->insert(iter, str); - } - return *iter; -} - -litert::Expected LiteRtDispatchGraphT::AddNode( - LiteRtDispatchNodeId node_id, LiteRtDispatchNodeType node_type) { - auto thr_graph_add_sq_node = southbound_.api().thr_graph_add_sq_node; - if (!thr_graph_add_sq_node) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_add_sq_node not found"); - } - - auto thr_node_id = ThrNodeIdStr(node_id); - ThrNodeType thr_node_type; - switch (node_type) { - case kLiteRtDispatchNodeTypeDsp: - thr_node_type = kThrNodeTypeDsp; - break; - case kLiteRtDispatchNodeTypeNpu: - thr_node_type = kThrNodeTypeNpu; - break; - default: - LITERT_LOG(LITERT_ERROR, "Unexpected node type: %d", node_type); - return Error(kLiteRtStatusErrorRuntimeFailure, "Unexpected node type"); - } - - if (auto status = - thr_graph_add_sq_node(thr_graph_, thr_node_id.data(), thr_node_type); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_add_sq_node failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_add_sq_node failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::AddEdge( - LiteRtDispatchEdgeId edge_id) { - auto thr_graph_add_edge = southbound_.api().thr_graph_add_edge; - if (!thr_graph_add_edge) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_add_edge not found"); - } - - auto thr_edge_id = ThrEdgeIdStr(edge_id); - ThrEdgeType thr_edge_type = kThrEdgeNoType; - if (auto status = - thr_graph_add_edge(thr_graph_, thr_edge_id.data(), thr_edge_type); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_add_edge failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, "thr_graph_add_edge failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::ConnectNodeInput( - LiteRtDispatchNodeId node_id, int input_index, - LiteRtDispatchEdgeId edge_id) { - auto thr_graph_connect_node_input = - southbound_.api().thr_graph_connect_node_input; - if (!thr_graph_connect_node_input) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_connect_node_input not found"); - } - - int next_input_index = NextNodeInputIndex(node_id); - if (input_index != next_input_index) { - LITERT_LOG(LITERT_ERROR, "Unexpected input index %d, expected %d", - input_index, next_input_index); - return Error(kLiteRtStatusErrorRuntimeFailure, "Unexpected input index"); - } - - auto thr_node_id = ThrNodeIdStr(node_id); - auto thr_edge_id = ThrEdgeIdStr(edge_id); - if (auto status = thr_graph_connect_node_input(thr_graph_, thr_node_id.data(), - thr_edge_id.data()); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_set_input_edge failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_set_input_edge failed"); - } - - AddInputEdge(input_index, edge_id); - return {}; -} - -litert::Expected LiteRtDispatchGraphT::ConnectNodeOutput( - LiteRtDispatchNodeId node_id, int output_index, - LiteRtDispatchEdgeId edge_id) { - auto thr_graph_connect_node_output = - southbound_.api().thr_graph_connect_node_output; - if (!thr_graph_connect_node_output) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_connect_node_output not found"); - } - - int next_output_index = NextNodeOutputIndex(node_id); - if (output_index != next_output_index) { - LITERT_LOG(LITERT_ERROR, "Unexpected output index %d, expected %d", - output_index, next_output_index); - return Error(kLiteRtStatusErrorRuntimeFailure, "Unexpected output index"); - } - - auto thr_node_id = ThrNodeIdStr(node_id); - auto thr_edge_id = ThrEdgeIdStr(edge_id); - if (auto status = thr_graph_connect_node_output( - thr_graph_, thr_node_id.data(), thr_edge_id.data()); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_set_output_edge failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_set_output_edge failed"); - } - - AddOutputEdge(output_index, edge_id); - return {}; -} - -litert::Expected LiteRtDispatchGraphT::ConnectGraphInput( - int input_index, LiteRtDispatchEdgeId edge_id) { - int next_input_index = NextGraphInputIndex(); - if (input_index != next_input_index) { - LITERT_LOG(LITERT_ERROR, "Unexpected input index %d, expected %d", - input_index, next_input_index); - return Error(kLiteRtStatusErrorRuntimeFailure, "Unexpected input index"); - } - - auto thr_graph_set_input_edge = southbound_.api().thr_graph_set_input_edge; - if (!thr_graph_set_input_edge) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_set_input_edge not found"); - } - - auto thr_edge_id = ThrEdgeIdStr(edge_id); - if (auto status = thr_graph_set_input_edge(thr_graph_, thr_edge_id.data()); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_set_input_edge failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_set_input_edge failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::ConnectGraphOutput( - int output_index, LiteRtDispatchEdgeId edge_id) { - int next_output_index = NextGraphOutputIndex(); - if (output_index != next_output_index) { - LITERT_LOG(LITERT_ERROR, "Unexpected output index %d, expected %d", - output_index, next_output_index); - return Error(kLiteRtStatusErrorRuntimeFailure, "Unexpected output index"); - } - - auto thr_graph_set_output_edge = southbound_.api().thr_graph_set_output_edge; - if (!thr_graph_set_output_edge) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_set_output_edge not found"); - } - - auto thr_edge_id = ThrEdgeIdStr(edge_id); - if (auto status = thr_graph_set_output_edge(thr_graph_, thr_edge_id.data()); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_set_output_edge failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_set_output_edge failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::AssignNodeFunction( - LiteRtDispatchNodeId node_id, LiteRtDispatchExecutableHandle exec_handle, - const char* function_name) { - auto thr_graph_assign_sq = southbound_.api().thr_graph_assign_sq; - if (!thr_graph_assign_sq) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_assign_sq not found"); - } - - auto thr_node_id = ThrNodeIdStr(node_id); - ThrSqContainerHandle sq_handle = exec_handle; - // An empty function name represent no function name being provided and - // therefore we must pass a nullptr to the call below, otherwise the SB API - // will expect a model with a signature. See b/378913220. - const char* function_name_ptr = - absl::string_view(function_name).empty() ? nullptr : function_name; - if (auto status = thr_graph_assign_sq(thr_graph_, thr_node_id.data(), - sq_handle, function_name_ptr); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_assign_sq failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_assign_sq failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::AnnotateGraph(const char* key, - const char* value) { - auto thr_graph_annotate_graph = southbound_.api().thr_graph_annotate_graph; - if (!thr_graph_annotate_graph) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_annotate_graph not found"); - } - - if (auto status = thr_graph_annotate_graph(thr_graph_, key, value); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_graph failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_annotate_graph failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::AnnotateNode( - LiteRtDispatchNodeId node_id, const char* key, const char* value) { - auto thr_graph_annotate_node = southbound_.api().thr_graph_annotate_node; - if (!thr_graph_annotate_node) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_annotate_node not found"); - } - - auto thr_node_id = ThrNodeIdStr(node_id); - if (auto status = - thr_graph_annotate_node(thr_graph_, thr_node_id.data(), key, value); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_node failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_annotate_node failed"); - } - - return {}; -} - -litert::Expected LiteRtDispatchGraphT::AnnotateEdge( - LiteRtDispatchEdgeId edge_id, const char* key, const char* value) { - auto thr_graph_annotate_edge = southbound_.api().thr_graph_annotate_edge; - if (!thr_graph_annotate_edge) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_annotate_edge not found"); - } - - auto thr_edge_id = ThrEdgeIdStr(edge_id); - if (auto status = - thr_graph_annotate_edge(thr_graph_, thr_edge_id.data(), key, value); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_edge failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_graph_annotate_edge failed"); - } - - return {}; -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h deleted file mode 100644 index 6586e58f9bd637..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_GRAPH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_GRAPH_H_ - -#include -#include - -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" - -class LiteRtDispatchGraphT { - public: - LiteRtDispatchGraphT(const litert::google_tensor::Southbound& southbound, - ThrGraph* thr_graph, - LiteRtDispatchDeviceContext device_context) - : southbound_(southbound), - thr_graph_(thr_graph), - device_context_(device_context) {} - - ThrGraph* thr_graph() { return thr_graph_; } - - LiteRtDispatchDeviceContext device_context() { return device_context_; } - - litert::Expected AddNode(LiteRtDispatchNodeId node_id, - LiteRtDispatchNodeType node_type); - litert::Expected AddEdge(LiteRtDispatchEdgeId edge_id); - - litert::Expected InputEdge(int input_index) const { - return IoEdge(input_index, input_edges_); - } - - litert::Expected OutputEdge(int output_index) const { - return IoEdge(output_index, output_edges_); - } - - size_t NumOutputs() const { return output_edges_.size(); } - - litert::Expected ConnectNodeInput(LiteRtDispatchNodeId node_id, - int input_index, - LiteRtDispatchEdgeId edge_id); - - litert::Expected ConnectNodeOutput(LiteRtDispatchNodeId node_id, - int output_index, - LiteRtDispatchEdgeId edge_id); - - litert::Expected ConnectGraphInput(int input_index, - LiteRtDispatchEdgeId edge_id); - - litert::Expected ConnectGraphOutput(int output_index, - LiteRtDispatchEdgeId edge_id); - - litert::Expected AssignNodeFunction( - LiteRtDispatchNodeId node_id, LiteRtDispatchExecutableHandle exec_handle, - const char* function_name); - - litert::Expected AnnotateGraph(const char* key, const char* value); - - litert::Expected AnnotateNode(LiteRtDispatchNodeId node_id, - const char* key, const char* value); - - litert::Expected AnnotateEdge(LiteRtDispatchEdgeId edge_id, - const char* key, const char* value); - - private: - using NextNodeIoIndexMap = std::map; - using IoIndexToEdgeIdMap = std::map; - - int NextNodeOutputIndex(LiteRtDispatchNodeId node_id) { - return NextNodeIoIndex(node_id, next_node_output_index_); - } - - int NextGraphInputIndex() { return next_graph_input_index_++; } - - int NextGraphOutputIndex() { return next_graph_output_index_++; } - - int NextNodeIoIndex(LiteRtDispatchNodeId node_id, NextNodeIoIndexMap& map) { - return map[node_id]++; - } - - litert::Expected IoEdge( - int io_index, const IoIndexToEdgeIdMap& map) const { - auto iter = map.find(io_index); - if (iter == map.end()) { - return litert::Unexpected(kLiteRtStatusErrorNotFound, - "Unexpected graph input/output index"); - } - return iter->second; - } - - int NextNodeInputIndex(LiteRtDispatchNodeId node_id) { - return NextNodeIoIndex(node_id, next_node_input_index_); - } - - void AddInputEdge(int input_index, LiteRtDispatchEdgeId edge_id) { - input_edges_[input_index] = edge_id; - } - - void AddOutputEdge(int output_index, LiteRtDispatchEdgeId edge_id) { - output_edges_[output_index] = edge_id; - } - - const litert::google_tensor::Southbound& southbound_; - ThrGraph* thr_graph_; - LiteRtDispatchDeviceContext device_context_; - NextNodeIoIndexMap next_node_input_index_; - NextNodeIoIndexMap next_node_output_index_; - int next_graph_input_index_ = 0; - int next_graph_output_index_ = 0; - IoIndexToEdgeIdMap input_edges_; - IoIndexToEdgeIdMap output_edges_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_GRAPH_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc deleted file mode 100644 index ac0a845c56ea4b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc +++ /dev/null @@ -1,613 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h" - -#include - -#include "absl/strings/string_view.h" -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/core/util/tensor_type_util.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_metrics.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" - -using litert::Error; -using litert::Expected; -using litert::Unexpected; - -extern absl::string_view ThrEdgeIdStr(LiteRtDispatchEdgeId edge_id); - -namespace { - -constexpr const size_t kEdgeTpuPadding = 64; - -template -inline constexpr auto Pad(X x, Align align) { - return ((x + align - 1) / align) * align; -} - -} // namespace - -litert::Expected -LiteRtDispatchInvocationContextT::CreateFromBytecode( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs) { - auto graph = device_context->CreateGraph(); - if (!graph) { - return graph.Error(); - } - - LiteRtDispatchNodeId node_id = 0; - LiteRtDispatchNodeType node_type; - switch (exec_type) { - case kLiteRtDispatchExecutableTypeDspLibrary: - node_type = kLiteRtDispatchNodeTypeDsp; - break; - case kLiteRtDispatchExecutableTypeMlModel: - node_type = kLiteRtDispatchNodeTypeNpu; - break; - default: - LITERT_LOG(LITERT_ERROR, "Unexpected executable type: %d", exec_type); - return Error(kLiteRtStatusErrorInvalidArgument, - "Unexpected executable type"); - } - - if (auto status = (*graph)->AddNode(node_id, node_type); !status) { - return status.Error(); - } - - auto exec_handle = - device_context->LoadExecutable(exec_type, exec_bytecode_buffer); - if (!exec_handle) { - return exec_handle.Error(); - } - - if (auto status = - (*graph)->AssignNodeFunction(node_id, *exec_handle, function_name); - !status) { - return status.Error(); - } - - LiteRtDispatchEdgeId next_edge_id = 0; - - for (auto input_index = 0; input_index < num_inputs; ++input_index) { - LiteRtDispatchEdgeId edge_id = next_edge_id++; - if (auto status = (*graph)->AddEdge(edge_id); !status) { - return status.Error(); - } - if (auto status = (*graph)->ConnectGraphInput(input_index, edge_id); - !status) { - return status.Error(); - } - if (auto status = (*graph)->ConnectNodeInput(node_id, input_index, edge_id); - !status) { - return status.Error(); - } - } - - for (auto output_index = 0; output_index < num_outputs; ++output_index) { - LiteRtDispatchEdgeId edge_id = next_edge_id++; - if (auto status = (*graph)->AddEdge(edge_id); !status) { - return status.Error(); - } - if (auto status = - (*graph)->ConnectNodeOutput(node_id, output_index, edge_id); - !status) { - return status.Error(); - } - if (auto status = (*graph)->ConnectGraphOutput(output_index, edge_id); - !status) { - return status.Error(); - } - } - - auto invocation_context = CreateFromGraph(southbound, device_context, *graph); - if (!invocation_context) { - return invocation_context.Error(); - } - - (*invocation_context)->AttachExecutable(*exec_handle); - - return invocation_context; -} - -litert::Expected -LiteRtDispatchInvocationContextT::CreateFromGraph( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph) { - auto thr_invocation_context_get = southbound.api().thr_invocation_context_get; - if (!thr_invocation_context_get) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_get not found"); - } - - ThrGraph* thr_graph = graph->thr_graph(); - auto thr_icontext = - thr_invocation_context_get(thr_graph, device_context->thr_context()); - if (!thr_icontext) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_get failed"); - } - - device_context->add_graph(thr_graph); - return Ptr(new LiteRtDispatchInvocationContextT(southbound, thr_icontext, - device_context, graph)); -} - -LiteRtDispatchInvocationContextT::~LiteRtDispatchInvocationContextT() { - auto thr_invocation_context_delete = - southbound_.api().thr_invocation_context_delete; - if (!thr_invocation_context_delete) { - LITERT_LOG(LITERT_ERROR, "thr_invocation_context_delete not found"); - } else { - ThrGraph* thr_graph = graph_->thr_graph(); - if (auto status = - thr_invocation_context_delete(thr_graph, thr_invocation_context_); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_invocation_context_delete failed: %d", - status); - } - } - - if (exec_handle_) { - device_context_->UnloadExecutable(*exec_handle_); - } -} - -namespace { - -Expected GetTensorBufferRequirements( - const LiteRtRankedTensorType& tensor_type) { - auto* tensor_strides = tensor_type.layout.strides; - if (tensor_strides != nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Tensor strides are not supported on GoogleTensor"); - } - - LiteRtTensorBufferType supported_tensor_buffer_types[] = { - kLiteRtTensorBufferTypeAhwb, - }; - int num_supported_tensor_buffer_types = - sizeof(supported_tensor_buffer_types) / - sizeof(supported_tensor_buffer_types[0]); - - auto buffer_size = litert::internal::GetNumPackedBytes(tensor_type); - if (!buffer_size) { - return Unexpected(buffer_size.Error()); - } - - size_t padded_buffer_size = Pad(*buffer_size, kEdgeTpuPadding); - - LiteRtTensorBufferRequirements requirements; - if (auto status = LiteRtCreateTensorBufferRequirements( - num_supported_tensor_buffer_types, supported_tensor_buffer_types, - padded_buffer_size, /*num_strides=*/0, /*strides=*/nullptr, - &requirements); - status != kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create tensor buffer requirements"); - } - - return requirements; -} -} // namespace - -Expected -LiteRtDispatchInvocationContextT::GetInputRequirements( - int input_index, const LiteRtRankedTensorType& tensor_type) { - return GetTensorBufferRequirements(tensor_type); -} - -Expected -LiteRtDispatchInvocationContextT::GetOutputRequirements( - int output_index, const LiteRtRankedTensorType& tensor_type) { - return GetTensorBufferRequirements(tensor_type); -} - -namespace { - -litert::Expected AttachBufferHelper( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchInvocationContext invocation_context, - LiteRtDispatchEdgeId edge_id, - LiteRtTensorBufferHandle tensor_buffer_handle) { - auto thr_invocation_context_attach_buffer = - southbound.api().thr_invocation_context_attach_buffer; - if (!thr_invocation_context_attach_buffer) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_attach_buffer not found"); - } - - ThrInvocationContext* thr_icontext = - invocation_context->thr_invocation_context(); - ThrContext* thr_context = invocation_context->device_context()->thr_context(); - auto thr_edge_id = ThrEdgeIdStr(edge_id); - ThrBufferHandle thr_buffer_handle = tensor_buffer_handle; - if (auto status = thr_invocation_context_attach_buffer( - thr_icontext, thr_context, thr_edge_id.data(), thr_buffer_handle); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_invocation_context_attach_buffer failed: %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_attach_buffer failed"); - } - - return {}; -} - -} // namespace - -litert::Expected LiteRtDispatchInvocationContextT::AttachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = graph_->InputEdge(graph_input_index); result) { - auto edge_id = *result; - return AttachBufferHelper(southbound_, this, edge_id, tensor_buffer_handle); - } else { - return result.Error(); - } -} - -litert::Expected LiteRtDispatchInvocationContextT::AttachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = graph_->OutputEdge(graph_output_index); result) { - auto edge_id = *result; - return AttachBufferHelper(southbound_, this, edge_id, tensor_buffer_handle); - } else { - return result.Error(); - } -} - -namespace { - -litert::Expected DetachTensorBufferHelper( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchInvocationContext invocation_context, - LiteRtDispatchEdgeId edge_id, - LiteRtTensorBufferHandle tensor_buffer_handle) { - auto thr_invocation_context_detach_buffer = - southbound.api().thr_invocation_context_detach_buffer; - if (!thr_invocation_context_detach_buffer) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_detach_buffer not found"); - } - - ThrInvocationContext* thr_icontext = - invocation_context->thr_invocation_context(); - ThrContext* thr_context = invocation_context->device_context()->thr_context(); - auto thr_edge_id = ThrEdgeIdStr(edge_id); - ThrBufferHandle thr_buffer_handle = tensor_buffer_handle; - if (auto status = thr_invocation_context_detach_buffer( - thr_icontext, thr_context, thr_edge_id.data(), thr_buffer_handle); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_invocation_context_detach_buffer failed: %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_detach_buffer failed"); - } - - return {}; -} - -} // namespace - -litert::Expected LiteRtDispatchInvocationContextT::DetachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = graph_->InputEdge(graph_input_index); result) { - auto edge_id = *result; - return DetachTensorBufferHelper(southbound_, this, edge_id, - tensor_buffer_handle); - } else { - return result.Error(); - } -} - -litert::Expected LiteRtDispatchInvocationContextT::DetachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto result = graph_->OutputEdge(graph_output_index); result) { - auto edge_id = *result; - return DetachTensorBufferHelper(southbound_, this, edge_id, - tensor_buffer_handle); - } else { - return result.Error(); - } -} - -namespace { - -litert::Expected PrepareForInvoke( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchInvocationContext invocation_context, - bool create_output_sync_fence) { - auto thr_invocation_context_prepare_for_invoke = - southbound.api().thr_invocation_context_prepare_for_invoke; - if (!thr_invocation_context_prepare_for_invoke) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_prepare_for_invoke not found"); - } - - ThrInvocationContext* thr_icontext = - invocation_context->thr_invocation_context(); - if (auto status = thr_invocation_context_prepare_for_invoke( - thr_icontext, create_output_sync_fence); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, - "thr_invocation_context_prepare_for_invoke failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_prepare_for_invoke failed"); - } - - return {}; -} - -litert::Expected InvokeOnce( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchInvocationContext invocation_context) { - auto thr_invocation_context_invoke_once = - southbound.api().thr_invocation_context_invoke_once; - if (!thr_invocation_context_invoke_once) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_invoke_once not found"); - } - - ThrInvocationContext* thr_icontext = - invocation_context->thr_invocation_context(); - if (auto status = thr_invocation_context_invoke_once(thr_icontext); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_invocation_context_invoke_once failed: %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_invoke_once failed"); - } - - return {}; -} - -litert::Expected Wait( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchInvocationContext invocation_context) { - auto thr_invocation_context_wait = - southbound.api().thr_invocation_context_wait; - if (!thr_invocation_context_wait) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_wait not found"); - } - - ThrInvocationContext* thr_icontext = - invocation_context->thr_invocation_context(); - if (auto status = thr_invocation_context_wait(thr_icontext); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, "thr_invocation_context_wait failed: %d", status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_wait failed"); - } - - return {}; -} - -} // namespace - -litert::Expected LiteRtDispatchInvocationContextT::Invoke() { - if (auto result = PrepareForInvoke(southbound_, this, - /*create_output_sync_fence=*/false); - !result) { - return result.Error(); - } - if (auto result = InvokeOnce(southbound_, this); !result) { - return result.Error(); - } - return Wait(southbound_, this); -} - -litert::Expected LiteRtDispatchInvocationContextT::AttachInputEvent( - int graph_input_index, LiteRtEvent input_event) { - int input_fence_fd; - if (auto status = LiteRtGetEventSyncFenceFd(input_event, &input_fence_fd); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get sync fence fd from event"); - } - - auto edge = graph_->InputEdge(graph_input_index); - if (!edge) { - LITERT_LOG(LITERT_ERROR, "Unexpected graph input index: %d", - graph_input_index); - return edge.Error(); - } - auto edge_id = *edge; - - auto thr_invocation_context_attach_input_buffer_sync_fence = - southbound_.api().thr_invocation_context_attach_input_buffer_sync_fence; - if (!thr_invocation_context_attach_input_buffer_sync_fence) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_attach_input_buffer_sync_fence not found"); - } - - auto thr_edge_id = ThrEdgeIdStr(edge_id); - if (auto status = thr_invocation_context_attach_input_buffer_sync_fence( - thr_invocation_context_, thr_edge_id.data(), input_fence_fd); - status != kThrStatusSuccess) { - LITERT_LOG( - LITERT_ERROR, - "thr_invocation_context_attach_input_buffer_sync_fence failed: %d", - status); - return Error( - kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_attach_input_buffer_sync_fence failed"); - } - - input_sync_fences_[thr_edge_id.data()] = input_fence_fd; - return {}; -} - -namespace { - -litert::Expected GetOutputEvent( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtEvent* output_event) { - auto edge = invocation_context->graph()->OutputEdge(graph_output_index); - if (!edge) { - LITERT_LOG(LITERT_ERROR, "Unexpected graph output index: %d", - graph_output_index); - return edge.Error(); - } - auto edge_id = *edge; - - auto thr_invocation_context_get_output_buffer_sync_fence = - southbound.api().thr_invocation_context_get_output_buffer_sync_fence; - if (!thr_invocation_context_get_output_buffer_sync_fence) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_get_output_buffer_sync_fence not found"); - } - - ThrInvocationContext* thr_icontext = - invocation_context->thr_invocation_context(); - auto thr_edge_id = ThrEdgeIdStr(edge_id); - int output_fence_fd; - if (auto status = thr_invocation_context_get_output_buffer_sync_fence( - thr_icontext, thr_edge_id.data(), &output_fence_fd); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, - "thr_invocation_context_get_output_buffer_sync_fence failed: %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_get_output_buffer_sync_fence failed"); - } - - if (auto status = LiteRtCreateEventFromSyncFenceFd( - output_fence_fd, /*owns_fd=*/false, output_event); - status != kLiteRtStatusOk) { - return Error(status, "Failed to create event from sync fence fd"); - } - - return {}; -} - -} // namespace - -litert::Expected LiteRtDispatchInvocationContextT::InvokeAsync( - int num_output_events, LiteRtEvent* output_events) { - if (num_output_events != graph_->NumOutputs()) { - LITERT_LOG(LITERT_ERROR, "Unexpected number of output events: %d", - num_output_events); - return Error(kLiteRtStatusErrorInvalidArgument, - "Unexpected number of output events"); - } - - if (auto status = PrepareForInvoke(southbound_, this, - /*create_output_sync_fence=*/true); - !status) { - return status.Error(); - } - - if (auto status = InvokeOnce(southbound_, this); !status) { - return status.Error(); - } - - // Deatach input fences. - auto thr_invocation_context_detach_input_buffer_sync_fence = - southbound_.api().thr_invocation_context_detach_input_buffer_sync_fence; - if (!thr_invocation_context_detach_input_buffer_sync_fence) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_detach_input_buffer_sync_fence not found"); - } - for (const auto& p : input_sync_fences_) { - const auto& thr_edge_id = p.first; - auto input_fence_fd = p.second; - if (auto status = thr_invocation_context_detach_input_buffer_sync_fence( - thr_invocation_context_, thr_edge_id.data(), input_fence_fd); - status != kThrStatusSuccess) { - return Error( - kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_deatch_input_buffer_sync_fence failed"); - } - } - input_sync_fences_.clear(); - - // Extract output events. - for (auto graph_output_index = 0; graph_output_index < num_output_events; - ++graph_output_index) { - if (auto status = GetOutputEvent(southbound_, this, graph_output_index, - &output_events[graph_output_index]); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to get event for output %d", - graph_output_index); - return status.Error(); - } - } - - return {}; -} - -litert::Expected LiteRtDispatchInvocationContextT::StartMetricsCollection( - int detail_level) { - auto thr_invocation_context_start_metrics_collection = - southbound_.api().thr_invocation_context_start_metrics_collection; - if (!thr_invocation_context_start_metrics_collection) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_start_metrics_collection not found"); - } - if (auto status = thr_invocation_context_start_metrics_collection( - thr_invocation_context_, detail_level); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, - "thr_invocation_context_start_metrics_collection failed: %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_start_metrics_collection failed"); - } - return {}; -} - -litert::Expected LiteRtDispatchInvocationContextT::StopMetricsCollection( - LiteRtDispatchMetrics* metrics) { - auto thr_invocation_context_stop_metrics_collection = - southbound_.api().thr_invocation_context_stop_metrics_collection; - if (!thr_invocation_context_stop_metrics_collection) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_stop_metrics_collection not found"); - } - ThrInvocationMetrics thr_metrics{.version = 0}; - if (auto status = thr_invocation_context_stop_metrics_collection( - thr_invocation_context_, &thr_metrics); - status != kThrStatusSuccess) { - LITERT_LOG(LITERT_ERROR, - "thr_invocation_context_stop_metrics_collection failed: %d", - status); - *metrics = new LiteRtDispatchMetricsT(/*num_metrics=*/0, - /*metric_names=*/nullptr, - /*metric_values=*/nullptr); - return Error(kLiteRtStatusErrorRuntimeFailure, - "thr_invocation_context_stop_metrics_collection failed"); - } - *metrics = new LiteRtDispatchMetricsT(thr_metrics.num_metrics, - thr_metrics.metric_keys, - thr_metrics.metric_values); - return {}; -} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h deleted file mode 100644 index 8cbae593d0874c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ - -#include -#include -#include -#include -#include - -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_event.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h" -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" - -class LiteRtDispatchInvocationContextT { - public: - using Ptr = std::unique_ptr; - - static litert::Expected CreateFromBytecode( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs); - - static litert::Expected CreateFromGraph( - const litert::google_tensor::Southbound& southbound, - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph); - - ~LiteRtDispatchInvocationContextT(); - - litert::Expected GetInputRequirements( - int input_index, const LiteRtRankedTensorType& tensor_type); - litert::Expected GetOutputRequirements( - int output_index, const LiteRtRankedTensorType& tensor_type); - - litert::Expected AttachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle); - litert::Expected AttachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected DetachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle); - litert::Expected DetachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected Invoke(); - litert::Expected InvokeAsync(int num_output_events, - LiteRtEvent* output_events); - litert::Expected StartMetricsCollection(int detail_level); - litert::Expected StopMetricsCollection(LiteRtDispatchMetrics* metrics); - - litert::Expected AttachInputEvent(int graph_input_index, - LiteRtEvent input_event); - - ThrInvocationContext* thr_invocation_context() { - return thr_invocation_context_; - } - - LiteRtDispatchDeviceContext device_context() { return device_context_; } - - LiteRtDispatchGraph graph() { return graph_; } - - private: - LiteRtDispatchInvocationContextT( - const litert::google_tensor::Southbound& southbound, - ThrInvocationContext* thr_invocation_context, - LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph) - : southbound_(southbound), - thr_invocation_context_(thr_invocation_context), - device_context_(device_context), - graph_(graph) {} - - void AttachExecutable(LiteRtDispatchExecutableHandle exec_handle) { - exec_handle_ = exec_handle; - } - - const litert::google_tensor::Southbound& southbound_; - ThrInvocationContext* thr_invocation_context_; - LiteRtDispatchDeviceContext device_context_; - LiteRtDispatchGraph graph_; - std::optional exec_handle_; - std::map input_sync_fences_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_metrics.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_metrics.h deleted file mode 100644 index a33a69d4adc237..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_metrics.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_METRICS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_METRICS_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_any.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -class LiteRtDispatchMetricsT { - public: - // Construct a LiteRtDispatchMetricsT object using C-style arrays and strings. - // `metric_names` is an array of C-style strings representing metric names. - // `metric_values` is an array of int64_t values representing metric values. - // Both `metric_names` and `metric_values` have `num_metrics` elements. - // - // NOTE: The values in the arrays are copied into the LiteRtDispatchMetricsT. - LiteRtDispatchMetricsT(int num_metrics, const char** metric_names, - const int64_t* metric_values) - : metric_names_(metric_names, metric_names + num_metrics), - metric_values_(metric_values, metric_values + num_metrics) {} - int GetNumMetrics() const { return metric_names_.size(); } - LiteRtMetric GetMetric(int metric_index) const { - if (metric_index < 0 || metric_index >= GetNumMetrics()) { - return LiteRtMetric{.name = "invalid_metric", - .value = LiteRtAny{.type = kLiteRtAnyTypeNone}}; - } - return LiteRtMetric{ - .name = metric_names_[metric_index].c_str(), - .value = - LiteRtAny{ - .type = kLiteRtAnyTypeInt, - .int_value = metric_values_[metric_index], - }, - }; - } - - private: - const std::vector metric_names_; - const std::vector metric_values_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_METRICS_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc deleted file mode 100644 index e103c289d5820b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" - -#include - -#include -#include -#include - -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -#define Load(H, S) \ - H = reinterpret_cast(::dlsym(dlib_handle_, #S)); \ - if (!H) { \ - LITERT_LOG(LITERT_WARNING, "Failed to load symbol %s: %s", #S, \ - ::dlerror()); \ - } - -namespace litert { -namespace google_tensor { - -namespace { - -// The SouthBound APIs are implemented in the EdgeTPU libraries. -// It used to be implemented in the libedgetpu_util.so and has been moved to -// libedgetpu_litert.so in newer Android builds. -constexpr const char* kLiteRtLibPath = "/vendor/lib64/libedgetpu_litert.so"; -constexpr const char* kEdgeTpuUtilLibPath = "/vendor/lib64/libedgetpu_util.so"; - -} // namespace - -Southbound::Southbound() : api_(new ThrFunctions) {} - -Southbound::~Southbound() { - if (dlib_handle_) { - ::dlclose(dlib_handle_); - } -} - -Expected Southbound::Create( - std::optional shared_library_dir) { - Ptr southbound(new Southbound); - if (auto status = southbound->LoadSymbols(shared_library_dir); !status) { - return Unexpected(status.Error()); - } - - return southbound; -} - -Expected Southbound::LoadSymbols( - std::optional shared_library_dir) { - // Always load the Southbound API library from the vendor partition. - (void)shared_library_dir; - - // Try loading the new EdgeTPU LiteRT library first. If it fails, it might be - // because the Android build is too old. In that case, load the old EdgeTPU - // utility library. - dlib_handle_ = ::dlopen(kLiteRtLibPath, RTLD_NOW | RTLD_LOCAL); - if (!dlib_handle_) { - dlib_handle_ = ::dlopen(kEdgeTpuUtilLibPath, RTLD_NOW | RTLD_LOCAL); - if (!dlib_handle_) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to load Southbound shared library"); - } - } - - // Binds all supported symbols from the shared library to the function - // pointers. - Load(api_->thr_initialize, thrInitialize); - - Load(api_->thr_get_vendor_api_version, thrGetVendorApiVersion); - Load(api_->thr_get_vendor_id, thrGetVendorId); - - Load(api_->thr_context_create, thrContextCreate); - Load(api_->thr_context_delete, thrContextDelete); - - Load(api_->thr_graph_create, thrGraphCreate); - Load(api_->thr_graph_delete, thrGraphDelete); - - Load(api_->thr_graph_add_edge, thrGraphAddEdge); - Load(api_->thr_graph_add_sq_node, thrGraphAddSqNode); - - Load(api_->thr_graph_connect_node_input, thrGraphConnectNodeInput); - Load(api_->thr_graph_connect_node_output, thrGraphConnectNodeOutput); - - Load(api_->thr_graph_set_input_edge, thrGraphSetInputEdge); - Load(api_->thr_graph_set_output_edge, thrGraphSetOutputEdge); - - Load(api_->thr_graph_annotate_graph, thrGraphAnnotateGraph); - Load(api_->thr_graph_annotate_edge, thrGraphAnnotateEdge); - Load(api_->thr_graph_annotate_node, thrGraphAnnotateNode); - - Load(api_->thr_load_sq_container, thrLoadSqContainer); - Load(api_->thr_load_sq_container_fd, thrLoadSqContainerFd); - Load(api_->thr_load_sq_container_file, thrLoadSqContainerFile); - Load(api_->thr_unload_sq_container, thrUnloadSqContainer); - - Load(api_->thr_graph_assign_sq, thrGraphAssignSq); - Load(api_->thr_sq_query_scratch_pad, thrSqQueryScratchPad); - Load(api_->thr_sq_attach_scratch_pad_buffer, thrSqAttachScratchPadBuffer); - - Load(api_->thr_register_buffer, thrRegisterBuffer); - Load(api_->thr_register_buffer_with_offset, thrRegisterBufferWithOffset); - Load(api_->thr_unregister_buffer, thrUnregisterBuffer); - - Load(api_->thr_invocation_context_get, thrInvocationContextGet); - Load(api_->thr_invocation_context_delete, thrInvocationContextDelete); - - Load(api_->thr_invocation_context_attach_buffer, - thrInvocationContextAttachBuffer); - Load(api_->thr_invocation_context_detach_buffer, - thrInvocationContextDetachBuffer); - - Load(api_->thr_invocation_context_prepare_for_invoke, - thrInvocationContextPrepareForInvoke); - Load(api_->thr_invocation_context_invoke_once, - thrInvocationContextInvokeOnce); - Load(api_->thr_invocation_context_wait, thrInvocationContextWait); - - Load(api_->thr_invocation_context_attach_input_buffer_sync_fence, - thrInvocationContextAttachInputBufferSyncFence); - Load(api_->thr_invocation_context_get_output_buffer_sync_fence, - thrInvocationContextGetOutputBufferSyncFence); - Load(api_->thr_invocation_context_detach_input_buffer_sync_fence, - thrInvocationContextDetachInputBufferSyncFence); - - Load(api_->thr_invocation_context_query_node_scratch_pad, - thrInvocationContextQueryNodeScratchPad); - Load(api_->thr_invocation_context_attach_scratch_pad_buffer, - thrInvocationContextAttachScratchPadBuffer); - - Load(api_->thr_invocation_context_start_metrics_collection, - thrInvocationContextStartMetricsCollection); - Load(api_->thr_invocation_context_stop_metrics_collection, - thrInvocationContextStopMetricsCollection); - - Load(api_->thr_vendor_set_system_attribute_str, - thrVendorSetSystemAttributeStr); - Load(api_->thr_vendor_set_system_attribute_int64, - thrVendorSetSystemAttributeInt64); - - LITERT_LOG(LITERT_INFO, "SouthBound symbols loaded"); - return {}; -} - -} // namespace google_tensor -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h deleted file mode 100644 index d3ab7367c7e665..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_SOUTHBOUND_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_SOUTHBOUND_H_ - -#include -#include -#include - -#include "third_party/odml/infra/southbound/sb_api.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::google_tensor { - -class Southbound { - public: - using Ptr = std::unique_ptr; - struct ThrFunctions; - - Southbound(Southbound&) = delete; - Southbound(Southbound&&) = delete; - Southbound& operator=(const Southbound&) = delete; - Southbound& operator=(Southbound&&) = delete; - - ~Southbound(); - - static Expected Create(std::optional shared_library_dir); - - const ThrFunctions& api() const { return *api_; } - - private: - Southbound(); - Expected LoadSymbols(std::optional shared_library_dir); - - void* dlib_handle_ = nullptr; - std::unique_ptr api_; -}; - -// A convenient struct for holding function pointers to SouthBound symbols. -// These function pointers will be loaded to the shared library on device during -// runtime. -struct Southbound::ThrFunctions { - decltype(&thrInitialize) thr_initialize = nullptr; - - decltype(&thrGetVendorApiVersion) thr_get_vendor_api_version = nullptr; - decltype(&thrGetVendorId) thr_get_vendor_id = nullptr; - - decltype(&thrContextCreate) thr_context_create = nullptr; - decltype(&thrContextDelete) thr_context_delete = nullptr; - - decltype(&thrGraphCreate) thr_graph_create = nullptr; - decltype(&thrGraphDelete) thr_graph_delete = nullptr; - - decltype(&thrGraphAddEdge) thr_graph_add_edge = nullptr; - decltype(&thrGraphAddSqNode) thr_graph_add_sq_node = nullptr; - - decltype(&thrGraphConnectNodeInput) thr_graph_connect_node_input = nullptr; - decltype(&thrGraphConnectNodeOutput) thr_graph_connect_node_output = nullptr; - - decltype(&thrGraphSetInputEdge) thr_graph_set_input_edge = nullptr; - decltype(&thrGraphSetOutputEdge) thr_graph_set_output_edge = nullptr; - - decltype(&thrGraphAnnotateGraph) thr_graph_annotate_graph = nullptr; - decltype(&thrGraphAnnotateEdge) thr_graph_annotate_edge = nullptr; - decltype(&thrGraphAnnotateNode) thr_graph_annotate_node = nullptr; - - decltype(&thrLoadSqContainer) thr_load_sq_container = nullptr; - decltype(&thrLoadSqContainerFd) thr_load_sq_container_fd = nullptr; - decltype(&thrLoadSqContainerFile) thr_load_sq_container_file = nullptr; - decltype(&thrUnloadSqContainer) thr_unload_sq_container = nullptr; - - decltype(&thrGraphAssignSq) thr_graph_assign_sq = nullptr; - decltype(&thrSqQueryScratchPad) thr_sq_query_scratch_pad = nullptr; - decltype(&thrSqAttachScratchPadBuffer) thr_sq_attach_scratch_pad_buffer = - nullptr; - - decltype(&thrRegisterBuffer) thr_register_buffer = nullptr; - decltype(&thrRegisterBufferWithOffset) thr_register_buffer_with_offset = - nullptr; - decltype(&thrUnregisterBuffer) thr_unregister_buffer = nullptr; - - decltype(&thrInvocationContextGet) thr_invocation_context_get = nullptr; - decltype(&thrInvocationContextDelete) thr_invocation_context_delete = nullptr; - - decltype(&thrInvocationContextAttachBuffer) - thr_invocation_context_attach_buffer = nullptr; - decltype(&thrInvocationContextDetachBuffer) - thr_invocation_context_detach_buffer = nullptr; - - decltype(&thrInvocationContextPrepareForInvoke) - thr_invocation_context_prepare_for_invoke = nullptr; - decltype(&thrInvocationContextInvokeOnce) thr_invocation_context_invoke_once = - nullptr; - decltype(&thrInvocationContextWait) thr_invocation_context_wait = nullptr; - - decltype(&thrInvocationContextAttachInputBufferSyncFence) - thr_invocation_context_attach_input_buffer_sync_fence = nullptr; - decltype(&thrInvocationContextGetOutputBufferSyncFence) - thr_invocation_context_get_output_buffer_sync_fence = nullptr; - decltype(&thrInvocationContextDetachInputBufferSyncFence) - thr_invocation_context_detach_input_buffer_sync_fence = nullptr; - - decltype(&thrInvocationContextQueryNodeScratchPad) - thr_invocation_context_query_node_scratch_pad = nullptr; - decltype(&thrInvocationContextAttachScratchPadBuffer) - thr_invocation_context_attach_scratch_pad_buffer = nullptr; - - decltype(&thrInvocationContextStartMetricsCollection) - thr_invocation_context_start_metrics_collection = nullptr; - decltype(&thrInvocationContextStopMetricsCollection) - thr_invocation_context_stop_metrics_collection = nullptr; - - decltype(&thrVendorSetSystemAttributeStr) - thr_vendor_set_system_attribute_str = nullptr; - decltype(&thrVendorSetSystemAttributeInt64) - thr_vendor_set_system_attribute_int64 = nullptr; -}; - -} // namespace litert::google_tensor - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_SOUTHBOUND_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/BUILD deleted file mode 100644 index 73d1e8e484ebdf..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/BUILD +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/vendors/mediatek:mediatek_build_defs.bzl", "litert_cc_lib_with_mtk") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -litert_cc_lib_with_mtk( - name = "neuron_adapter_api", - srcs = [ - "neuron_adapter_api.cc", - ], - hdrs = [ - "neuron_adapter_api.h", - ], - tags = [ - # Don't build/test in OS until neuron is available. - "nobuilder", - "notap", - ], - deps = [ - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/neuro_pilot:latest_host_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/BUILD deleted file mode 100644 index 11a4f96268b748..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/BUILD +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib", "litert_test") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/mediatek/compiler:__subpackages__"], -) - -litert_dynamic_lib( - name = "compiler_plugin", - srcs = ["compiler_plugin.cc"], - hdrs = ["//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h"], - export_litert_only = True, - shared_lib_name = "compiler_plugin_so", - so_name = "libLiteRtCompilerPlugin_MediaTek.so", - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - ungrte = True, - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":compile_model", - ":create_model", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:common_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/schema:mediatek_litert_schema", - "//tensorflow/lite/experimental/litert/vendors/mediatek/schema:neuron_litert_schema", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - -cc_library( - name = "create_model", - srcs = ["create_model.cc"], - hdrs = ["create_model.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - # copybara:uncomment "//third_party/neuro_pilot:latest_host_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:add_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:batch_matmul_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:common_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:concat_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:fully_connected_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:gelu_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:mean_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:mul_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:operand_map", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:reshape_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:rsqrt_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:softmax_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:sub_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:transpose_op_legalization", - "//tensorflow/lite/experimental/litert/vendors/mediatek/schema:mediatek_litert_schema", - ], -) - -cc_library( - name = "compile_model", - srcs = ["compile_model.cc"], - hdrs = ["compile_model.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/neuro_pilot:latest_host_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - "//tensorflow/lite/experimental/litert/vendors/mediatek/schema:mediatek_litert_schema", - ], -) - -litert_test( - name = "compiler_plugin_test", - srcs = [ - "compiler_plugin_test.cc", - ], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - ], - linkstatic = True, - tags = [ - # Tests with ungrte deps do not currently work on forge. - "no-remote-exec", - "notap", - "nobuilder", - "no_oss", - "nosan", - ], - # Currently this test can only be run on Android because we don't have x86 shared libraries for - # MTK. - target_compatible_with = select({ - "@platforms//os:android": [], - "@platforms//os:linux": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - ungrte = True, - use_sys_malloc = True, - deps = [ - ":compiler_plugin", # buildcleaner: keep - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers_oss", - "//tensorflow/lite/experimental/litert/test:test_models", - "//tensorflow/lite/experimental/litert/vendors/cc:litert_compiler_plugin", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.cc deleted file mode 100644 index 15a5485b20dba4..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h" - -#include -#include - -#include "neuron/api/NeuronAdapter.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected CompileModel( - const NeuronAdapterApi& neuron_adapter_api, NeuronModel* model, - std::optional soc_model) { -#if defined(__ANDROID__) - if (soc_model) { - return Error(kLiteRtStatusErrorInvalidArgument, - "JIT compilation for a specific SoC is not supported"); - } -#endif - - // Per MediaTek recommendation, Compilation_create, - // Compilation_createWithOptions, and Compilation_setOptimizationString - // should be used as follow: - // - AOT Compilation: Compilation_createWithOptions only - // - JIT Compilation: Compilation_create and Compilation_setOptimizationString - // The code below takes care of those conditions. - - // NOLINTBEGIN - const auto compile_options = -#if __ANDROID__ - std::string(neuron_adapter_api.JitCompileOptions()); -#else - std::string(neuron_adapter_api.AotCompileOptions()); -#endif - // NOLINTEND - - auto compilation = -#if __ANDROID__ - neuron_adapter_api.CreateCompilation(model); -#else - neuron_adapter_api.CreateCompilation(model, compile_options); -#endif - if (!compilation) { - return compilation.Error(); - } - - if (neuron_adapter_api.api().compilation_set_priority( - compilation->get(), NEURON_PRIORITY_HIGH) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set compilation priority"); - } - - if (neuron_adapter_api.api().compilation_set_preference( - compilation->get(), NEURON_PREFER_SUSTAINED_SPEED) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set compilation preference"); - } - -#if __ANDROID__ - if (!compile_options.empty()) { - if (auto status = - neuron_adapter_api.api().compilation_set_optimization_string( - compilation->get(), compile_options.c_str()); - status != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_INFO, - "NeuronCompilation_setOptimizationString failed with error %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set optimization string"); - } - } -#endif - - if (auto status = - neuron_adapter_api.api().compilation_finish(compilation->get()); - status != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_INFO, "NeuronCompilation_finish failed with error %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to finish compilation"); - } - - return compilation; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h deleted file mode 100644 index 3e30c0d8451b7b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_COMPILE_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_COMPILE_MODEL_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected CompileModel( - const NeuronAdapterApi& neuron_adapter_api, NeuronModel* model, - std::optional soc_model); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_COMPILE_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin.cc deleted file mode 100644 index 1f92fb4168f124..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin.cc +++ /dev/null @@ -1,370 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/schema/neuron_schema_generated.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h" - -// -// Configurations -// - -using litert::Error; -using litert::Expected; -using litert::mediatek::NeuronAdapterApi; -using litert::mediatek::NeuronCompilationPtr; -using litert::mediatek::NeuronModelPtr; - -namespace { - -constexpr char kPluginManufacturer[] = "MediaTek"; - -// clang-format off -constexpr std::pair kPluginSocModels[] = { - {"mt6853", "mt6853"}, - {"mt6877", "mt6877"}, - {"mt6878", "mt6878"}, - {"mt6879", "mt6879"}, - {"mt6886", "mt6886"}, - {"mt6893", "mt6893"}, - {"mt6895", "mt6895"}, - {"mt6897", "mt6897"}, - {"mt6983", "mt6983"}, - {"mt6985", "mt6985"}, - {"mt6989", "mt6989"}, - {"mt6991", "mt6991"}, -}; - -constexpr LiteRtOpCode kSupportedOps[] = { - kLiteRtOpCodeTflAdd, - kLiteRtOpCodeTflMul, - kLiteRtOpCodeTflBatchMatmul, - kLiteRtOpCodeTflFullyConnected, - kLiteRtOpCodeTflReshape, - kLiteRtOpCodeTflTranspose, - kLiteRtOpCodeTflRsqrt, - kLiteRtOpCodeTflConcatenation, - kLiteRtOpCodeTflQuantize, - kLiteRtOpCodeTflSlice, - kLiteRtOpCodeTflSub, - kLiteRtOpCodeTflTanh, - kLiteRtOpCodeTflSoftmax, - kLiteRtOpCodeTflMean, - kLiteRtOpCodeTflGelu, -}; -// clang-format on - -constexpr auto kNumPluginSocModels = - sizeof(kPluginSocModels) / sizeof(kPluginSocModels[0]); - -std::optional FindSocModel(absl::string_view soc_model_name) { - std::optional soc_model; - for (auto i = 0; i < kNumPluginSocModels; ++i) { - if (soc_model_name == kPluginSocModels[i].first) { - soc_model = kPluginSocModels[i].second; - break; - } - } - return soc_model; -} - -} // namespace - -LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version) { - if (api_version == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - api_version->major = LITERT_API_VERSION_MAJOR; - api_version->minor = LITERT_API_VERSION_MINOR; - api_version->patch = LITERT_API_VERSION_PATCH; - return kLiteRtStatusOk; -} - -const char* LiteRtGetCompilerPluginSocManufacturer() { - return kPluginManufacturer; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( - LiteRtCompilerPlugin compiler_plugin, - LiteRtHwAccelerators* supported_hardware) { - if (!compiler_plugin || !supported_hardware) { - return kLiteRtStatusErrorInvalidArgument; - } - *supported_hardware = kLiteRtHwAcceleratorNpu; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( - LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex* num_supported_soc_models) { - if (!compiler_plugin || !num_supported_soc_models) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_supported_soc_models = kNumPluginSocModels; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, - const char** soc_model_name) { - if (!compiler_plugin || !soc_model_name) { - return kLiteRtStatusErrorInvalidArgument; - } else if (soc_model_idx < 0 || soc_model_idx >= kNumPluginSocModels) { - return kLiteRtStatusErrorInvalidArgument; - } - *soc_model_name = kPluginSocModels[soc_model_idx].first; - return kLiteRtStatusOk; -} - -// -// Compiled Result Definition -// - -// TODO: Revisit this struct after we extend the compiler plugin API to return -// results with more than one single bytecode. -struct LiteRtCompiledResultT { - std::vector graph_names; - neuron::BytecodeBuilder bytebuilder; -}; - -LiteRtStatus LiteRtCompiledResultNumByteCodeModules( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_byte_code) { - if (!compiled_result || !num_byte_code) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_byte_code = compiled_result->graph_names.size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledResultByteCode( - LiteRtCompiledResult compiled_result, LiteRtParamIndex byte_code_idx, - const void** byte_code, size_t* byte_code_size) { - if (!compiled_result || !byte_code || !byte_code_size || - (byte_code_idx >= compiled_result->graph_names.size())) { - return kLiteRtStatusErrorInvalidArgument; - } - *byte_code = compiled_result->bytebuilder.GetBytecode().first; - *byte_code_size = compiled_result->bytebuilder.GetBytecode().second; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledResultCallInfo( - LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, - const void** call_info, size_t* call_info_size, - LiteRtParamIndex* byte_code_idx) { - if (!compiled_result || !call_info || !call_info_size) { - return kLiteRtStatusErrorInvalidArgument; - } else if (call_idx >= compiled_result->graph_names.size()) { - return kLiteRtStatusErrorIndexOOB; - } - - auto& graph_name = compiled_result->graph_names[call_idx]; - *call_info = graph_name.data(); - *call_info_size = graph_name.size(); - // MTK should have one byte code per call. - *byte_code_idx = call_idx; - - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompiledResultCalls( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { - if (!compiled_result || !num_calls) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_calls = compiled_result->graph_names.size(); - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompiledResult(LiteRtCompiledResult compiled_result) { - delete compiled_result; -} - -// -// Plugin Definition -// - -// Plugins can hold state. -struct LiteRtCompilerPluginT {}; - -LiteRtStatus LiteRtCompilerPluginSetFlags(LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex num_flags, - const char** keys, - const char** values) { - // IMPLEMENT ME - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { - auto* plugin = new LiteRtCompilerPluginT; - *compiler_plugin = plugin; - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { - delete compiler_plugin; -} - -namespace { - -// TODO update this function to match the new legalizations. -bool IsOpSupported(const litert::Op& op) { - // NOTE: Currently we are demoing by just mapping simple f32 mul ops. Use a - // very loose guard for now -- only checking if op code is supported. - for (auto supported_op : kSupportedOps) { - if (op.Code() == supported_op && - litert::mediatek::VerifyCommonOp(op, op.Code())) { - return true; - } - } - return false; -} - -} // namespace - -LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtSubgraph subgraph, - LiteRtOpList selected_ops) { - litert::Subgraph graph(subgraph); - for (const auto& op : graph.Ops()) { - if (!IsOpSupported(op)) { - continue; - } - - LITERT_RETURN_IF_ERROR(LiteRtPushOp(selected_ops, op.Get(), 0)); - } - - return kLiteRtStatusOk; -} - -namespace { - -Expected> CompilePartition( - NeuronAdapterApi& neuron_adapter_api, const litert::Subgraph& partition, - const std::string& graph_name, std::optional soc_model) { - auto model = CreateModel(neuron_adapter_api, partition, graph_name); - if (!model) { - return model.Error(); - } - - auto compilation = CompileModel(neuron_adapter_api, model->get(), soc_model); - if (!compilation) { - return compilation.Error(); - } - - size_t bytecode_size; - if (neuron_adapter_api.api().compilation_get_compiled_network_size( - compilation->get(), &bytecode_size) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get compiled network size"); - } - - std::vector bytecode(bytecode_size); - if (neuron_adapter_api.api().compilation_store_compiled_network( - compilation->get(), bytecode.data(), bytecode.size()) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get compiled network"); - } - - return bytecode; -} - -} // namespace - -LiteRtStatus LiteRtCompilerPluginCompile( - LiteRtCompilerPlugin compiler_plugin, const char* soc_model, - LiteRtModel partitions, LiteRtCompiledResult* compiled_result) { - static constexpr char dla_directory_template[] = "/tmp/tempdir_dla.XXXXXX"; - char* dla_directory_name = mkdtemp(const_cast(dla_directory_template)); - if (dla_directory_name == nullptr) { - LITERT_LOG(LITERT_ERROR, "Failed to make DLA temporary directory") - return kLiteRtStatusErrorFileIO; - } - setenv("MTKNN_ADAPTER_DLA_PLATFORM", soc_model, 1); - setenv("MTKNN_ADAPTER_DLA_DIR", dla_directory_name, 1); - - auto model = litert::Model::CreateFromNonOwnedHandle(partitions); - const auto num_partitions = model.NumSubgraphs(); - - LITERT_LOG(LITERT_INFO, - "Starting MediaTek Compilation for %d subgraphs, soc_model=%s", - num_partitions, soc_model); - - auto opt_soc_model = soc_model ? FindSocModel(soc_model) : std::nullopt; - if (opt_soc_model) { - LITERT_LOG(LITERT_ERROR, "Compiling for MediaTek architecture: %s", - *opt_soc_model); - } else if (soc_model) { - LITERT_LOG(LITERT_ERROR, "Unexpected SoC model: %s", soc_model); - rmdir(dla_directory_name); - return kLiteRtStatusErrorInvalidArgument; - } - - // Initialize SDK and load mediatek shared libraries. - - auto api = NeuronAdapterApi::Create(/*shared_library_dir=*/std::nullopt); - if (!api) { - rmdir(dla_directory_name); - return api.Error().Status(); - } - - auto result = std::make_unique(); - - for (auto i = 0; i < num_partitions; ++i) { - auto graph_name = absl::StrFormat("Partition_%d", i); - auto bytecode = - CompilePartition(**api, *model.Subgraph(i), graph_name, opt_soc_model); - rmdir(dla_directory_name); - if (!bytecode) { - LITERT_LOG(LITERT_INFO, "%s", bytecode.Error().Message().c_str()); - return bytecode.Error().Status(); - } - auto bufferIdx = result->bytebuilder.AddBuffer( - graph_name, (int8_t*)bytecode->data(), bytecode->size()); - result->bytebuilder.AddCompiledNetwork( - graph_name, NeuronSchema::CompiledType_AdapterCache, bufferIdx); - result->graph_names.emplace_back(graph_name); - } - - if (!result->bytebuilder.Finish()) { - return kLiteRtStatusErrorCompilation; - } - *compiled_result = result.release(); - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin_test.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin_test.cc deleted file mode 100644 index b8bb947587b229..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin_test.cc +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/test_models.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h" - -namespace litert { -namespace { - -using ::testing::Values; - -// clang-format off -const auto kSupportedOps = Values( - "add_cst.tflite", - "add_simple.tflite", - "simple_add_op.tflite"); -// clang-format on - -TEST(TestMediatekPlugin, GetConfigInfo) { - EXPECT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), "MediaTek"); - - auto plugin = CreatePlugin(); - - LiteRtParamIndex num_supported_soc_models; - ASSERT_EQ(LiteRtGetNumCompilerPluginSupportedSocModels( - plugin.get(), &num_supported_soc_models), - kLiteRtStatusOk); - ASSERT_EQ(num_supported_soc_models, 12); - - const char* config_id; - ASSERT_EQ( - LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, &config_id), - kLiteRtStatusOk); - EXPECT_STREQ(config_id, "mt6853"); -} - -TEST(TestMediatekPlugin, PartitionAdd) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("add_simple.tflite"); - - LiteRtOpListT selected_op_list; - ASSERT_EQ(LiteRtCompilerPluginPartition(plugin.get(), /*soc_model=*/nullptr, - model.Subgraph(0)->Get(), - &selected_op_list), - kLiteRtStatusOk); - const auto selected_ops = selected_op_list.Values(); - - ASSERT_EQ(selected_ops.size(), 1); - EXPECT_EQ(selected_ops[0].first->OpCode(), kLiteRtOpCodeTflAdd); -} - -// ///////////////////////////////////////////////////////////////////////////// - -class MtkPluginOpCompatibilityTest - : public ::testing::TestWithParam {}; - -TEST_P(MtkPluginOpCompatibilityTest, SupportedOpsTest) { -#ifndef __ANDROID__ - GTEST_SKIP() << "Loading shared lib not currently supported on linux."; -#endif // __ANDROID__ - - LITERT_LOG(LITERT_INFO, "Testing TFLite model: %s", GetParam().c_str()); - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel(GetParam()); - - LiteRtCompiledResult compiled; - ASSERT_EQ(LiteRtCompilerPluginCompile(plugin.get(), /*soc_model=*/nullptr, - model.Get(), &compiled), - kLiteRtStatusOk); - - LiteRtParamIndex num_byte_code; - ASSERT_EQ(LiteRtCompiledResultNumByteCodeModules(compiled, &num_byte_code), - kLiteRtStatusOk); - ASSERT_EQ(num_byte_code, 1); - - const void* byte_code; - size_t byte_code_size; - - ASSERT_EQ(LiteRtGetCompiledResultByteCode(compiled, /*byte_code_idx=*/0, - &byte_code, &byte_code_size), - kLiteRtStatusOk); - - absl::string_view byte_code_string(reinterpret_cast(byte_code), - byte_code_size); - ASSERT_FALSE(byte_code_string.empty()); - - const void* op_data; - size_t op_data_size; - LiteRtParamIndex byte_code_idx; - - ASSERT_EQ(LiteRtGetCompiledResultCallInfo(compiled, /*call_idx=*/0, &op_data, - &op_data_size, &byte_code_idx), - kLiteRtStatusOk); - - EXPECT_EQ(byte_code_idx, 0); - - absl::string_view op_data_string(reinterpret_cast(op_data), - op_data_size); - EXPECT_EQ(op_data_string, "Partition_0"); - - LiteRtDestroyCompiledResult(compiled); -} - -INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, MtkPluginOpCompatibilityTest, - kSupportedOps); - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.cc deleted file mode 100644 index c7b3ca2d3ebdb0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.cc +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h" - -#include -#include -#include - -#include "neuron/api/NeuronAdapter.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h" - -namespace litert::mediatek { - -Expected CreateModel(const NeuronAdapterApi& neuron_adapter_api, - const litert::Subgraph& partition, - const std::string& model_name) { - auto model = neuron_adapter_api.CreateModel(); - if (!model) { - return model.Error(); - } - - if (neuron_adapter_api.api().model_set_name( - model->get(), model_name.c_str()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to set model name"); - } - - OperandMap operand_map(neuron_adapter_api, model->get()); - - std::vector input_indices; - for (const auto& input : partition.Inputs()) { - auto operand_index = operand_map.GetOperandIndex(input); - if (!operand_index) { - return operand_index.Error(); - } - input_indices.push_back(*operand_index); - } - - std::vector output_indices; - for (const auto& output : partition.Outputs()) { - auto operand_index = operand_map.GetOperandIndex(output); - if (!operand_index) { - return operand_index.Error(); - } - output_indices.push_back(*operand_index); - } - - if (neuron_adapter_api.api().model_identify_inputs_and_outputs( - model->get(), input_indices.size(), input_indices.data(), - output_indices.size(), output_indices.data()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to identify model I/Os"); - } - - for (const auto& op : partition.Ops()) { - Expected status; - switch (op.Code()) { - case kLiteRtOpCodeTflAdd: - status = - LegalizeAddOp(neuron_adapter_api, model->get(), operand_map, op); - break; - case kLiteRtOpCodeTflMul: - status = - LegalizeMulOp(neuron_adapter_api, model->get(), operand_map, op); - break; - case kLiteRtOpCodeTflBatchMatmul: - status = LegalizeBatchMatMulOp(neuron_adapter_api, model->get(), - operand_map, op); - break; - case kLiteRtOpCodeTflFullyConnected: - status = LegalizeFullyConnectedOp(neuron_adapter_api, model->get(), - operand_map, op); - break; - case kLiteRtOpCodeTflReshape: - status = LegalizeReshapeOp(neuron_adapter_api, model->get(), - operand_map, op); - break; - case kLiteRtOpCodeTflTranspose: - status = LegalizeTransposeOp(neuron_adapter_api, model->get(), - operand_map, op); - break; - case kLiteRtOpCodeTflRsqrt: - status = - LegalizeRsqrtOp(neuron_adapter_api, model->get(), operand_map, op); - break; - case kLiteRtOpCodeTflConcatenation: - status = - LegalizeConcatOp(neuron_adapter_api, model->get(), operand_map, op); - break; - case kLiteRtOpCodeTflQuantize: - status = LegalizeCommonOp(neuron_adapter_api, model->get(), operand_map, - op, NEURON_QUANTIZE); - break; - case kLiteRtOpCodeTflSlice: - status = LegalizeCommonOp(neuron_adapter_api, model->get(), operand_map, - op, NEURON_SLICE); - break; - case kLiteRtOpCodeTflTanh: - status = LegalizeCommonOp(neuron_adapter_api, model->get(), operand_map, - op, NEURON_TANH); - break; - case kLiteRtOpCodeTflSub: - status = - LegalizeSubOp(neuron_adapter_api, model->get(), operand_map, op); - break; - case kLiteRtOpCodeTflSoftmax: - status = LegalizeSoftmaxOp(neuron_adapter_api, model->get(), - operand_map, op); - break; - case kLiteRtOpCodeTflMean: - status = - LegalizeMeanOp(neuron_adapter_api, model->get(), operand_map, op); - break; - case kLiteRtOpCodeTflGelu: - status = - LegalizeGeluOp(neuron_adapter_api, model->get(), operand_map, op); - break; - default: - return Error(kLiteRtStatusErrorRuntimeFailure, "Unsupported op"); - } - - if (!status) { - return status.Error(); - } - } - - if (neuron_adapter_api.api().model_finish(model->get()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to finish model"); - } - - return model; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h deleted file mode 100644 index 6e958d691a80e1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_CREATE_MODEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_CREATE_MODEL_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -// Create a new NeuronModel Graph from given LiteRt Graph. -Expected CreateModel(const NeuronAdapterApi& neuron_adapter_api, - const Subgraph& partition, - const std::string& model_name); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_CREATE_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/BUILD deleted file mode 100644 index abd020b3ce9725..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/BUILD +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/mediatek/compiler:__subpackages__"], -) - -cc_library( - name = "operand_map", - srcs = ["operand_map.cc"], - hdrs = ["operand_map.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "neuron_utils", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - "@com_google_absl//absl/container:flat_hash_map", - ], -) - -cc_library( - name = "neuron_utils", - srcs = ["neuron_utils.cc"], - hdrs = ["neuron_utils.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "add_op_legalization", - srcs = ["add_op_legalization.cc"], - hdrs = ["add_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "mul_op_legalization", - srcs = ["mul_op_legalization.cc"], - hdrs = ["mul_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "batch_matmul_op_legalization", - srcs = ["batch_matmul_op_legalization.cc"], - hdrs = ["batch_matmul_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "fully_connected_op_legalization", - srcs = ["fully_connected_op_legalization.cc"], - hdrs = ["fully_connected_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "reshape_op_legalization", - srcs = ["reshape_op_legalization.cc"], - hdrs = ["reshape_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "transpose_op_legalization", - srcs = ["transpose_op_legalization.cc"], - hdrs = ["transpose_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "rsqrt_op_legalization", - srcs = ["rsqrt_op_legalization.cc"], - hdrs = ["rsqrt_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "concat_op_legalization", - srcs = ["concat_op_legalization.cc"], - hdrs = ["concat_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "quantize_op_legalization", - srcs = ["quantize_op_legalization.cc"], - hdrs = ["quantize_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "common_op_legalization", - srcs = ["common_op_legalization.cc"], - hdrs = ["common_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "sub_op_legalization", - srcs = ["sub_op_legalization.cc"], - hdrs = ["sub_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "softmax_op_legalization", - srcs = ["softmax_op_legalization.cc"], - hdrs = ["softmax_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "mean_op_legalization", - srcs = ["mean_op_legalization.cc"], - hdrs = ["mean_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) - -cc_library( - name = "gelu_op_legalization", - srcs = ["gelu_op_legalization.cc"], - hdrs = ["gelu_op_legalization.h"], - tags = [ - # Don't build/test in OS until MediaTek SDK is available. - "nobuilder", - "notap", - ], - deps = [ - "operand_map", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.cc deleted file mode 100644 index 47194694f7de7b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeAddOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Add"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_ADD operation takes a 3rd scalar operand, which is used to pass a - // TfLiteFusedActivation value. - uint32_t tfl_fused_activation; - if (auto status = - LiteRtGetAddFusedActivationOption(op.Get(), &tfl_fused_activation); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get fused activation"); - } - auto fused_activation_operand_index = - operand_map.AddScalarInt32(tfl_fused_activation); - if (!fused_activation_operand_index) { - return fused_activation_operand_index.Error(); - } - input_indices.push_back(*fused_activation_operand_index); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_ADD, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_ADD op"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h deleted file mode 100644 index d774d6bcb972e9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeAddOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.cc deleted file mode 100644 index 23cbe10ff09386..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.cc +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeBatchMatMulOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, - OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize BatchMatMul"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_BATCH_MATMUL operation takes 2 scalar operand, which is used to - // pass a adjX, adjY value. - bool tfl_matmul_param_adj_x = 0, tfl_matmul_param_adj_y = 0; - if (auto status = - LiteRtGetBatchMatmulAdjXOption(op.Get(), &tfl_matmul_param_adj_x); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get batch matmul adjX"); - } - - if (auto status = - LiteRtGetBatchMatmulAdjYOption(op.Get(), &tfl_matmul_param_adj_y); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get batch matmul adjY"); - } - - auto adj_x_operand_index = operand_map.AddScalarBool(tfl_matmul_param_adj_x); - if (!adj_x_operand_index) { - return adj_x_operand_index.Error(); - } - input_indices.push_back(*adj_x_operand_index); - - auto adj_j_operand_index = operand_map.AddScalarBool(tfl_matmul_param_adj_y); - if (!adj_j_operand_index) { - return adj_j_operand_index.Error(); - } - input_indices.push_back(*adj_j_operand_index); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_BATCH_MATMUL, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_BATCH_MATMUL op"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.h deleted file mode 100644 index 227c6563713a58..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/batch_matmul_op_legalization.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeBatchMatMulOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, - OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.cc deleted file mode 100644 index 0c3a62f8997b3a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -bool VerifyCommonOp(const litert::Op& op, LiteRtOpCode op_code) { - // Do some common check - return true; -} - -Expected LegalizeCommonOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op, - NeuronOperationType mtk_operation_type) { - LITERT_LOG(LITERT_INFO, "Legalize Op: %d", mtk_operation_type); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/mtk_operation_type, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to add operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.h deleted file mode 100644 index 5995f77e888bf1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/common_op_legalization.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_COMMON_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_COMMON_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -bool VerifyCommonOp(const litert::Op& op, LiteRtOpCode op_code); - -Expected LegalizeCommonOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op, - NeuronOperationType mtk_operation_type); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_COMMON_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.cc deleted file mode 100644 index 3320272f9c65f1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeConcatOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Concate"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_CONCAT operation takes an additional scalar operand, which is used - // to pass as a axis. - int32_t axis; - if (auto status = LiteRtGetConcatenationAxisOption(op.Get(), &axis); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get new shape option"); - } - - auto axis_operand_index = operand_map.AddScalarInt32(axis); - if (!axis_operand_index) { - return axis_operand_index.Error(); - } - - input_indices.push_back(*axis_operand_index); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, - /*type=*/NEURON_CONCATENATION, input_indices, - output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_CONCAT operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.h deleted file mode 100644 index e7f1294ec39df4..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/concat_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_CONCAT_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_CONCAT_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeConcatOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_CONCAT_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.cc deleted file mode 100644 index 877c7511649a4a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.cc +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -#define GET_RANK(op) ((op).RankedTensorType()->Layout().Rank()) -#define GET_DIMENSION(op) ((op).RankedTensorType()->Layout().Dimensions()) - -namespace litert::mediatek { - -Expected LegalizeFullyConnectedOp( - const NeuronAdapterApi& neuron_adapter_api, NeuronModel* model, - OperandMap& operand_map, const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Fully Connected"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // for beta - if (input_indices.size() < 3) { - auto weights_shape = GET_DIMENSION(op.Inputs()[1]); - std::vector bias_shape = { - static_cast(weights_shape[0])}; - std::vector bias_data(bias_shape[0], 0); - auto bias_data_operand = - operand_map.AddTensorByType(NEURON_TENSOR_QUANT8_SYMM, bias_shape, - bias_data.data(), bias_data.size() * 1); - input_indices.push_back(*bias_data_operand); - } - - // A NEURON_FULLY_CONNECTED operation takes a 4rd scalar operand, which is - // used to pass a TfLiteFusedActivation value. - uint32_t tfl_fused_activation; - if (auto status = LiteRtGetFullyConnectedFusedActivationOption( - op.Get(), &tfl_fused_activation); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get fused activation"); - } - auto fused_activation_operand_index = - operand_map.AddScalarInt32(tfl_fused_activation); - if (!fused_activation_operand_index) { - return fused_activation_operand_index.Error(); - } - input_indices.push_back(*fused_activation_operand_index); - - auto output_operand = OperandType::Create(op.Outputs()[0]); - std::vector output_indices; - - if (GET_RANK(op.Outputs()[0]) > 2) { - // if output_operand shape , reshape to - auto last_dim = output_operand->GetDimension().back(); - auto elements = output_operand->GetElementCount(); - std::vector new_dimension = {elements / last_dim, last_dim}; - if (auto res = output_operand->Reshape(new_dimension); !res) { - return res.Error(); - } - auto intermediate_operand = operand_map.AddOperand(*output_operand); - output_indices.push_back(*intermediate_operand); - } else { - auto output_operand = operand_map.GetOperandIndex(op.Outputs()[0]); - output_indices.push_back(*output_operand); - if (!output_operand) { - return output_operand.Error(); - } - } - - if (ModelAddOperation(neuron_adapter_api, model, - /*type=*/NEURON_FULLY_CONNECTED, input_indices, - output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set NEURON_FULLY_CONNECTED operation"); - } - - if (GET_RANK(op.Outputs()[0]) > 2) { - // intermediate as reshape input - input_indices = {output_indices.back()}; - auto output_operand = operand_map.GetOperandIndex(op.Outputs()[0]); - if (!output_operand) { - return output_operand.Error(); - } - - auto dimension = op.Outputs()[0].RankedTensorType()->Layout().Dimensions(); - std::vector new_shape(dimension.begin(), dimension.end()); - std::vector tensor_shape = {(uint32_t)new_shape.size()}; - auto new_shape_operand_index = operand_map.AddTensorInt32( - tensor_shape, new_shape.data(), new_shape.size() * sizeof(int32_t)); - if (!new_shape_operand_index) { - return new_shape_operand_index.Error(); - } - input_indices.push_back(*new_shape_operand_index); - output_indices = {*output_operand}; - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_RESHAPE, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add Reshape after FC"); - } - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.h deleted file mode 100644 index 68d6a319295b90..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/fully_connected_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeFullyConnectedOp( - const NeuronAdapterApi& neuron_adapter_api, NeuronModel* model, - OperandMap& operand_map, const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.cc deleted file mode 100644 index 32af1156fae40b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.cc +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -constexpr uint32_t kGeluApproximateTanh = 1; - -Expected LegalizeGeluOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Gelu"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - auto approximate_operand = operand_map.AddScalarUInt32(kGeluApproximateTanh); - if (!approximate_operand) { - return approximate_operand.Error(); - } - - input_indices.push_back(*approximate_operand); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_GELU_V2, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add GELU operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.h deleted file mode 100644 index 9249263c77e902..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/gelu_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeGeluOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.cc deleted file mode 100644 index efbad31106b51b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeMeanOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Mean"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_Mean operation takes an additional scalar operand, which is - // used to pass a keepdims. - bool keepdims; - if (auto status = LiteRtGetMeanKeepDimsOption(op.Get(), &keepdims); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get beta"); - } - LITERT_LOG(LITERT_INFO, "keepdims: %d", keepdims); - auto keepdims_operand = operand_map.AddScalarInt32(keepdims); - if (!keepdims_operand) { - return keepdims_operand.Error(); - } - input_indices.push_back(*keepdims_operand); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_MEAN, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_MEAN operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.h deleted file mode 100644 index fc36f646d75836..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mean_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_MEAN_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_MEAN_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeMeanOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_MEAN_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.cc deleted file mode 100644 index b78f1640f8bc92..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeMulOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Mul"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_MUL operation takes a 3rd scalar operand, which is used to pass a - // TfLiteFusedActivation value. - uint32_t tfl_fused_activation; - if (auto status = - LiteRtGetMulFusedActivationOption(op.Get(), &tfl_fused_activation); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get fused activation"); - } - auto fused_activation_operand_index = - operand_map.AddScalarInt32(tfl_fused_activation); - if (!fused_activation_operand_index) { - return fused_activation_operand_index.Error(); - } - input_indices.push_back(*fused_activation_operand_index); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_MUL, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_MUL operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.h deleted file mode 100644 index 8ff1c325fe3f27..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/mul_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeMulOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.cc deleted file mode 100644 index 059d51cc10dc9c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.h" - -namespace litert::mediatek { - -Expected GetNeuronTensorType(const Tensor& t) { - auto ranked_tensor_type = t.RankedTensorType(); - if (!ranked_tensor_type) { - return ranked_tensor_type.Error(); - } - - int32_t mtk_type; - switch (ranked_tensor_type->ElementType()) { - case ElementType::Float32: - mtk_type = NEURON_TENSOR_FLOAT32; - break; - case ElementType::Float16: - mtk_type = NEURON_TENSOR_FLOAT16; - break; - case ElementType::Int32: - mtk_type = NEURON_TENSOR_INT32; - break; - case ElementType::Int16: - if (t.QTypeId() == kLiteRtQuantizationPerTensor) { - mtk_type = NEURON_TENSOR_QUANT16_SYMM; - } else { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Int16 is not supported."); - } - break; - case ElementType::Int8: - if (t.QTypeId() == kLiteRtQuantizationPerTensor) { - mtk_type = NEURON_TENSOR_QUANT8_SYMM; - } else if (t.QTypeId() == kLiteRtQuantizationPerChannel) { - mtk_type = NEURON_TENSOR_QUANT8_SYMM_PER_CHANNEL; - } else { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Int8 is not supported."); - } - break; - default: - return Error(kLiteRtStatusErrorRuntimeFailure, - absl::StrFormat("Unsupported element type: %d", - ranked_tensor_type->ElementType())); - } - return mtk_type; -} - -Expected GetNeuronDataSize(NeuronTensorType type) { - switch (type) { - case NEURON_FLOAT32: - case NEURON_TENSOR_FLOAT32: - case NEURON_INT32: - case NEURON_TENSOR_INT32: - return 4; - case NEURON_FLOAT16: - case NEURON_TENSOR_FLOAT16: - case NEURON_EXT_TENSOR_QUANT16_ASYMM_SIGNED: - return 2; - case NEURON_BOOL: - case NEURON_TENSOR_BOOL8: - case NEURON_TENSOR_QUANT8_ASYMM: - case NEURON_TENSOR_QUANT8_ASYMM_SIGNED: - return 1; - default: - return Error(kLiteRtStatusErrorRuntimeFailure, - "Get Data Size fail for Neuron Type"); - } - return Error(kLiteRtStatusErrorRuntimeFailure, "Unexpected neuron type"); -} - -Expected IsQuantizedType(NeuronTensorType type) { - switch (type) { - case NEURON_TENSOR_QUANT16_SYMM: - case NEURON_TENSOR_QUANT16_ASYMM: - case NEURON_TENSOR_QUANT8_ASYMM: - case NEURON_TENSOR_QUANT8_ASYMM_SIGNED: - return true; - } - return false; -} - -NeuronReturnCode ModelAddOperation(const NeuronAdapterApi& api, - NeuronModel* model, NeuronOperationType type, - std::vector input, - std::vector output) { - return api.api().model_add_operation(model, type, input.size(), input.data(), - output.size(), output.data()); -}; - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.h deleted file mode 100644 index 27633fef2d746f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_NEURON_UTILS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_NEURON_UTILS_H_ - -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { -using NeuronTensorType = int32_t; -using NeuronReturnCode = int32_t; - -Expected GetNeuronTensorType(const Tensor& t); - -Expected GetNeuronDataSize(NeuronTensorType type); - -Expected IsQuantizedType(NeuronTensorType type); - -NeuronReturnCode ModelAddOperation(const NeuronAdapterApi& api, - NeuronModel* model, NeuronOperationType type, - std::vector input, - std::vector output); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_NEURON_UTILS_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.cc deleted file mode 100644 index 40347a2c00bbe5..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.cc +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected OperandMap::Register(const NeuronOperandType& operand_type) { - if (neuron_adapter_api_.api().model_add_operand(model_, &operand_type) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to register model operand"); - } - return AllocateOperandIndex(); -} - -Expected OperandMap::Register(const Tensor& t) { - auto operand_type = OperandType::Create(t); - if (!operand_type) { - return operand_type.Error(); - } - - auto operand_index = - Register(static_cast(*operand_type)); - if (!operand_index) { - return operand_index.Error(); - } - LITERT_LOG(LITERT_INFO, "\nOperandIndex: %d", operand_index.Value()); - operand_type->Info(); - - if (t.HasWeights()) { - auto weights = t.Weights().Bytes(); - if (t.QTypeId() == kLiteRtQuantizationPerChannel) { - auto quant_param = operand_type->GetPerChannelQuantParams().Value(); - if (neuron_adapter_api_.api().model_set_symm_per_channel_quant_params( - model_, *operand_index, &quant_param) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set param of per channel quant params"); - } - } - if (neuron_adapter_api_.api().model_set_operand_value( - model_, *operand_index, weights.data(), weights.size()) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set value of tensor weights"); - } - } - - map_[t.Get()] = *operand_index; - return *operand_index; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h deleted file mode 100644 index fb79626d6cf988..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h +++ /dev/null @@ -1,269 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_OPERAND_MAP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_OPERAND_MAP_H_ - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/neuron_utils.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -class OperandType : public NeuronOperandType { - public: - static Expected Create(const Tensor& t) { - auto ranked_tensor_type = t.RankedTensorType(); - if (!ranked_tensor_type) { - return ranked_tensor_type.Error(); - } - - auto tensor_dimensions = ranked_tensor_type->Layout().Dimensions(); - std::vector mtk_dimensions; - mtk_dimensions.reserve(tensor_dimensions.size()); - std::copy(tensor_dimensions.begin(), tensor_dimensions.end(), - std::back_inserter(mtk_dimensions)); - - // tensor type dimensions couldn't be zero. - if (mtk_dimensions.size() == 0) { - mtk_dimensions = { - 1, - }; - } - - // BlockWise Quantize is not supported now. - if (t.HasQuantization() && t.QTypeId() == kLiteRtQuantizationBlockWise) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Doesn't support BlockWise quantize now"); - } - - auto mtk_type = GetNeuronTensorType(t); - if (!mtk_type) { - return mtk_type.Error(); - } - - if (t.QTypeId() == kLiteRtQuantizationPerTensor) { - auto quant_info = t.PerTensorQuantization(); - LITERT_LOG(LITERT_INFO, "zeroPoint: %d, scale: %f", quant_info.zero_point, - quant_info.scale); - return OperandType(*mtk_type, std::move(mtk_dimensions), quant_info.scale, - quant_info.zero_point, std::nullopt); - } else if (t.QTypeId() == kLiteRtQuantizationPerChannel) { - auto quant_info = t.PerChannelQuantization(); - NeuronSymmPerChannelQuantParams params; - params.scaleCount = quant_info.num_channels; - params.scales = quant_info.scales; - params.channelDim = quant_info.quantized_dimension; - LITERT_LOG(LITERT_INFO, "quantized_dimension: %d", - quant_info.quantized_dimension); - LITERT_LOG(LITERT_INFO, "params.channelDim: %d", params.channelDim); - return OperandType(*mtk_type, std::move(mtk_dimensions), 0, 0, params); - } else { - return OperandType(*mtk_type, std::move(mtk_dimensions), /*scale*/ 0, - /*zero_point*/ 0, std::nullopt); - } - } - - void Info() { - std::string vector = "["; - for (int i = 0; i < dimensionCount; i++) { - vector += std::to_string(dimensions_[i]); - vector += ","; - } - vector += "]"; - LITERT_LOG(LITERT_INFO, - "\n[Type] %d" - "\n[zeroPoint]%d" - "\n[scale]%f" - "\n[dimensionCount]%u" - "\n[dimensions]%s\n", - type, zeroPoint, scale, dimensionCount, vector.c_str()); - } - - OperandType(const OperandType&) = delete; - - OperandType(OperandType&& other) - : dimensions_(std::move(other.dimensions_)), - neuron_per_channel_params_(other.neuron_per_channel_params_) { - // Copy all the scalar fields from other. - *static_cast(this) = - *static_cast(&other); - // Reset the pointer fields by using own data. - dimensions = dimensions_.data(); - }; - - Expected Reshape(std::vector& shape) { - auto elements = GetElementCount(); - if (elements != std::accumulate(shape.begin(), shape.end(), 1, - std::multiplies())) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "the elements is not the same"); - } - this->dimensions_ = shape; - this->dimensionCount = this->dimensions_.size(); - this->dimensions = this->dimensions_.data(); - return {}; - } - - Expected GetPerChannelQuantParams() { - if (!neuron_per_channel_params_.has_value()) { - return Error(kLiteRtStatusErrorRuntimeFailure, "No quant param is set"); - } - return neuron_per_channel_params_.value(); - } - - int32_t GetNeuronType() const { return this->type; } - - std::vector GetDimension() { return this->dimensions_; } - - uint32_t GetElementCount() { - return std::accumulate(dimensions_.begin(), dimensions_.end(), 1, - std::multiplies()); - } - - uint32_t GetRank() { return this->dimensions_.size(); } - - OperandType& operator=(const OperandType&) = delete; - OperandType& operator=(OperandType&& other) = delete; - - private: - explicit OperandType(int32_t mtk_type, std::vector&& mtk_dimensions, - float scale, int32_t zero_point, - std::optional pararms) - : dimensions_(std::move(mtk_dimensions)), - neuron_per_channel_params_(pararms) { - this->scale = scale; - this->zeroPoint = zero_point; - this->type = mtk_type; - this->dimensionCount = dimensions_.size(); - this->dimensions = dimensions_.data(); - } - - std::vector dimensions_; - - std::optional neuron_per_channel_params_ = - std::nullopt; -}; - -// This class takes care of registering Tensors and scalars with a given -// NeuronModel and returing their "operand index", which is how the MTK SDK -// handles them. -class OperandMap { - public: - OperandMap(const NeuronAdapterApi& neuron_adapter_api, NeuronModel* model) - : neuron_adapter_api_(neuron_adapter_api), model_(model) {} - - // Add a scalar operand to the model. - Expected AddScalarBool(bool value) { - return AddScalar(NEURON_BOOL, value); - } - Expected AddScalarInt32(int32_t value) { - return AddScalar(NEURON_INT32, value); - } - Expected AddScalarUInt32(uint32_t value) { - return AddScalar(NEURON_UINT32, value); - } - Expected AddScalarFloat32(float value) { - return AddScalar(NEURON_FLOAT32, value); - } - - // Add a tensor operand to the model - Expected AddTensorInt32(std::vector& shape, - const void* data, const size_t data_size) { - return AddTensor(NEURON_TENSOR_INT32, shape, data, data_size); - } - - // Add a tensor operand to the model - Expected AddTensorByType(int mtk_type, std::vector& shape, - const void* data, const size_t data_size) { - return AddTensor(mtk_type, shape, data, data_size); - } - - Expected AddOperand(const NeuronOperandType& operand) { - return Register(operand); - } - - // Find the operand index for a given tensor and, if not done already, add the - // tensor as an operand in the model. - Expected GetOperandIndex(const Tensor& t) { - auto i = map_.find(t.Get()); - if (i != map_.end()) { - return i->second; - } else { - return Register(t); - } - } - - private: - Expected Register(const Tensor& t); - Expected Register(const NeuronOperandType& operand_type); - uint32_t AllocateOperandIndex() { return next_operand_index_++; } - - template - Expected AddScalar(int32_t mtk_type, T value) { - const NeuronOperandType scalar_type = { - .type = mtk_type, - .dimensionCount = 0, - .dimensions = nullptr, - }; - auto operand_index = Register(scalar_type); - if (!operand_index) { - return operand_index.Error(); - } - if (neuron_adapter_api_.api().model_set_operand_value( - model_, *operand_index, &value, sizeof(value)) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set value of scalar operand"); - } - return operand_index; - } - - Expected AddTensor(int32_t mtk_type, - const std::vector& shape, - const void* data, const size_t data_size) { - const NeuronOperandType scalar_type = { - .type = mtk_type, - .dimensionCount = (uint32_t)shape.size(), - .dimensions = shape.data(), - }; - auto operand_index = Register(scalar_type); - if (!operand_index) { - return operand_index.Error(); - } - if (neuron_adapter_api_.api().model_set_operand_value( - model_, *operand_index, data, data_size) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set value of tensor operand"); - } - return operand_index; - } - - const NeuronAdapterApi& neuron_adapter_api_; - NeuronModel* model_; - int next_operand_index_ = 0; - absl::flat_hash_map map_; -}; - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_OPERAND_MAP_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.cc deleted file mode 100644 index 662b93d66d1508..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeQuantizeOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Quantize"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_QUANTIZE, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_QUANTIZE operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.h deleted file mode 100644 index d2db3761f374de..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/quantize_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_QUANTIZE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_QUANTIZE_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeQuantizeOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_QUANTIZE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.cc deleted file mode 100644 index f9a9af0e8f1fb1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeReshapeOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Reshape"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_RESHAPE, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_RESHAPE operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.h deleted file mode 100644 index d8b3b3246ecbb8..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/reshape_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeReshapeOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.cc deleted file mode 100644 index 8b35a9d0163174..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeRsqrtOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Rsqrt"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_RSQRT, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_RSQRT operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.h deleted file mode 100644 index b8ae369796fc78..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/rsqrt_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeRsqrtOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.cc deleted file mode 100644 index 1f0ea602cec504..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.cc +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeSoftmaxOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Softmax"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_Softmax operation takes an additional scalar operand, which is - // used to pass a Beta value. - float beta; - if (auto status = LiteRtGetSoftmaxBetaOption(op.Get(), &beta); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get beta"); - } - auto beta_operand = operand_map.AddScalarFloat32(beta); - if (!beta_operand) { - return beta_operand.Error(); - } - input_indices.push_back(*beta_operand); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_SOFTMAX, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add NEURON_SOFTMAX operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.h deleted file mode 100644 index 22c9ea4f1aed63..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/softmax_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeSoftmaxOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.cc deleted file mode 100644 index 0b26d24bc39f73..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeSubOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Sub"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - // A NEURON_SUB operation takes a 3rd scalar operand, which is used to pass a - // TfLiteFusedActivation value. - uint32_t tfl_fused_activation; - if (auto status = - LiteRtGetSubFusedActivationOption(op.Get(), &tfl_fused_activation); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get fused activation"); - } - auto fused_activation_operand_index = - operand_map.AddScalarInt32(tfl_fused_activation); - if (!fused_activation_operand_index) { - return fused_activation_operand_index.Error(); - } - input_indices.push_back(*fused_activation_operand_index); - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_SUB, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add value of NEURON_SUB fused activation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.h deleted file mode 100644 index bc1e783d55f7b2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/sub_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeSubOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.cc deleted file mode 100644 index 754a77677a4167..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.h" - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeTransposeOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op) { - LITERT_LOG(LITERT_INFO, "Legalize Transpose"); - std::vector input_indices; - for (auto& input : op.Inputs()) { - auto id = operand_map.GetOperandIndex(input); - if (!id) { - return id.Error(); - } - input_indices.push_back(*id); - } - - std::vector output_indices; - for (auto& output : op.Outputs()) { - auto id = operand_map.GetOperandIndex(output); - if (!id) { - return id.Error(); - } - output_indices.push_back(*id); - } - - if (ModelAddOperation(neuron_adapter_api, model, /*type=*/NEURON_TRANSPOSE, - input_indices, output_indices) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add reshape operation"); - } - - return {}; -} - -} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.h deleted file mode 100644 index 94b445b6218025..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/transpose_op_legalization.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -namespace litert::mediatek { - -Expected LegalizeTransposeOp(const NeuronAdapterApi& neuron_adapter_api, - NeuronModel* model, OperandMap& operand_map, - const litert::Op& op); - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/BUILD deleted file mode 100644 index 7315f1598a9476..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/BUILD +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "copy_file", "litert_dynamic_lib") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -litert_dynamic_lib( - name = "dispatch_api", - srcs = [ - "dispatch_api.cc", - "litert_dispatch_device_context.cc", - "litert_dispatch_invocation_context.cc", - ], - hdrs = [ - "litert_dispatch_device_context.h", - "litert_dispatch_invocation_context.h", - ], - copts = [ - "-Os", - "-fno-exceptions", - "-fno-unwind-tables", - "-fno-asynchronous-unwind-tables", - "-ffunction-sections", - "-fdata-sections", - ], - export_litert_only = True, - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }) + [ - "-Wl,-soname=libLiteRtDispatch_GoogleTensor.so", - "-Wl,-lc++abi", - ], - shared_lib_name = "dispatch_api_so", - so_name = "libLiteRtDispatch_Mediatek.so", - tags = [ - # Remove when sdk is available to bazel. - "nobuilder", - "notap", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/neuro_pilot:latest_host_headers", - "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/core:dynamic_loading", - "//tensorflow/lite/experimental/litert/core/util:tensor_type_util", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter_api", - "//tensorflow/lite/experimental/litert/vendors/mediatek/schema:mediatek_litert_schema", - ], -) - -# This is cc_library target for `libLiteRtDispatch_Mediatek.so`. -cc_library( - name = "dispatch_api_shared_lib", - srcs = [":dispatch_api_so"], -) - -# Copies the shared library so that it is available for use in test data as libLiteRtDispatch_Mediatek.so. -copy_file( - name = "copy_dispatch_api_so", - src = "//tensorflow/lite/experimental/litert/vendors/mediatek/dispatch:dispatch_api_so", - target = "libLiteRtDispatch_Mediatek.so", -) - -cc_test( - name = "dispatch_api_mediatek_test", - srcs = [ - "dispatch_api_mediatek_test.cc", - ], - data = [ - ":dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - tags = [ - "no-remote-exec", - "no_oss", - "nobuilder", - "nosan", - "notap", - ], - deps = [ - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/neuro_pilot:latest_host_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/README.md b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/README.md deleted file mode 100644 index 35a6130c76d318..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/README.md +++ /dev/null @@ -1,4 +0,0 @@ -Test case can dispatch_api_mediatek_test can be run on a device with a MetiaTek -mt6989 SoC with the following comands - -$ ../../../google/run_test_on_android.sh dispatch_api_mediatek_test diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api.cc deleted file mode 100644 index b8c5da6ee392f6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api.cc +++ /dev/null @@ -1,327 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#if LITERT_HAS_AHWB_SUPPORT -#include -#endif - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h" - -namespace { - -litert::mediatek::NeuronAdapterApi* TheNeuronAdapter; -char BuildId[256]; - -} // namespace - -namespace litert { -namespace mediatek { - -// ///////////////////////////////////////////////////////////////////////////// -// Basic Execution API -// ///////////////////////////////////////////////////////////////////////////// - -const char* GetSharedLibraryDir(const LiteRtDispatchOption* options, - int num_options) { - for (auto i = 0; i < num_options; ++i) { - auto& option = options[i]; - if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { - return option.value.str_value; - } - } - return nullptr; -} - -LiteRtStatus LiteRtInitialize(const LiteRtDispatchOption* options, - int num_options) { - auto* shared_library_dir = GetSharedLibraryDir(options, num_options); - std::optional shared_library_dir_opt = - shared_library_dir ? std::make_optional(std::string(shared_library_dir)) - : std::nullopt; - - if (auto neuron_adapter_api = - litert::mediatek::NeuronAdapterApi::Create(shared_library_dir_opt); - neuron_adapter_api) { - TheNeuronAdapter = neuron_adapter_api->release(); - } else { - LITERT_LOG(LITERT_INFO, "Initialization failure: %s", - neuron_adapter_api.Error().Message().c_str()); - return neuron_adapter_api.Error().Status(); - } - - auto get_version = TheNeuronAdapter->api().get_version; - if (!get_version) { - LITERT_LOG(LITERT_ERROR, "get_version not found"); - return kLiteRtStatusErrorRuntimeFailure; - } - - NeuronRuntimeVersion version; - if (get_version(&version) != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to get version"); - return kLiteRtStatusErrorRuntimeFailure; - } - LITERT_LOG(LITERT_INFO, "Neuron SDK version: %d.%d.%d", version.major, - version.minor, version.patch); - - snprintf(BuildId, sizeof(BuildId), - "MediaTek Dispatch API version %d.%d.%d, NeuronAdaptor API version " - "%d.%d.%d", - LITERT_API_VERSION_MAJOR, LITERT_API_VERSION_MINOR, - LITERT_API_VERSION_PATCH, version.major, version.minor, - version.patch); - BuildId[sizeof(BuildId) - 1] = 0; - - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetVendorId(const char** vendor_id) { - *vendor_id = "MediaTek"; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetBuildId(const char** build_id) { - *build_id = BuildId; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCapabilities(int* capabilities) { - *capabilities = kLiteRtDispatchCapabilitiesBasic; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtDeviceContextCreate( - LiteRtDispatchDeviceContext* device_context) { - if (auto context = LiteRtDispatchDeviceContextT::Create(*TheNeuronAdapter); - context) { - *device_context = context->release(); - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to create device context: %s", - context.Error().Message().c_str()); - return context.Error().Status(); - } -} - -LiteRtStatus LiteRtDeviceContextDestroy( - LiteRtDispatchDeviceContext device_context) { - delete device_context; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetInputRequirements( - LiteRtDispatchInvocationContext invocation_context, int input_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (auto requirements = - invocation_context->GetInputRequirements(input_index, *tensor_type); - requirements) { - *tensor_buffer_requirements = *requirements; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", - requirements.Error().Message().c_str()); - return requirements.Error().Status(); - } -} - -LiteRtStatus LiteRtGetOutputRequirements( - LiteRtDispatchInvocationContext invocation_context, int output_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (auto requirements = - invocation_context->GetOutputRequirements(output_index, *tensor_type); - requirements) { - *tensor_buffer_requirements = *requirements; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", - requirements.Error().Message().c_str()); - return requirements.Error().Status(); - } -} - -LiteRtStatus LiteRtRegisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBuffer tensor_buffer, - LiteRtTensorBufferHandle* tensor_buffer_handle) { - if (auto result = device_context->RegisterTensorBuffer(tensor_buffer); - result) { - *tensor_buffer_handle = *result; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to register tensor buffer: %s", - result.Error().Message().c_str()); - return result.Error().Status(); - } -} - -LiteRtStatus LiteRtUnregisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = - device_context->UnregisterTensorBuffer(tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to unregister tensor buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtInvocationContextCreate( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs, - LiteRtDispatchInvocationContext* invocation_context) { - auto context = LiteRtDispatchInvocationContextT::Create( - *TheNeuronAdapter, device_context, exec_type, exec_bytecode_buffer, - function_name, num_inputs, num_outputs); - if (!context) { - LITERT_LOG(LITERT_ERROR, "Failed to create context from context binary: %s", - context.Error().Message().c_str()); - return context.Error().Status(); - } - *invocation_context = context->release(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtInvocationContextDestroy( - LiteRtDispatchInvocationContext invocation_context) { - delete invocation_context; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtAttachInput( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = invocation_context->AttachInput(graph_input_index, - tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to attach input: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtAttachOutput( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = invocation_context->AttachOutput(graph_output_index, - tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to attach output: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtDetachInput( - LiteRtDispatchInvocationContext invocation_context, int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = invocation_context->DetachInput(graph_input_index, - tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to detach input: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtDetachOutput( - LiteRtDispatchInvocationContext invocation_context, int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = invocation_context->DetachOutput(graph_output_index, - tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to detach output: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtInvoke(LiteRtDispatchInvocationContext invocation_context) { - if (auto status = invocation_context->Invoke(); !status) { - LITERT_LOG(LITERT_ERROR, "Failed to invoke context: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -} // namespace mediatek -} // namespace litert - -// ///////////////////////////////////////////////////////////////////////////// - -namespace { - -LiteRtDispatchInterface TheInterface = { - .initialize = litert::mediatek::LiteRtInitialize, - .get_vendor_id = litert::mediatek::LiteRtGetVendorId, - .get_build_id = litert::mediatek::LiteRtGetBuildId, - .get_capabilities = litert::mediatek::LiteRtGetCapabilities, - .device_context_create = litert::mediatek::LiteRtDeviceContextCreate, - .device_context_destroy = litert::mediatek::LiteRtDeviceContextDestroy, - .get_input_requirements = litert::mediatek::LiteRtGetInputRequirements, - .get_output_requirements = litert::mediatek::LiteRtGetOutputRequirements, - .register_tensor_buffer = litert::mediatek::LiteRtRegisterTensorBuffer, - .unregister_tensor_buffer = litert::mediatek::LiteRtUnregisterTensorBuffer, - .invocation_context_create = - litert::mediatek::LiteRtInvocationContextCreate, - .invocation_context_destroy = - litert::mediatek::LiteRtInvocationContextDestroy, - .attach_input = litert::mediatek::LiteRtAttachInput, - .attach_output = litert::mediatek::LiteRtAttachOutput, - .detach_input = litert::mediatek::LiteRtDetachInput, - .detach_output = litert::mediatek::LiteRtDetachOutput, - .invoke = litert::mediatek::LiteRtInvoke, -}; - -LiteRtDispatchApi TheApi = { - .version = {.major = LITERT_API_VERSION_MAJOR, - .minor = LITERT_API_VERSION_MINOR, - .patch = LITERT_API_VERSION_PATCH}, - .interface = &TheInterface, - .async_interface = nullptr, - .graph_interface = nullptr, -}; - -} // namespace - -LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api) { - *api = TheApi; - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc deleted file mode 100644 index 9926f55e6884b6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc +++ /dev/null @@ -1,638 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -using ::testing::Pointwise; - -TEST(MediaTek, DispatchApiWithAhwb) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "This test is specific to Android devices with a MediaTek NPU"; -#endif - - LiteRtDispatchOption dispatch_option = { - /*.name=*/kDispatchOptionSharedLibraryDir, - /*.value=*/*litert::ToLiteRtAny(std::any("/data/local/tmp")), - }; - ASSERT_EQ( - LiteRtDispatchInitialize(/*options=*/&dispatch_option, /*num_options=*/1), - kLiteRtStatusOk); - - const char* vendor_id; - EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "vendor_id: " << vendor_id; - - const char* build_id; - EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "build_id: " << build_id; - - LiteRtApiVersion api_version; - EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); - ABSL_LOG(INFO) << "api_version: " << api_version.major << "." - << api_version.minor << "." << api_version.patch; - - int capabilities; - EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); - ABSL_LOG(INFO) << "capabilities: " << capabilities; - - LiteRtDispatchDeviceContext device_context = nullptr; - EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "device_context: " << device_context; - - auto model_file_name = - litert::testing::GetTestFilePath(kMediaTekModelFileName); - auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model) << model.Error(); - ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() - << " bytes"; - - // /////////////////////////////////////////////////////////////////////////// - // Set up an invocation context for a given model. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtMemBuffer exec_bytecode_buffer = {/*.fd=*/-1, - /*.base_addr=*/model->Data(), - /*.offset=*/0, - /*.size=*/model->Size()}; - LiteRtDispatchInvocationContext invocation_context = nullptr; - EXPECT_EQ(LiteRtDispatchInvocationContextCreate( - device_context, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, /*function_name=*/nullptr, - /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "Invocation context: " << invocation_context; - - // /////////////////////////////////////////////////////////////////////////// - // Determine tensor buffer requirements. - // /////////////////////////////////////////////////////////////////////////// - - int num_tensor_buffer_types; - LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/0, &kInput0TensorType, - &input_0_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_0_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_0_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_0_tensor_buffer_requirements, /*type_index=*/0, - &input_0_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t input_0_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); - - LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/1, &kInput1TensorType, - &input_1_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_1_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_1_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_1_tensor_buffer_requirements, /*type_index=*/0, - &input_1_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t input_1_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); - - LiteRtTensorBufferRequirements output_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetOutputRequirements( - invocation_context, /*output_index=*/0, &kOutputTensorType, - &output_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - output_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType output_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - output_tensor_buffer_requirements, /*type_index=*/0, - &output_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); - size_t output_tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - output_tensor_buffer_requirements, &output_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); - - // /////////////////////////////////////////////////////////////////////////// - // Allocate tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBuffer input_0_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_0_tensor_buffer_type, &kInput0TensorType, - input_0_tensor_buffer_size, &input_0_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer input_1_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_1_tensor_buffer_type, &kInput1TensorType, - input_1_tensor_buffer_size, &input_1_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer output_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - output_tensor_buffer_type, &kOutputTensorType, - output_tensor_buffer_size, &output_tensor_buffer), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Register tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBufferHandle input_1_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_1_tensor_buffer, &input_1_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle input_0_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_0_tensor_buffer, &input_0_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle output_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, output_tensor_buffer, &output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Attach tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with more data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor_2, - sizeof(kTestInput0Tensor_2)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor_2, - sizeof(kTestInput1Tensor_2)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model once more. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking second execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor_2[i]; - } - EXPECT_THAT(output, - Pointwise(testing::FloatNear(1e-3), kTestOutputTensor_2)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Clean up resources. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), - kLiteRtStatusOk); - LiteRtDestroyTensorBuffer(output_tensor_buffer); - LiteRtDestroyTensorBuffer(input_1_tensor_buffer); - LiteRtDestroyTensorBuffer(input_0_tensor_buffer); - EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), - kLiteRtStatusOk); -} - -TEST(MediaTek, DispatchApiWithDmaBuf) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "This test is specific to Android devices with a MediaTek NPU"; -#endif - - EXPECT_EQ(LiteRtDispatchInitialize(/*options=*/nullptr, /*num_options=*/0), - kLiteRtStatusOk); - - const char* vendor_id; - EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "vendor_id: " << vendor_id; - - const char* build_id; - EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "build_id: " << build_id; - - LiteRtApiVersion api_version; - EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); - ABSL_LOG(INFO) << "api_version: " << api_version.major << "." - << api_version.minor << "." << api_version.patch; - - int capabilities; - EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); - ABSL_LOG(INFO) << "capabilities: " << capabilities; - - LiteRtDispatchDeviceContext device_context = nullptr; - EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "device_context: " << device_context; - - auto model_file_name = - litert::testing::GetTestFilePath(kMediaTekModelFileName); - auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model) << model.Error(); - ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() - << " bytes"; - - // /////////////////////////////////////////////////////////////////////////// - // Set up an invocation context for a given model. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtMemBuffer exec_bytecode_buffer = {/*.fd=*/-1, - /*.base_addr=*/model->Data(), - /*.offset=*/0, - /*.size=*/model->Size()}; - LiteRtDispatchInvocationContext invocation_context = nullptr; - EXPECT_EQ(LiteRtDispatchInvocationContextCreate( - device_context, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, /*function_name=*/nullptr, - /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "Invocation context: " << invocation_context; - - // /////////////////////////////////////////////////////////////////////////// - // Determine tensor buffer requirements. - // /////////////////////////////////////////////////////////////////////////// - - int num_tensor_buffer_types; - LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/0, &kInput0TensorType, - &input_0_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_0_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 2); - LiteRtTensorBufferType input_0_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_0_tensor_buffer_requirements, /*type_index=*/1, - &input_0_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); - size_t input_0_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); - - LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/1, &kInput1TensorType, - &input_1_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_1_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 2); - LiteRtTensorBufferType input_1_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_1_tensor_buffer_requirements, /*type_index=*/1, - &input_1_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); - size_t input_1_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); - - LiteRtTensorBufferRequirements output_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetOutputRequirements( - invocation_context, /*output_index=*/0, &kOutputTensorType, - &output_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - output_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 2); - LiteRtTensorBufferType output_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - output_tensor_buffer_requirements, /*type_index=*/1, - &output_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); - size_t output_tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - output_tensor_buffer_requirements, &output_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); - - // /////////////////////////////////////////////////////////////////////////// - // Allocate tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBuffer input_0_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_0_tensor_buffer_type, &kInput0TensorType, - input_0_tensor_buffer_size, &input_0_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer input_1_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_1_tensor_buffer_type, &kInput1TensorType, - input_1_tensor_buffer_size, &input_1_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer output_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - output_tensor_buffer_type, &kOutputTensorType, - output_tensor_buffer_size, &output_tensor_buffer), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Register tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBufferHandle input_1_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_1_tensor_buffer, &input_1_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle input_0_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_0_tensor_buffer, &input_0_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle output_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, output_tensor_buffer, &output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Attach tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with more data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor_2, - sizeof(kTestInput0Tensor_2)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor_2, - sizeof(kTestInput1Tensor_2)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model once more. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking second execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor_2[i]; - } - EXPECT_THAT(output, - Pointwise(testing::FloatNear(1e-3), kTestOutputTensor_2)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Clean up resources. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), - kLiteRtStatusOk); - LiteRtDestroyTensorBuffer(output_tensor_buffer); - LiteRtDestroyTensorBuffer(input_1_tensor_buffer); - LiteRtDestroyTensorBuffer(input_0_tensor_buffer); - EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), - kLiteRtStatusOk); -} diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc deleted file mode 100644 index b728f5c5c15d5f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h" - -#include - -#include -#include - -#include "neuron/api/NeuronAdapter.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -using litert::Error; - -LiteRtDispatchDeviceContextT::~LiteRtDispatchDeviceContextT() = default; - -litert::Expected -LiteRtDispatchDeviceContextT::Create( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api) { - return std::unique_ptr( - new LiteRtDispatchDeviceContextT(neuron_adapter_api)); -} - -litert::Expected -LiteRtDispatchDeviceContextT::RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer) { - LiteRtTensorBufferType tensor_buffer_type; - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferType(tensor_buffer, &tensor_buffer_type)); - - if (tensor_buffer_type != kLiteRtTensorBufferTypeAhwb && - tensor_buffer_type != kLiteRtTensorBufferTypeDmaBuf) { - return Error(kLiteRtStatusErrorUnsupported, "Unsupported buffer type"); - } - - size_t tensor_buffer_size; - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferSize(tensor_buffer, &tensor_buffer_size)); - - size_t tensor_buffer_offset; - if (auto status = - LiteRtGetTensorBufferOffset(tensor_buffer, &tensor_buffer_offset); - status != kLiteRtStatusOk) { - if (status == kLiteRtStatusErrorNotFound) { - tensor_buffer_offset = 0; - } else { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get buffer offset"); - } - } - - LiteRtRankedTensorType tensor_type; - LITERT_RETURN_IF_ERROR( - LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type)); - - auto* tensor_strides = tensor_type.layout.strides; - if (tensor_strides != nullptr) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Tensor strides are not supported"); - } - - switch (tensor_buffer_type) { - case kLiteRtTensorBufferTypeAhwb: -#if LITERT_HAS_AHWB_SUPPORT - AHardwareBuffer* ahwb; - if (auto status = LiteRtGetTensorBufferAhwb(tensor_buffer, &ahwb); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get AHWB"); - } -#else - return Error(kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer is not supported on this platform"); -#endif // LITERT_HAS_AHWB_SUPPORT - NeuronMemory* neuron_memory; -#if LITERT_HAS_AHWB_SUPPORT - if (neuron_adapter_api_.api().memory_create_from_ahwb( - ahwb, &neuron_memory) != NEURON_NO_ERROR) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create NeuronMemory from AHWB"); - } - return neuron_memory_registry_.Register(neuron_memory, tensor_buffer_size, - tensor_buffer_offset); -#else - (void)neuron_adapter_api_; - return litert::Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer is not supported on this platform"); -#endif // LITERT_HAS_AHWB_SUPPORT - break; - - case kLiteRtTensorBufferTypeDmaBuf: - - int fd; -#if LITERT_HAS_DMABUF_SUPPORT - void* addr; - if (auto status = - LiteRtGetTensorBufferDmaBufBuffer(tensor_buffer, &addr, &fd); - status != kLiteRtStatusOk) { - return Error(status, "Failed to get DMA-BUF"); - } -#else - return Error(kLiteRtStatusErrorRuntimeFailure, - "DMA-BUF is not supported on this platform"); -#endif // LITERT_HAS_DMABUF_SUPPORT - if (neuron_adapter_api_.api().memory_create_from_fd( - tensor_buffer_size, /*protect*/ PROT_READ | PROT_WRITE, fd, - tensor_buffer_offset, &neuron_memory) != NEURON_NO_ERROR) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create NeuronMemory from DMA-BUF"); - } - return neuron_memory_registry_.Register(neuron_memory, tensor_buffer_size, - tensor_buffer_offset); - break; - - default: - LITERT_LOG(LITERT_ERROR, "Unsupported buffer type: %d", - tensor_buffer_type); - return litert::Unexpected(kLiteRtStatusErrorUnsupported); - } -} - -LiteRtDispatchDeviceContextT::NeuronMemoryRegistry::~NeuronMemoryRegistry() { - for (auto i = 0; i < records_.size(); ++i) { - auto& record = records_[i]; - if (record.neuron_memory != nullptr) { - neuron_adapter_api_.api().memory_free(record.neuron_memory); - } - } -} - -LiteRtTensorBufferHandle -LiteRtDispatchDeviceContextT::NeuronMemoryRegistry::Register( - NeuronMemory* neuron_memory, size_t size, size_t offset) { - int dest_index = -1; - for (auto i = 0; i < records_.size(); ++i) { - if (!records_[i].neuron_memory) { - dest_index = i; - break; - } - } - if (dest_index < 0) { - dest_index = records_.size(); - records_.push_back({}); - } - auto& dest = records_[dest_index]; - dest = {neuron_memory, size, offset}; - return dest_index; -} - -litert::Expected -LiteRtDispatchDeviceContextT::NeuronMemoryRegistry::Unregister( - LiteRtTensorBufferHandle tensor_buffer_handle) { - auto record = Find(tensor_buffer_handle); - if (!record) { - return record.Error(); - } else { - auto& mem = (*record)->neuron_memory; - neuron_adapter_api_.api().memory_free(mem); - mem = nullptr; - return {}; - } -} - -litert::Expected -LiteRtDispatchDeviceContextT::NeuronMemoryRegistry::Find( - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (tensor_buffer_handle < 0 || tensor_buffer_handle >= records_.size()) { - return litert::Unexpected(kLiteRtStatusErrorInvalidArgument, - "Invalid tensor buffer handle"); - } - return &records_[tensor_buffer_handle]; -} diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h deleted file mode 100644 index 483701fe919acc..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ - -#include - -#include "neuron/api/NeuronAdapter.h" -#include "absl/container/flat_hash_set.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -class LiteRtDispatchDeviceContextT { - public: - using Ptr = std::unique_ptr; - struct NeuronMemoryInfo { - NeuronMemory* neuron_memory; - size_t size; - size_t offset; - }; - - ~LiteRtDispatchDeviceContextT(); - - static litert::Expected Create( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api); - - litert::Expected RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer); - - litert::Expected UnregisterTensorBuffer( - LiteRtTensorBufferHandle tensor_buffer_handle) { - return neuron_memory_registry_.Unregister(tensor_buffer_handle); - } - - litert::Expected GetNeuronMemoryInfo( - LiteRtTensorBufferHandle tensor_buffer_handle) { - auto record = neuron_memory_registry_.Find(tensor_buffer_handle); - if (!record) { - return record.Error(); - } else { - return NeuronMemoryInfo(**record); - } - } - - private: - class NeuronMemoryRegistry { - public: - explicit NeuronMemoryRegistry( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api) - : neuron_adapter_api_(neuron_adapter_api) {} - ~NeuronMemoryRegistry(); - LiteRtTensorBufferHandle Register(NeuronMemory* neuron_memory, size_t size, - size_t offset); - litert::Expected Unregister( - LiteRtTensorBufferHandle tensor_buffer_handle); - litert::Expected Find( - LiteRtTensorBufferHandle tensor_buffer_handle); - - private: - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api_; - std::vector records_; - }; - - explicit LiteRtDispatchDeviceContextT( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api) - : neuron_adapter_api_(neuron_adapter_api), - neuron_memory_registry_(neuron_adapter_api) {} - - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api_; - NeuronMemoryRegistry neuron_memory_registry_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc deleted file mode 100644 index 2f235e182c9614..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc +++ /dev/null @@ -1,435 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h" - -using litert::Error; -using litert::Expected; -using litert::mediatek::NeuronCompilationPtr; -using litert::mediatek::NeuronExecutionPtr; -using litert::mediatek::NeuronModelPtr; - -namespace { - -Expected> LoadFromCachedNetwork( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api, - const void* bytecode_addr, size_t bytecode_size) { - NeuronModel* model; - NeuronCompilation* compilation; - if (neuron_adapter_api.api().model_restore_from_compiled_network( - &model, &compilation, bytecode_addr, bytecode_size) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to restore model from compiled network"); - } - return std::make_pair( - NeuronModelPtr{model, neuron_adapter_api.api().model_free}, - NeuronCompilationPtr{compilation, - neuron_adapter_api.api().compilation_free}); -} - -uint16_t GetRestoreDlaExtensionOperandType( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api) { - NeuronRuntimeVersion version; - neuron_adapter_api.api().get_version(&version); - // The values below were suggested by MTK. - if (version.major >= 8) { - return 0x0200; - } else { - return 0x0100; - } -} - -Expected> LoadFromDlaBytecode( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api, - const void* bytecode_addr, size_t bytecode_size, int num_inputs, - int num_outputs) { - Expected model = neuron_adapter_api.CreateModel(); - if (!model) { - return model.Error(); - } - - // fake input, the real outputs are loaded by compiled network. - constexpr const NeuronOperandType fake_io_operand_type{ - .type = NEURON_TENSOR_FLOAT32, - .dimensionCount = 0, - .scale = 0.0f, - .zeroPoint = 0, - }; - - std::vector input_op_number; - input_op_number.reserve(num_inputs); - for (auto i = 0; i < num_inputs; i++) { - if (neuron_adapter_api.api().model_add_operand( - model->get(), &fake_io_operand_type) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add input operand"); - } - input_op_number.emplace_back(i); - } - - const uint16_t kNetworkOperandRestoreData = - GetRestoreDlaExtensionOperandType(neuron_adapter_api); - constexpr const uint16_t kRestoreDlaExtensionOperationType = 0; - constexpr const char* kExtensionRestoreCompiledNetwork = - "com.mediatek.compiled_network"; - - int32_t operand_type; - if (neuron_adapter_api.api().model_get_extension_operand_type( - model->get(), kExtensionRestoreCompiledNetwork, - kNetworkOperandRestoreData, &operand_type) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to getextension operand"); - } - - const NeuronOperandType extension_operand_type{ - .type = operand_type, - .dimensionCount = 0, - .scale = 0.0f, - .zeroPoint = 0, - }; - if (neuron_adapter_api.api().model_add_operand( - model->get(), &extension_operand_type) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add extension operand"); - } - input_op_number.emplace_back(input_op_number.size()); - if (neuron_adapter_api.api().model_set_operand_value( - model->get(), input_op_number.back(), bytecode_addr, bytecode_size) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set extension operand value"); - } - - std::vector output_op_number; - for (auto i = 0; i < num_outputs; i++) { - if (neuron_adapter_api.api().model_add_operand( - model->get(), &fake_io_operand_type) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add output operand"); - } - output_op_number.emplace_back(input_op_number.size() + i); - } - - int32_t operation_type; - if (neuron_adapter_api.api().model_get_extension_operation_type( - model->get(), kExtensionRestoreCompiledNetwork, - kRestoreDlaExtensionOperationType, - &operation_type) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get extension operation"); - } - - // Add extension operation - if (neuron_adapter_api.api().model_add_operation( - model->get(), static_cast(operation_type), - input_op_number.size(), input_op_number.data(), - output_op_number.size(), - output_op_number.data()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to add extension operation"); - } - - if (neuron_adapter_api.api().model_identify_inputs_and_outputs( - model->get(), input_op_number.size() - 1, input_op_number.data(), - output_op_number.size(), - output_op_number.data()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to identify I/Os"); - } - - if (neuron_adapter_api.api().model_finish(model->get()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to finish model"); - } - - auto compilation = neuron_adapter_api.CreateCompilation(model->get()); - if (!compilation) { - return compilation.Error(); - } - - if (neuron_adapter_api.api().compilation_set_priority( - compilation->get(), NEURON_PRIORITY_HIGH) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set compilation priority"); - } - - if (neuron_adapter_api.api().compilation_set_preference( - compilation->get(), NEURON_PREFER_SUSTAINED_SPEED) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set compilation preference"); - } - - // We use AOT compile options since the DLA file was compiled ahead of time. - const auto compile_options = - std::string(neuron_adapter_api.AotCompileOptions()); - if (!compile_options.empty()) { - if (neuron_adapter_api.api().compilation_set_optimization_string( - compilation->get(), compile_options.c_str()) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set optimization string"); - } - } - - if (neuron_adapter_api.api().compilation_finish(compilation->get()) != - NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to finish compilation"); - } - - return std::make_pair(std::move(*model), std::move(*compilation)); -} - -Expected> -LoadModelAndCompilation( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api, - const void* bytecode_addr, size_t bytecode_size, int num_inputs, - int num_outputs) { - if (auto result = LoadFromDlaBytecode(neuron_adapter_api, bytecode_addr, - bytecode_size, num_inputs, num_outputs); - !result) { - return LoadFromCachedNetwork(neuron_adapter_api, bytecode_addr, - bytecode_size); - } else { - return result; - } -} - -} // namespace - -Expected -LiteRtDispatchInvocationContextT::Create( - litert::mediatek::NeuronAdapterApi& neuron_adapter_api, - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs) { - neuron::SchemaResolver resolver; - - const void* exec_bytecode_ptr = - static_cast(exec_bytecode_buffer->base_addr) + - exec_bytecode_buffer->offset; - auto exec_bytecode_size = exec_bytecode_buffer->size; - auto res = resolver.Initialize((const uint8_t*)exec_bytecode_ptr, - exec_bytecode_size); - if (res.HasValue() && res.Value()) { - std::string func = function_name != nullptr ? function_name : ""; - auto graph = resolver.GetCompiledGraph(func); - if (!graph.has_value()) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Couldn't find the subgraph"); - } - auto compile_graph = graph.value().GetCompiledNetwork(); - if (!compile_graph) { - return compile_graph.Error(); - } - std::tie(exec_bytecode_ptr, exec_bytecode_size) = compile_graph.Value(); - } - - auto model_and_compilation = - LoadModelAndCompilation(neuron_adapter_api, exec_bytecode_ptr, - exec_bytecode_size, num_inputs, num_outputs); - if (!model_and_compilation) { - return model_and_compilation.Error(); - } - - auto& model = model_and_compilation->first; - auto& compilation = model_and_compilation->second; - - auto execution = neuron_adapter_api.CreateExecution(compilation.get()); - if (!execution) { - return execution.Error(); - } - - if (neuron_adapter_api.api().execution_set_boost_hint( - execution->get(), 100) != NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set execution boost hint"); - } - - return Ptr(new LiteRtDispatchInvocationContextT( - neuron_adapter_api, device_context, model.release(), - compilation.release(), execution->release(), num_inputs, num_outputs)); -} - -LiteRtDispatchInvocationContextT::~LiteRtDispatchInvocationContextT() { - if (execution_) { - neuron_adapter_api_.api().execution_free(execution_); - } - if (compilation_) { - neuron_adapter_api_.api().compilation_free(compilation_); - } - if (model_) { - neuron_adapter_api_.api().model_free(model_); - } -} - -LiteRtDispatchInvocationContextT::IoRequirementsBuilder::IoRequirementsBuilder( - size_t buffer_size, const std::vector& padded_dimensions) - : buffer_size_(buffer_size) { - auto rank = padded_dimensions.size(); - strides_.resize(rank); - strides_[0] = 1; - for (auto i = 1; i < rank; ++i) { - strides_[i] = padded_dimensions[i - 1]; - } -} - -Expected -LiteRtDispatchInvocationContextT::IoRequirementsBuilder::Create() { - static constexpr std::array kSupportedTensorBufferTypes = { -#if defined(__ANDROID__) - kLiteRtTensorBufferTypeAhwb, -#endif // __ANDROID__ - kLiteRtTensorBufferTypeDmaBuf, - }; - - LiteRtTensorBufferRequirements requirements; - if (auto status = LiteRtCreateTensorBufferRequirements( - kSupportedTensorBufferTypes.size(), - kSupportedTensorBufferTypes.data(), buffer_size_, strides_.size(), - strides_.data(), &requirements); - status != kLiteRtStatusOk) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to create tensor buffer requirements"); - } - - return requirements; -} - -Expected -LiteRtDispatchInvocationContextT::GetInputRequirements( - int input_index, const LiteRtRankedTensorType& tensor_type) { - if (!input_requirements_builders_[input_index]) { - size_t buffer_size; - if (neuron_adapter_api_.api().compilation_get_input_padded_size( - compilation_, input_index, &buffer_size) != NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get input padded size"); - } - - std::vector padded_dimensions(tensor_type.layout.rank); - if (neuron_adapter_api_.api().compilation_get_input_padded_dimensions( - compilation_, input_index, padded_dimensions.data()) != - NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get input padded dimensions"); - } - - input_requirements_builders_[input_index] = - std::make_unique(buffer_size, padded_dimensions); - } - - return input_requirements_builders_[input_index]->Create(); -} - -Expected -LiteRtDispatchInvocationContextT::GetOutputRequirements( - int output_index, const LiteRtRankedTensorType& tensor_type) { - if (!output_requirements_builders_[output_index]) { - size_t buffer_size; - if (neuron_adapter_api_.api().compilation_get_output_padded_size( - compilation_, output_index, &buffer_size) != NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get output padded size"); - } - - std::vector padded_dimensions(tensor_type.layout.rank); - if (neuron_adapter_api_.api().compilation_get_output_padded_dimensions( - compilation_, output_index, padded_dimensions.data()) != - NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get output padded dimensions"); - } - - output_requirements_builders_[output_index] = - std::make_unique(buffer_size, padded_dimensions); - } - - return output_requirements_builders_[output_index]->Create(); -} - -Expected LiteRtDispatchInvocationContextT::AttachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - auto neuron_memory_info = - device_context_->GetNeuronMemoryInfo(tensor_buffer_handle); - if (!neuron_memory_info) { - return litert::Error(neuron_memory_info.Error()); - } - - if (neuron_adapter_api_.api().execution_set_input_from_memory( - execution_, graph_input_index, nullptr, - neuron_memory_info->neuron_memory, neuron_memory_info->offset, - neuron_memory_info->size) != NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set execution input from memory"); - } - return {}; -} - -Expected LiteRtDispatchInvocationContextT::AttachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - auto neuron_memory_info = - device_context_->GetNeuronMemoryInfo(tensor_buffer_handle); - if (!neuron_memory_info) { - return litert::Error(neuron_memory_info.Error()); - } - - if (neuron_adapter_api_.api().execution_set_output_from_memory( - execution_, graph_output_index, nullptr, - neuron_memory_info->neuron_memory, neuron_memory_info->offset, - neuron_memory_info->size) != NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to set execution output from memory"); - } - return {}; -} - -Expected LiteRtDispatchInvocationContextT::DetachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - // Nothing to do. - return {}; -} - -Expected LiteRtDispatchInvocationContextT::DetachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - // Nothing to do. - return {}; -} - -Expected LiteRtDispatchInvocationContextT::Invoke() { - if (neuron_adapter_api_.api().execution_compute(execution_) != - NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to execute network"); - } - return {}; -} diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h deleted file mode 100644 index f58ee976b693e2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.h +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ - -#include - -#include "neuron/api/NeuronAdapter.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -class LiteRtDispatchInvocationContextT { - public: - using Ptr = std::unique_ptr; - - static litert::Expected Create( - litert::mediatek::NeuronAdapterApi& neuron_adapter_api, - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs); - - ~LiteRtDispatchInvocationContextT(); - - litert::Expected GetInputRequirements( - int input_index, const LiteRtRankedTensorType& tensor_type); - - litert::Expected GetOutputRequirements( - int output_index, const LiteRtRankedTensorType& tensor_type); - - litert::Expected AttachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle); - litert::Expected AttachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected DetachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle); - litert::Expected DetachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected Invoke(); - - private: - class IoRequirementsBuilder { - public: - IoRequirementsBuilder(size_t buffer_size, - const std::vector& padded_dimensions); - litert::Expected Create(); - - private: - size_t buffer_size_; - std::vector strides_; - }; - - LiteRtDispatchInvocationContextT( - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api, - LiteRtDispatchDeviceContext device_context, NeuronModel* model, - NeuronCompilation* compilation, NeuronExecution* execution, - int num_inputs, int num_outputs) - : neuron_adapter_api_(neuron_adapter_api), - device_context_(device_context), - model_(model), - compilation_(compilation), - execution_(execution), - input_requirements_builders_(num_inputs), - output_requirements_builders_(num_outputs) {} - - const litert::mediatek::NeuronAdapterApi& neuron_adapter_api_; - LiteRtDispatchDeviceContext device_context_; - NeuronModel* model_; - NeuronCompilation* compilation_; - NeuronExecution* execution_; - std::vector> - input_requirements_builders_; - std::vector> - output_requirements_builders_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/mediatek_build_defs.bzl b/tensorflow/lite/experimental/litert/vendors/mediatek/mediatek_build_defs.bzl deleted file mode 100644 index 5427e9e0d29521..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/mediatek_build_defs.bzl +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Build definitions for Mediatek backend.""" - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "append_rule_kwargs", "litert_bin", "litert_lib", "make_rpaths") - -_MTK_STD_LIBS_HOST = [ - # copybara:uncomment_begin(google-only) - # "//third_party/neuro_pilot:latest/host/lib/libc++.so.1", - # "//third_party/neuro_pilot:latest/host/lib/libstdc++.so.6", - # copybara:uncomment_end -] # @unused - -_MTK_NEURON_ADAPTER_SO = [ - # copybara:uncomment_begin(google-only) - # "//third_party/neuro_pilot:latest/host/lib/libneuron_adapter.so", - # copybara:uncomment_end -] - -# TODO: Make rpaths dynamic with "$(location {})". -_MTK_HOST_RPATHS = [ - # copybara:uncomment_begin(google-only) - # "third_party/neuro_pilot/latest/host/lib", - # copybara:uncomment_end -] - -def _litert_with_mtk_base( - litert_rule, - use_custom_libcc, - **litert_rule_kwargs): - if use_custom_libcc: - # TODO: Figure out strategy for custom libcc. - fail("Custom libcc not yet supported") - - data_x86_64 = [] - data_x86_64.extend(_MTK_NEURON_ADAPTER_SO) - append_rule_kwargs( - litert_rule_kwargs, - data = select({ - "//tensorflow:linux_x86_64": data_x86_64, - "//conditions:default": [], - }), - linkopts = select({ - "//tensorflow:linux_x86_64": [make_rpaths(_MTK_HOST_RPATHS)], - "//conditions:default": [], - }), - ) - - litert_rule(**litert_rule_kwargs) - -def litert_cc_lib_with_mtk( - use_custom_libcc = False, - **litert_lib_kwargs): - """Creates a litert_lib target with Mediatek backend dependencies. - - Args: - use_custom_libcc: Whether to use a custom libcc. Not yet supported. - **litert_lib_kwargs: Keyword arguments passed to litert_lib. - """ - _litert_with_mtk_base( - litert_lib, - use_custom_libcc, - **litert_lib_kwargs - ) - -def litert_cc_bin_with_mtk( - use_custom_libcc = False, - **litert_bin_kwargs): - """Creates a litert_bin target with Mediatek backend dependencies. - - Args: - use_custom_libcc: Whether to use a custom libcc. Not yet supported. - **litert_bin_kwargs: Keyword arguments passed to litert_bin. - """ - _litert_with_mtk_base( - litert_bin, - use_custom_libcc, - **litert_bin_kwargs - ) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.cc deleted file mode 100644 index ab3cbfaddb1287..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.cc +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h" - -#include - -#include -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" - -#define LOAD_SYMB(S, H) \ - if (auto maybe_H = dlib_.LookupSymbol(#S); maybe_H.HasValue()) { \ - H = reinterpret_cast(std::move(maybe_H).Value()); \ - } else { \ - LITERT_LOG(LITERT_WARNING, "Failed to load symbol %s: %s", #S, \ - dlib_.DlError()); \ - } - -namespace litert { -namespace mediatek { - -NeuronAdapterApi::NeuronAdapterApi() : api_(new Api) {} - -litert::Expected NeuronAdapterApi::Create( - std::optional shared_library_dir) { - std::unique_ptr neuron_adapter_api(new NeuronAdapterApi); - if (auto status = neuron_adapter_api->LoadSymbols(shared_library_dir); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to load NeuronAdapter shared library: %s", - status.Error().Message().c_str()); - return status.Error(); - } - - return neuron_adapter_api; -} - -litert::Expected NeuronAdapterApi::LoadSymbols( - std::optional shared_library_dir) { - constexpr auto kLibNeuronAdapterLib = "libneuron_adapter.so"; - - const std::vector so_paths = { - // The following preinstalled library is for system partition - // applications. - "libneuronusdk_adapter.mtk.so", "libneuron_adapter_mgvi.so", - kLibNeuronAdapterLib, - // Finally, the app may want to provide their own version of the library. - shared_library_dir.has_value() - ? absl::StrCat(*shared_library_dir, "/", kLibNeuronAdapterLib) - : kLibNeuronAdapterLib}; - for (auto& so_path : so_paths) { - auto maybe_dlib = SharedLibrary::Load(so_path, RtldFlags::Default()); - if (maybe_dlib.HasValue()) { - dlib_ = std::move(maybe_dlib).Value(); - } - } - - if (!dlib_.Loaded()) { - return litert::Error(kLiteRtStatusErrorDynamicLoading, - "Failed to load NeuronAdapter shared library"); - } - - LITERT_LOG(LITERT_INFO, "Loaded NeuronAdapter shared library."); - - // Binds all supported symbols from the shared library to the function - // pointers. - LOAD_SYMB(NeuronCompilation_create, api_->compilation_create); - LOAD_SYMB(NeuronCompilation_createWithOptions, - api_->compilation_create_with_options); - LOAD_SYMB(NeuronCompilation_finish, api_->compilation_finish); - LOAD_SYMB(NeuronCompilation_free, api_->compilation_free); - LOAD_SYMB(NeuronCompilation_getInputPaddedDimensions, - api_->compilation_get_input_padded_dimensions); - LOAD_SYMB(NeuronCompilation_getInputPaddedSize, - api_->compilation_get_input_padded_size); - LOAD_SYMB(NeuronCompilation_getOutputPaddedDimensions, - api_->compilation_get_output_padded_dimensions); - LOAD_SYMB(NeuronCompilation_getOutputPaddedSize, - api_->compilation_get_output_padded_size); - LOAD_SYMB(NeuronCompilation_setOptimizationString, - api_->compilation_set_optimization_string); - LOAD_SYMB(NeuronCompilation_setPreference, api_->compilation_set_preference); - LOAD_SYMB(NeuronCompilation_setPriority, api_->compilation_set_priority); - LOAD_SYMB(NeuronExecution_compute, api_->execution_compute); - LOAD_SYMB(NeuronExecution_create, api_->execution_create); - LOAD_SYMB(NeuronExecution_free, api_->execution_free); - LOAD_SYMB(NeuronCompilation_getCompiledNetworkSize, - api_->compilation_get_compiled_network_size); - LOAD_SYMB(NeuronCompilation_storeCompiledNetwork, - api_->compilation_store_compiled_network); - LOAD_SYMB(NeuronExecution_setBoostHint, api_->execution_set_boost_hint); - LOAD_SYMB(NeuronExecution_setInputFromMemory, - api_->execution_set_input_from_memory); - LOAD_SYMB(NeuronExecution_setOutputFromMemory, - api_->execution_set_output_from_memory); - LOAD_SYMB(NeuronMemory_createFromAHardwareBuffer, - api_->memory_create_from_ahwb); - LOAD_SYMB(NeuronMemory_createFromFd, api_->memory_create_from_fd); - LOAD_SYMB(NeuronMemory_free, api_->memory_free); - LOAD_SYMB(NeuronModel_addOperand, api_->model_add_operand); - LOAD_SYMB(NeuronModel_addOperation, api_->model_add_operation); - LOAD_SYMB(NeuronModel_create, api_->model_create); - LOAD_SYMB(NeuronModel_finish, api_->model_finish); - LOAD_SYMB(NeuronModel_free, api_->model_free); - LOAD_SYMB(NeuronModel_getExtensionOperandType, - api_->model_get_extension_operand_type); - LOAD_SYMB(NeuronModel_getExtensionOperationType, - api_->model_get_extension_operation_type); - LOAD_SYMB(NeuronModel_identifyInputsAndOutputs, - api_->model_identify_inputs_and_outputs); - LOAD_SYMB(NeuronModel_restoreFromCompiledNetwork, - api_->model_restore_from_compiled_network); - LOAD_SYMB(NeuronModel_setName, api_->model_set_name); - LOAD_SYMB(NeuronModel_setOperandValue, api_->model_set_operand_value); - LOAD_SYMB(NeuronModel_setOperandSymmPerChannelQuantParams, - api_->model_set_symm_per_channel_quant_params); - LOAD_SYMB(Neuron_getVersion, api_->get_version); - - LITERT_LOG(LITERT_INFO, "NeuronAdapter symbols loaded"); - return {}; -} - -Expected NeuronAdapterApi::CreateModel() const { - NeuronModel* model; - if (api().model_create(&model) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to create NeuroModel"); - } - return NeuronModelPtr{model, api().model_free}; -} - -Expected NeuronAdapterApi::CreateCompilation( - NeuronModel* model) const { - NeuronCompilation* compilation; - if (api().compilation_create(model, &compilation) != NEURON_NO_ERROR) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to create NeuronCompilation"); - } - return NeuronCompilationPtr{compilation, api().compilation_free}; -} - -Expected NeuronAdapterApi::CreateCompilation( - NeuronModel* model, const std::string& compile_options) const { - NeuronCompilation* compilation; - if (auto status = api().compilation_create_with_options( - model, &compilation, compile_options.c_str()); - status != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, - "NeuronCompilation_createWithOptions failed with error %d", - status); - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to create NeuronCompilation"); - } - return NeuronCompilationPtr{compilation, api().compilation_free}; -} - -Expected NeuronAdapterApi::CreateExecution( - NeuronCompilation* compilation) const { - NeuronExecution* execution; - if (api().execution_create(compilation, &execution) != NEURON_NO_ERROR) { - return litert::Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to create execution"); - } - return NeuronExecutionPtr{execution, api().execution_free}; -} - -} // namespace mediatek -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h b/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h deleted file mode 100644 index 7d61d2c027f2d1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter_api.h +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_NEURON_ADAPTER_API_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_NEURON_ADAPTER_API_H_ - -#include -#include -#include - -#include "neuron/api/NeuronAdapter.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" - -#if LITERT_HAS_AHWB_SUPPORT -#include -#else -struct AHardwareBuffer {}; -#endif - -namespace litert::mediatek { - -using NeuronModelPtr = std::unique_ptr; -using NeuronCompilationPtr = - std::unique_ptr; -using NeuronExecutionPtr = - std::unique_ptr; - -class NeuronAdapterApi { - public: - using Ptr = std::unique_ptr; - struct Api; - - NeuronAdapterApi(NeuronAdapterApi&) = delete; - NeuronAdapterApi(NeuronAdapterApi&&) = delete; - NeuronAdapterApi& operator=(const NeuronAdapterApi&) = delete; - NeuronAdapterApi& operator=(NeuronAdapterApi&&) = delete; - - static Expected Create(std::optional shared_library_dir); - - const Api& api() const { return *api_; } - - absl::string_view AotCompileOptions() const { - // Option `import_forever` has been recommended by MediaTek to reduce memory - // footprint when using the same I/O buffers across multiple invocations. - return "--apusys-config \"{ \\\"import_forever\\\": true }\""; - } - - absl::string_view JitCompileOptions() const { return ""; } - - Expected CreateModel() const; - - Expected CreateCompilation(NeuronModel* model) const; - - Expected CreateCompilation( - NeuronModel* model, const std::string& compile_options) const; - - Expected CreateExecution( - NeuronCompilation* compilation) const; - - private: - NeuronAdapterApi(); - litert::Expected LoadSymbols( - std::optional shared_library_dir); - - // Handle to the shared library that implements the Neuron API. - // - // This will keep the shared library open until the NeuronAdapterApi object is - // destroyed. - SharedLibrary dlib_; - std::unique_ptr api_; -}; - -// This is not part of the provided NeuronAdapter header for some reason. -int NeuronCompilation_createWithOptions(NeuronModel* model, - NeuronCompilation** compilation, - const char* options); - -// A convenient struct for holding function pointers to NeuronAdapter API -// symbols. These function pointers will be loaded to the shared library on -// device during runtime. -struct NeuronAdapterApi::Api { - decltype(&NeuronCompilation_create) compilation_create = nullptr; - decltype(&NeuronCompilation_createWithOptions) - compilation_create_with_options = nullptr; - decltype(&NeuronCompilation_finish) compilation_finish = nullptr; - decltype(&NeuronCompilation_free) compilation_free = nullptr; - decltype(&NeuronCompilation_getCompiledNetworkSize) - compilation_get_compiled_network_size = nullptr; - decltype(&NeuronCompilation_getInputPaddedDimensions) - compilation_get_input_padded_dimensions = nullptr; - decltype(&NeuronCompilation_getInputPaddedSize) - compilation_get_input_padded_size = nullptr; - decltype(&NeuronCompilation_getOutputPaddedDimensions) - compilation_get_output_padded_dimensions = nullptr; - decltype(&NeuronCompilation_getOutputPaddedSize) - compilation_get_output_padded_size = nullptr; - decltype(&NeuronCompilation_setOptimizationString) - compilation_set_optimization_string = nullptr; - decltype(&NeuronCompilation_setPreference) compilation_set_preference = - nullptr; - decltype(&NeuronCompilation_setPriority) compilation_set_priority = nullptr; - decltype(&NeuronCompilation_storeCompiledNetwork) - compilation_store_compiled_network = nullptr; - decltype(&NeuronExecution_compute) execution_compute = nullptr; - decltype(&NeuronExecution_create) execution_create = nullptr; - decltype(&NeuronExecution_free) execution_free = nullptr; - decltype(&NeuronExecution_setBoostHint) execution_set_boost_hint = nullptr; - decltype(&NeuronExecution_setInputFromMemory) - execution_set_input_from_memory = nullptr; - decltype(&NeuronExecution_setOutputFromMemory) - execution_set_output_from_memory = nullptr; - decltype(&NeuronMemory_createFromAHardwareBuffer) memory_create_from_ahwb = - nullptr; - decltype(&NeuronMemory_createFromFd) memory_create_from_fd = nullptr; - decltype(&NeuronMemory_free) memory_free = nullptr; - decltype(&NeuronModel_addOperand) model_add_operand = nullptr; - decltype(&NeuronModel_addOperation) model_add_operation = nullptr; - decltype(&NeuronModel_create) model_create = nullptr; - decltype(&NeuronModel_finish) model_finish = nullptr; - decltype(&NeuronModel_free) model_free = nullptr; - decltype(&NeuronModel_getExtensionOperandType) - model_get_extension_operand_type = nullptr; - decltype(&NeuronModel_getExtensionOperationType) - model_get_extension_operation_type = nullptr; - decltype(&NeuronModel_identifyInputsAndOutputs) - model_identify_inputs_and_outputs = nullptr; - decltype(&NeuronModel_restoreFromCompiledNetwork) - model_restore_from_compiled_network = nullptr; - decltype(&NeuronModel_setName) model_set_name = nullptr; - decltype(&NeuronModel_setOperandValue) model_set_operand_value = nullptr; - decltype(&NeuronModel_setOperandSymmPerChannelQuantParams) - model_set_symm_per_channel_quant_params = nullptr; - decltype(&Neuron_getVersion) get_version = nullptr; -}; - -} // namespace litert::mediatek - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_NEURON_ADAPTER_API_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/schema/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/schema/BUILD deleted file mode 100644 index f17fd689d66796..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/schema/BUILD +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2025 MediaTek Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -flatbuffer_cc_library( - name = "neuron_litert_schema", - srcs = ["neuron_schema.fbs"], - compatible_with = get_compatible_with_portable(), -) - -cc_library( - name = "mediatek_litert_schema", - hdrs = [ - "schema_resolver.h", - ], - visibility = ["//visibility:public"], - deps = [ - "neuron_litert_schema", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "@com_google_absl//absl/strings:str_format", - "@flatbuffers//:runtime_cc", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/schema/neuron_schema.fbs b/tensorflow/lite/experimental/litert/vendors/mediatek/schema/neuron_schema.fbs deleted file mode 100644 index 6d515fd7972e7c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/schema/neuron_schema.fbs +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace NeuronSchema; - -enum CompiledType : byte { - DLA = 0, - DLB, - AdapterCache -} - -table Index { - value: int = -1; -} - -table Identifier { - value: string; -} - -// BufferIndicate to specify how to point to a buffer -union BufferIndicate { - Index, - Identifier, -} - -table Subgraph { - entry_point: string; // Entry point of the subgraph - type: CompiledType; // Type of the compiled subgraph - compiled_index: BufferIndicate; // index to the buffer at Graphs.data - weight_share_index: [BufferIndicate]; // index to the buffer at Graphs.data[index]. if empty, no weight share. -} - -table Graphs { - version: short; // Version of the graph schema - subgraphs: [Subgraph]; // List of subgraphs - data: [Buffer]; - external: ExternalBuffer; -} - -table Buffer { - identifier: string; - data: [byte]; // Binary data -} - -// List of external buffer that doesn't store in this schema -table ExternalBuffer { - identifiers: [string]; -} - -root_type Graphs; \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h b/tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h deleted file mode 100644 index 9fd871da7e0f0e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/schema/schema_resolver.h +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright (c) 2025 MediaTek Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_SCHEMA_SCHEMA_RESOLVER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_SCHEMA_SCHEMA_RESOLVER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "flatbuffers/buffer.h" // from @flatbuffers -#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers -#include "flatbuffers/verifier.h" // from @flatbuffers -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/mediatek/schema/neuron_schema_generated.h" - -namespace neuron { - -inline bool IsNeuronSchema(const uint8_t* buffer, size_t size) { - if (buffer == nullptr) { - return false; - } - flatbuffers::Verifier verifier(buffer, size); - return NeuronSchema::VerifyGraphsBuffer(verifier); -} - -class CompiledGraph { - public: - CompiledGraph(const NeuronSchema::Graphs& g, const NeuronSchema::Subgraph& s) - : graph_(g), subgraph_(s) {}; - - litert::Expected> GetCompiledNetwork() { - // Neuron Adapter doesn't support DLB for now. - assert(GetCompiledType() != NeuronSchema::CompiledType_DLB); - // TODO: Support the external buffer. - assert(subgraph_.compiled_index_type() == - NeuronSchema::BufferIndicate_Index); - auto index = subgraph_.compiled_index_as_Index(); - return GetBuffer(index->value()); - } - - NeuronSchema::CompiledType GetCompiledType() { return subgraph_.type(); } - - litert::Expected> GetBuffer(int32_t i) { - auto array_size = graph_.data()->size(); - if (i >= array_size) { - return litert::Error( - kLiteRtStatusErrorIndexOOB, - absl::StrFormat("Buffer array index %d is OOB, the array size : %d", - i, array_size)); - } - auto buffer = graph_.data()->Get(i); - return std::pair(buffer->data()->data(), - buffer->data()->size()); - } - - private: - const NeuronSchema::Graphs& graph_; - const NeuronSchema::Subgraph& subgraph_; -}; - -class SchemaResolver { - public: - SchemaResolver() = default; - - litert::Expected Initialize(const uint8_t* buffer, size_t size) { - if (!IsNeuronSchema(buffer, size)) { - return litert::Error(kLiteRtStatusErrorInvalidFlatbuffer, - "buffer is not a valid NeuronSchema"); - } - graph_ = NeuronSchema::GetGraphs(buffer); - - auto subgraphs = graph_->subgraphs(); - for (const auto& subgraph : *subgraphs) { - auto graph_name = subgraph->entry_point()->str(); - if (entry_points_.count(graph_name)) { - // shouldn't have the same name between graphs. - return false; - } else { - LITERT_LOG(LITERT_INFO, "Found graph: %s", graph_name.c_str()); - entry_points_[graph_name] = subgraph; - } - } - LITERT_LOG(LITERT_INFO, "There are %u subgraphs in the bytecode", - entry_points_.size()); - return true; - } - - std::optional GetCompiledGraph(std::string& name) { - if (entry_points_.count(name) == 0) { - return std::nullopt; - } - return CompiledGraph(*graph_, *entry_points_[name]); - }; - - private: - const NeuronSchema::Graphs* graph_ = nullptr; - - std::unordered_map entry_points_; -}; - -class BytecodeBuilder { - public: - BytecodeBuilder() = default; - - int32_t AddCompiledNetwork(std::string& entry_point, - NeuronSchema::CompiledType type, - int32_t buffer_index) { - auto index = NeuronSchema::CreateIndex(fb_, buffer_index); - auto subgraph = NeuronSchema::CreateSubgraph( - fb_, fb_.CreateString(entry_point), type, - NeuronSchema::BufferIndicate_Index, index.Union()); - - subgraphs_.push_back(subgraph); - return subgraphs_count_++; - }; - - int32_t AddBuffer(std::string& identifier, const std::vector& data) { - auto buffer = - NeuronSchema::CreateBufferDirect(fb_, identifier.c_str(), &data); - graph_data_.push_back(buffer); - return buffer_count_++; - } - - int32_t AddBuffer(std::string& identifier, const int8_t* data, - size_t length) { - auto data_offset = fb_.CreateVector(data, length); - auto identifier_offset = fb_.CreateString(identifier); - auto buffer = - NeuronSchema::CreateBuffer(fb_, identifier_offset, data_offset); - graph_data_.push_back(buffer); - return buffer_count_++; - } - - bool Finish() { - auto graphs = - NeuronSchema::CreateGraphsDirect(fb_, 1, &subgraphs_, &graph_data_); - fb_.Finish(graphs); - raw_buffer_ = {fb_.GetBufferPointer(), fb_.GetSize()}; - return true; - } - - std::pair GetBytecode() { - if (!raw_buffer_.has_value()) { - return {nullptr, 0}; - } - return raw_buffer_.value(); - } - - private: - ::flatbuffers::FlatBufferBuilder fb_; - - std::optional> raw_buffer_; - - std::vector<::flatbuffers::Offset> subgraphs_; - - std::vector<::flatbuffers::Offset> graph_data_; - - int32_t subgraphs_count_ = 0; - int32_t buffer_count_ = 0; -}; - -}; // namespace neuron - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_SCHEMA_SCHEMA_RESOLVER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/supported_soc.csv b/tensorflow/lite/experimental/litert/vendors/mediatek/supported_soc.csv deleted file mode 100644 index 0d792650611649..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/supported_soc.csv +++ /dev/null @@ -1,17 +0,0 @@ -# manufacturer,model,android_api_level -Mediatek,MT6897,UNKNOWN -Mediatek,MT6895Z_A/TCZA,UNKNOWN -Mediatek,MT6985,UNKNOWN -Mediatek,MT6989,UNKNOWN -Mediatek,MT6983,UNKNOWN -Mediatek,MT6895Z/TCZA,UNKNOWN -Mediatek,MT6895Z_B/TCZA,UNKNOWN -Mediatek,MT6991,UNKNOWN -Mediatek,MT6983Z/CZA,UNKNOWN -Mediatek,MT6983W/CZA,UNKNOWN -Mediatek,MT6895,UNKNOWN -Mediatek,MT6983Z/TCZA,UNKNOWN -Mediatek,MT6991(ENG),UNKNOWN -Mediatek,MT6895Z/CZA,UNKNOWN -Mediatek,MT6989(ENG),UNKNOWN -Mediatek,MT6985(ENG),UNKNOWN diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD deleted file mode 100644 index 6a76dc9594a1c3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_lib", "litert_test") -load("//tensorflow/lite/experimental/litert/vendors/qualcomm:qualcomm_build_defs.bzl", "litert_cc_lib_with_qnn") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "common", - hdrs = ["common.h"], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model", - ], -) - -litert_lib( - name = "qnn_log", - srcs = ["qnn_log.cc"], - hdrs = ["qnn_log.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - ], -) - -cc_library( - name = "qnn_manager_hdr", - hdrs = ["qnn_manager.h"], - deps = [ - ":common", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - ], -) - -litert_cc_lib_with_qnn( - name = "qnn_manager", - srcs = [ - "qnn_manager.cc", - ], - hdrs = ["qnn_manager.h"], - include_system = True, - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - ungrte = True, - deps = [ - ":common", - ":qnn_log", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_shared_library", - "//tensorflow/lite/experimental/litert/core:dynamic_loading", - ], -) - -litert_test( - name = "qnn_manager_test", - srcs = ["qnn_manager_test.cc"], - linkstatic = True, - tags = [ - # Tests with ungrte deps do not currently work on forge. - "no-remote-exec", - "notap", - # Don't build/test in OS until qnn is available. - "nobuilder", - "no_oss", - # Sanitizer runtime doesn't work with anything that loads libQnnHtp.so. - "nosan", - ], - # This test can be run only on Android and Linux. - target_compatible_with = select({ - "@platforms//os:android": [], - "@platforms//os:linux": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - deps = [ - ":qnn_manager", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers_oss", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/tools:dump", - ], -) - -cc_library( - name = "context_binary_info", - srcs = ["context_binary_info.cc"], - hdrs = ["context_binary_info.h"], - deps = [ - ":qnn_manager", - ":qnn_tensor", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - ], -) - -cc_library( - name = "qnn_tensor", - srcs = ["qnn_tensor.cc"], - hdrs = ["qnn_tensor.h"], - deps = [ - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/common.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/common.h deleted file mode 100644 index 34b8971460c466..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/common.h +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMMON_H_ - -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -#define LITERT_RETURN_STATUS_IF_QNN_NOT_OK(expr) \ - if (QNN_SUCCESS != (expr)) { \ - return kLiteRtStatusErrorNotFound; \ - } - -// Pointers to functions of a dynamically loaded QNN library. -typedef QNN_INTERFACE_VER_TYPE QnnApi; - -// Pointers to functions of a dynamically loaded QNN system library. -typedef QNN_SYSTEM_INTERFACE_VER_TYPE QnnSystemApi; - -// QNN backend library should be on DT_RUNPATH (-rpath). -static const char kLibQnnHtpSo[] = "libQnnHtp.so"; - -// QNN backend library should be on DT_RUNPATH (-rpath). -static const char kLibQnnSystemSo[] = "libQnnSystem.so"; - -// Map LiteRT element type to Qnn counterpart. -inline LiteRtStatus LegalizeElementType(litert::ElementType litert_type, - Qnn_DataType_t* qnn_type) { - switch (litert_type) { - case litert::ElementType::Bool: - *qnn_type = QNN_DATATYPE_BOOL_8; - break; - case litert::ElementType::Int4: - *qnn_type = QNN_DATATYPE_SFIXED_POINT_4; - break; - case litert::ElementType::Int8: - *qnn_type = QNN_DATATYPE_INT_8; - break; - case litert::ElementType::Int16: - *qnn_type = QNN_DATATYPE_INT_16; - break; - case litert::ElementType::Int32: - *qnn_type = QNN_DATATYPE_INT_32; - break; - case litert::ElementType::Int64: - *qnn_type = QNN_DATATYPE_INT_64; - break; - case litert::ElementType::UInt8: - *qnn_type = QNN_DATATYPE_UINT_8; - break; - case litert::ElementType::UInt16: - *qnn_type = QNN_DATATYPE_UINT_16; - break; - case litert::ElementType::UInt32: - *qnn_type = QNN_DATATYPE_UINT_32; - break; - case litert::ElementType::UInt64: - *qnn_type = QNN_DATATYPE_UINT_64; - break; - case litert::ElementType::Float16: - *qnn_type = QNN_DATATYPE_FLOAT_16; - break; - case litert::ElementType::Float32: - *qnn_type = QNN_DATATYPE_FLOAT_32; - break; - case litert::ElementType::Float64: - *qnn_type = QNN_DATATYPE_FLOAT_64; - break; - default: - return kLiteRtStatusErrorUnsupported; - } - return kLiteRtStatusOk; -} - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMMON_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD deleted file mode 100644 index 6c54037f0803a1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib", "litert_lib", "litert_test") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:private"], -) - -litert_dynamic_lib( - name = "qnn_compiler_plugin", - srcs = ["qnn_compiler_plugin.cc"], - hdrs = ["//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h"], - export_litert_only = True, - shared_lib_name = "qnn_compiler_plugin_so", - so_name = "libLiteRtCompilerPlugin_Qualcomm.so", - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - ungrte = True, - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":qnn_compose_graph", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -litert_test( - name = "qnn_compiler_plugin_test", - srcs = [ - "qnn_compiler_plugin_test.cc", - ], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - ], - linkstatic = True, - tags = [ - # Tests with ungrte deps do not currently work on forge. - "no-remote-exec", - "notap", - # Don't build/test in OS until qnn is available. - "nobuilder", - "no_oss", - # Sanitizer runtime doesn't work with anything that loads libQnnHtp.so. - "nosan", - ], - # This test can be run only on Android and Linux. - target_compatible_with = select({ - "@platforms//os:android": [], - "@platforms//os:linux": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - use_sys_malloc = True, - deps = [ - ":qnn_compiler_plugin", # buildcleaner: keep - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers_oss", - "//tensorflow/lite/experimental/litert/test:test_models", - "//tensorflow/lite/experimental/litert/vendors/cc:litert_compiler_plugin", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:quantize_op_legalization", - ], -) - -litert_lib( - name = "qnn_compose_graph", - srcs = ["qnn_compose_graph.cc"], - hdrs = ["qnn_compose_graph.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":graph_mapper", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/tools:dump", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:cast_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:concatenation_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:conv2d_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:depthwise_conv2d_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:dynamic_update_slice_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:elementwise_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:embedding_lookup_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:fully_connected_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:fully_connected_op_builder_htp", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:gather_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:gelu_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:hard_swish_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:leaky_relu_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:matmul_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:mean_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:pack_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:pool2d_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:quantize_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:reduce_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:relu6_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:relu_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:reshape_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:resize_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:rms_norm_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:select_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:slice_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:softmax_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:spatial_transform_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:split_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:tanh_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders:transpose_op_builder", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -litert_lib( - name = "graph_mapper", - srcs = [ - "graph_mapper.cc", - ], - hdrs = ["graph_mapper.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD deleted file mode 100644 index fa0e3f55e19b93..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:__subpackages__"], -) - -cc_library( - name = "qnn_tensor", - srcs = ["qnn_tensor.cc"], - hdrs = ["qnn_tensor.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - ], -) - -cc_test( - name = "qnn_tensor_test", - srcs = ["qnn_tensor_test.cc"], - data = [ - "//tensorflow/lite/experimental/litert/test:mlir_test_data", - "//tensorflow/lite/experimental/litert/test:tflite_test_data", - ], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - "no_oss", - ], - deps = [ - ":qnn_tensor", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - "//tensorflow/lite/experimental/litert/test:test_models", - ], -) - -cc_library( - name = "qnn_op", - srcs = ["qnn_op.cc"], - hdrs = ["qnn_op.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_model", - ], -) - -cc_test( - name = "qnn_op_test", - srcs = ["qnn_op_test.cc"], - data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - "no_oss", - ], - deps = [ - ":qnn_op", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - ], -) - -cc_test( - name = "op_compatibility_test", - srcs = ["op_compatibility_test.cc"], - data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - "no_oss", - ], - deps = [ - ":qnn_op", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:matchers", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc deleted file mode 100644 index 477711417441f9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include "absl/strings/match.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" - -namespace { - -static constexpr absl::string_view kOpTpl = "simple_%s_op.tflite"; -struct OpInfo { - std::string op_name; - std::string expected_type_name; -}; - -// TODOL: b/365299994 - Add "stablehlo_scatter" once muti subgraphs is -// supported. -// clang-format off -const auto kSupportedOps = testing::Values( - OpInfo{"add", "ElementWiseAdd"}, - OpInfo{"mul", "ElementWiseMultiply"}, - OpInfo{"batch_matmul", "MatMul"}, - OpInfo{"concatenation", "Concat"}, - OpInfo{"div", "ElementWiseDivide"}, - OpInfo{"fully_connected", "FullyConnected"}, - OpInfo{"reshape", "Reshape"}, - OpInfo{"rsqrt", "ElementWiseRsqrt"}, - OpInfo{"select_v2", "ElementWiseSelect"}, - OpInfo{"select", "ElementWiseSelect"}, - OpInfo{"strided_slice", "StridedSlice"}, - OpInfo{"slice", "StridedSlice"}, - OpInfo{"softmax", "Softmax"}, - OpInfo{"sub", "ElementWiseSubtract"}, - OpInfo{"tanh", "Tanh"}, - OpInfo{"transpose", "Transpose"}); -// clang-format on - -class OpCompatibilityTest : public ::testing::TestWithParam {}; - -TEST_P(OpCompatibilityTest, SupportedOpsTest) { - auto test_params = GetParam(); - std::string model_path = absl::StrFormat(kOpTpl, test_params.op_name); - auto model = litert::testing::LoadTestFileModel(model_path); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto ops = subgraph->Ops(); - - Qnn_OpConfig_t qnn_op = litert::qnn::BuildDefaultOp(); - LITERT_ASSERT_OK(litert::qnn::LegalizeOp(ops.front().Get(), qnn_op)); - - EXPECT_TRUE(absl::StrContains(qnn_op.v1.name, test_params.op_name)); - EXPECT_STREQ(qnn_op.v1.packageName, "qti.aisw"); - EXPECT_STREQ(qnn_op.v1.typeName, test_params.expected_type_name.c_str()); - - EXPECT_EQ(qnn_op.v1.numOfInputs, 0); - EXPECT_EQ(qnn_op.v1.numOfOutputs, 0); - EXPECT_EQ(qnn_op.v1.numOfParams, 0); - - litert::qnn::ResetOp(qnn_op); -} - -INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, OpCompatibilityTest, kSupportedOps); - -} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc deleted file mode 100644 index 0a6949afaf7807..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -// A macro dance to create a unique literal string given a prefix. -#define STRINGIFY(x) #x -#define QNN_OP_NAME(prefix) STRINGIFY(prefix##__COUNTER) - -namespace litert::qnn { - -namespace { - -// Maps "op-code" related information (name, packageName, typeName) from src -// to dest. -LiteRtStatus LegalizeOpType(const Op& src, Qnn_OpConfig_t& dest) { - switch (src.Code()) { - case kLiteRtOpCodeTflMul: - dest.v1.name = QNN_OP_NAME(mul_); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseMultiply"; - break; - case kLiteRtOpCodeTflAdd: - dest.v1.name = QNN_OP_NAME("add"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseAdd"; - break; - case kLiteRtOpCodeTflBatchMatmul: - dest.v1.name = QNN_OP_NAME("batch_matmul"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "MatMul"; - break; - case kLiteRtOpCodeTflConcatenation: - dest.v1.name = QNN_OP_NAME("concatenation"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "Concat"; - break; - case kLiteRtOpCodeTflDiv: - dest.v1.name = QNN_OP_NAME("div"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseDivide"; - break; - case kLiteRtOpCodeTflFullyConnected: - dest.v1.name = QNN_OP_NAME("fully_connected"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "FullyConnected"; - break; - case kLiteRtOpCodeTflReshape: - dest.v1.name = QNN_OP_NAME("reshape"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "Reshape"; - break; - case kLiteRtOpCodeTflRsqrt: - dest.v1.name = QNN_OP_NAME("rsqrt"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseRsqrt"; - break; - case kLiteRtOpCodeTflSelectV2: - dest.v1.name = QNN_OP_NAME("select_v2"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseSelect"; - break; - case kLiteRtOpCodeTflSelect: - dest.v1.name = QNN_OP_NAME("select"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseSelect"; - break; - case kLiteRtOpCodeTflStridedSlice: - dest.v1.name = QNN_OP_NAME("strided_slice"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "StridedSlice"; - break; - case kLiteRtOpCodeTflSlice: - dest.v1.name = QNN_OP_NAME("slice"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "StridedSlice"; - break; - case kLiteRtOpCodeTflSoftmax: - dest.v1.name = QNN_OP_NAME("softmax"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "Softmax"; - break; - case kLiteRtOpCodeTflSub: - dest.v1.name = QNN_OP_NAME("sub"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "ElementWiseSubtract"; - break; - case kLiteRtOpCodeTflTanh: - dest.v1.name = QNN_OP_NAME("tanh"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "Tanh"; - break; - case kLiteRtOpCodeTflTranspose: - dest.v1.name = QNN_OP_NAME("transpose"); - dest.v1.packageName = "qti.aisw"; - dest.v1.typeName = "Transpose"; - break; - default: - return kLiteRtStatusErrorUnsupported; - } - return kLiteRtStatusOk; -} - -} // namespace - -Qnn_OpConfig_t BuildDefaultOp() { - Qnn_OpConfig_t op = QNN_OPCONFIG_INIT; - ResetOp(op); - return op; -} -Qnn_Param_t BuildDefaultParam() { - Qnn_Param_t param = QNN_PARAM_INIT; - ResetParam(param); - return param; -} - -void ResetOp(Qnn_OpConfig_t& op) { - op = QNN_OPCONFIG_INIT; - op.version = QNN_OPCONFIG_VERSION_1; - op.v1 = QNN_OPCONFIG_V1_INIT; -} - -void ResetParam(Qnn_Param_t& param) { param = QNN_PARAM_INIT; } -LiteRtStatus LegalizeOp(LiteRtOp src, Qnn_OpConfig_t& dest) { - ResetOp(dest); - Op op(src); - return LegalizeOpType(op, dest); -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h deleted file mode 100644 index 20e0f27f798b98..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_OP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_OP_H_ - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" - -namespace litert::qnn { - -// -// Initialize QNN Op. -// - -// NOTE: Any referential data within a QNN Op -// is allocated with "new" and must be explicitly cleaned up with ResetOp. - -// Construct a "blank" QNN Op. -Qnn_OpConfig_t BuildDefaultOp(); - -// Construct a "blank" QNN Param. -Qnn_Param_t BuildDefaultParam(); - -// Reset the given tensor, deallocating anything on the heap that it points to. -void ResetOp(Qnn_OpConfig_t& op); - -// Reset the given param, deallocating anything on the heap that it points to. -void ResetParam(Qnn_Param_t& param); - -// -// Legalize LiteRt Op to Analogous QNN Construct. -// - -// Map src op onto dest. Resets dest before doing anything. This only handles -// attribute-like info. It does not set edges (in/out tensors). -LiteRtStatus LegalizeOp(LiteRtOp src, Qnn_OpConfig_t& dest); - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_OP_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc deleted file mode 100644 index dd78cfca40b88c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" - -#include -#include -#include "absl/strings/match.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" - -namespace { - -using testing::litert::IsError; - -TEST(TestInitQnnOp, BuildDefaultOp) { - Qnn_OpConfig_t op = litert::qnn::BuildDefaultOp(); - ASSERT_EQ(op.version, QNN_OPCONFIG_VERSION_1); -} - -TEST(TestLegalizeOp, SimpleSupportedOp) { - auto model = litert::testing::LoadTestFileModel("one_mul.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto ops = subgraph->Ops(); - - Qnn_OpConfig_t qnn_op = litert::qnn::BuildDefaultOp(); - LITERT_ASSERT_OK(litert::qnn::LegalizeOp(ops.front().Get(), qnn_op)); - - EXPECT_TRUE(absl::StrContains(qnn_op.v1.name, "mul")); - EXPECT_STREQ(qnn_op.v1.packageName, "qti.aisw"); - EXPECT_STREQ(qnn_op.v1.typeName, "ElementWiseMultiply"); - - EXPECT_EQ(qnn_op.v1.numOfInputs, 0); - EXPECT_EQ(qnn_op.v1.numOfOutputs, 0); - EXPECT_EQ(qnn_op.v1.numOfParams, 0); - - litert::qnn::ResetOp(qnn_op); -} - -TEST(TestLegalizeOp, UnsupportedOp) { - auto model = litert::testing::LoadTestFileModel("simple_floor_mod_op.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto ops = subgraph->Ops(); - - Qnn_OpConfig_t qnn_op = litert::qnn::BuildDefaultOp(); - EXPECT_THAT(litert::qnn::LegalizeOp(ops.front().Get(), qnn_op), - IsError(kLiteRtStatusErrorUnsupported)); - - litert::qnn::ResetOp(qnn_op); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc deleted file mode 100644 index 4a308f6da78012..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" - -#include - -#include "absl/log/absl_check.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" - -namespace litert::qnn { - -namespace { - -LiteRtStatus LegalizeShapeInfo(const litert::Layout& src, Qnn_Tensor_t& dest) { - LITERT_ENSURE_SUPPORTED(!src.HasStrides(), "Strides not yet supported"); - - dest.v2.rank = src.Rank(); - // Ad-hoc fix: rank 0 tensor needs to be single element 1D tensor in QNN. - if (dest.v2.rank == 0) { - LITERT_LOG(LITERT_INFO, "Setting rank 0 tensor to single element tensor"); - dest.v2.rank = 1; - dest.v2.dimensions = new uint32_t[1]; - dest.v2.dimensions[0] = 1; - return kLiteRtStatusOk; - } - - dest.v2.dimensions = new uint32_t[dest.v2.rank]; - for (int i = 0; i < dest.v2.rank; ++i) { - const auto src_dim = src.Dimensions()[i]; - LITERT_ENSURE(src_dim >= 1, kLiteRtStatusErrorInvalidArgument, - "Cannot pass dim < 1 to QNN Tensor."); - - dest.v2.dimensions[i] = src.Dimensions()[i]; - } - return kLiteRtStatusOk; -} - -void FreeTensorDims(Qnn_Tensor_t& tensor) { - if (tensor.version == QNN_TENSOR_VERSION_2 && - tensor.v2.dimensions != nullptr) { - delete[] tensor.v2.dimensions; - tensor.v2.dimensions = nullptr; - tensor.v2.rank = 0; - } -} - -void FreePerChannelQuantization(Qnn_Tensor_t& tensor) { - if (tensor.v2.quantizeParams.quantizationEncoding == - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { - delete[] tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset; - tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset = nullptr; - tensor.v2.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets = 0; - } -} - -} // namespace - -void SetInputTensorAttrs(Qnn_Tensor_t& tensor) { - ABSL_DCHECK(tensor.version == QNN_TENSOR_VERSION_2); - tensor.v2.type = QNN_TENSOR_TYPE_APP_WRITE; - tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; - tensor.v2.clientBuf = QNN_CLIENT_BUFFER_INIT; -} - -void SetOutputTensorAttrs(Qnn_Tensor_t& tensor) { - ABSL_DCHECK(tensor.version == QNN_TENSOR_VERSION_2); - tensor.v2.type = QNN_TENSOR_TYPE_APP_READ; -} - -void SetResultTensorAttrs(Qnn_Tensor_t& tensor) { - ABSL_DCHECK(tensor.version == QNN_TENSOR_VERSION_2); - tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; - tensor.v2.type = QNN_TENSOR_TYPE_NATIVE; -} - -void ResetTensor(Qnn_Tensor_t& tensor) { - FreeTensorDims(tensor); - FreePerChannelQuantization(tensor); - tensor = QNN_TENSOR_INIT; - tensor.version = QNN_TENSOR_VERSION_2; - tensor.v2 = QNN_TENSOR_V2_INIT; - tensor.v2.dataFormat = QNN_TENSOR_DATA_FORMAT_DENSE; - tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; -} - -Qnn_Tensor_t BuildDefaultTensor(uint32_t id) { - Qnn_Tensor_t tensor = QNN_TENSOR_INIT; - ResetTensor(tensor); - tensor.v2.id = id; - return tensor; -} - -Qnn_Tensor_t BuildDefaultTensor() { return BuildDefaultTensor(0); } - -Qnn_Tensor_t BuildInputTensor() { - auto tensor = BuildDefaultTensor(); - SetInputTensorAttrs(tensor); - return tensor; -} - -Qnn_ClientBuffer_t BuildDefaultClientBuffer() { - Qnn_ClientBuffer_t client_buf = QNN_CLIENT_BUFFER_INIT; - client_buf.data = nullptr; - client_buf.dataSize = 0; - return client_buf; -} - -Qnn_Tensor_t BuildOutputTensor() { - Qnn_Tensor_t tensor = BuildDefaultTensor(); - SetOutputTensorAttrs(tensor); - return tensor; -} - -uint32_t MoveToId(Qnn_Tensor_t& tensor) { - const auto id = tensor.v2.id; - ResetTensor(tensor); - tensor.v2.id = id; - return id; -} - -void SetPerChannelQuantization( - Qnn_Tensor_t& tensor, - const LiteRtQuantizationPerChannel& lite_rt_quantization_per_channel) { - tensor.v2.quantizeParams.quantizationEncoding = - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; - - tensor.v2.quantizeParams.axisScaleOffsetEncoding = QNN_AXIS_SCALE_OFFSET_INIT; - tensor.v2.quantizeParams.axisScaleOffsetEncoding.axis = - lite_rt_quantization_per_channel.quantized_dimension; - tensor.v2.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets = - lite_rt_quantization_per_channel.num_channels; - - // Allocates memory for scaleOffset array. - tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset = - new Qnn_ScaleOffset_t[lite_rt_quantization_per_channel.num_channels]; - - for (int i = 0; i < lite_rt_quantization_per_channel.num_channels; ++i) { - tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i].scale = - lite_rt_quantization_per_channel.scales[i]; - tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i].offset = - lite_rt_quantization_per_channel.zero_points[i]; - } -} - -void SetPerTensorQuantization( - Qnn_Tensor_t& tensor, - const LiteRtQuantizationPerTensor& lite_rt_quantization_per_tensor) { - tensor.v2.quantizeParams.quantizationEncoding = - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; - tensor.v2.quantizeParams.scaleOffsetEncoding.scale = - lite_rt_quantization_per_tensor.scale; - tensor.v2.quantizeParams.scaleOffsetEncoding.offset = - lite_rt_quantization_per_tensor.zero_point; -} - -LiteRtStatus LegalizeQuntizationParameter(const litert::Tensor& src, - Qnn_Tensor_t& dest) { - LiteRtQuantizationTypeId lite_rt_quantization_type_id = src.QTypeId(); - switch (lite_rt_quantization_type_id) { - case kLiteRtQuantizationPerTensor: - SetPerTensorQuantization(dest, src.PerTensorQuantization()); - return kLiteRtStatusOk; - case kLiteRtQuantizationPerChannel: - SetPerChannelQuantization(dest, src.PerChannelQuantization()); - return kLiteRtStatusOk; - default: - LITERT_LOG(LITERT_ERROR, "Unsupported quantization type."); - return kLiteRtStatusErrorInvalidArgument; - } -} - -LiteRtStatus LegalizeTensor(const litert::Tensor& src, Qnn_Tensor_t& dest) { - if (src.TypeId() != kLiteRtRankedTensorType) { - return kLiteRtStatusErrorInvalidArgument; - } - - ResetTensor(dest); - - if (src.HasQuantization()) { - LITERT_RETURN_IF_ERROR(LegalizeQuntizationParameter(src, dest)); - } - - auto src_ranked_tensor_type = src.RankedTensorType(); - if (!src_ranked_tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", - src_ranked_tensor_type.Error().Message().c_str()); - return src_ranked_tensor_type.Error().Status(); - } - - Qnn_DataType_t* qnn_data_type = &dest.v2.dataType; - LITERT_RETURN_IF_ERROR(LegalizeElementType( - src_ranked_tensor_type->ElementType(), qnn_data_type)); - - LITERT_RETURN_IF_ERROR( - LegalizeShapeInfo(src_ranked_tensor_type->Layout(), dest)); - - const bool is_subgraph_in = src.IsSubgraphInput(); - const bool is_subgraph_out = src.IsSubgraphOutput(); - const bool is_constant = src.IsConstant(); - - LITERT_ENSURE(!(is_subgraph_in && is_subgraph_out), - kLiteRtStatusErrorInvalidArgument, - "Malformed tensor, cannot be both subgraph in and out."); - if (is_constant) { - LITERT_LOG(LITERT_INFO, "Adding constant tensor %s to qnn graph", - dest.v2.name); - LITERT_ENSURE(src.HasWeights(), kLiteRtStatusErrorInvalidLegalization, - "Empty weights for constant tensor."); - Qnn_ClientBuffer_t client_buf = BuildDefaultClientBuffer(); - client_buf.data = (void*)src.Weights().Bytes().data(); - client_buf.dataSize = src.Weights().Bytes().size(); - dest.v2.clientBuf = client_buf; - dest.v2.memType = QNN_TENSORMEMTYPE_RAW; - dest.v2.type = QNN_TENSOR_TYPE_STATIC; - dest.v2.isDynamicDimensions = nullptr; - } - - if (is_subgraph_in) { - LITERT_LOG(LITERT_INFO, "Adding subgraph input tensor to qnn graph"); - SetInputTensorAttrs(dest); - } - if (is_subgraph_out) { - LITERT_LOG(LITERT_INFO, "Adding subgraph output tensor to qnn graph"); - SetOutputTensorAttrs(dest); - } - if (!is_constant && !is_subgraph_in && !is_subgraph_out) { - LITERT_LOG(LITERT_INFO, "Adding result tensor to qnn graph"); - SetResultTensorAttrs(dest); - } - - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h deleted file mode 100644 index 607cc4c3decba9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_TENSOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_TENSOR_H_ - -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" - -namespace litert::qnn { - -// -// Initialize QNN Tensors. -// - -// NOTE: Within LiteRt land, all Qnn Tensors are treated as "v2". Any -// referential data (like dimensions : uint32_t*) within a QNN Tensor -// is allocated with "new" and must be explicitly cleaned up with ResetTensor. - -// Construct a "blank" QNN Tensor. -Qnn_Tensor_t BuildDefaultTensor(); - -// Construct a "blank" QNN Tensor with given id. -Qnn_Tensor_t BuildDefaultTensor(uint32_t id); - -// Constructa a "blank" QNN Tensor meant to be used as a graph input. -Qnn_Tensor_t BuildInputTensor(); - -// Constructa a "blank" QNN Tensor meant to be used as a graph output. -Qnn_Tensor_t BuildOutputTensor(); - -Qnn_ClientBuffer_t BuildDefaultClientBuffer(); - -// Adds attributes to given tensor making it amenable for use as graph input. -void SetInputTensorAttrs(Qnn_Tensor_t& tensor); - -// Adds attributes to given tensor making it amenable for use as graph output. -void SetOutputTensorAttrs(Qnn_Tensor_t& tensor); - -// Adds attributes to given tensor making it amenable for uses a intermediate -// output. -void SetResultTensorAttrs(Qnn_Tensor_t& tensor); - -// Reset the given tensor, deallocating anything on the heap that it points to. -void ResetTensor(Qnn_Tensor_t& tensor); - -// Resets all fields other than id in the given tensor and returns the id for -// convenience. Only the id is needed to traffic QNN Tensors after they have -// been registered with the context. -uint32_t MoveToId(Qnn_Tensor_t& tensor); - -// -// Legalize LiteRt Tensors to Analogous QNN Construct. -// - -// Map src tensor onto dest. Resets dest before doing anything. -LiteRtStatus LegalizeTensor(const litert::Tensor& src, Qnn_Tensor_t& dest); - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_TENSOR_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc deleted file mode 100644 index ba38fd211457a8..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" - -#include -#include -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/test_models.h" - -namespace { - -constexpr float kSimpleMulQuantModelOutputScale = 0.00028621565f; -constexpr float kSimpleMulQuantModelOutputOffset = 0; - -TEST(TestInitQnnTensor, BuildDefaultTensor) { - Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_DENSE); - EXPECT_EQ(tensor.v2.rank, 0); - EXPECT_EQ(tensor.v2.dimensions, nullptr); - EXPECT_EQ(tensor.v2.id, 0); -} - -TEST(TestInitQnnTensor, BuildDefaultTensorWithId) { - Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(2); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_DENSE); - EXPECT_EQ(tensor.v2.rank, 0); - EXPECT_EQ(tensor.v2.dimensions, nullptr); - EXPECT_EQ(tensor.v2.id, 2); -} - -TEST(TestInitQnnTensor, BuildDefaultInputTensor) { - Qnn_Tensor_t tensor = litert::qnn::BuildInputTensor(); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_WRITE); - EXPECT_EQ(tensor.v2.memType, QNN_TENSORMEMTYPE_RAW); - EXPECT_EQ(tensor.v2.clientBuf.dataSize, 0); -} - -TEST(TestInitQnnTensor, SetInputTensor) { - Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(); - litert::qnn::SetInputTensorAttrs(tensor); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_WRITE); - EXPECT_EQ(tensor.v2.memType, QNN_TENSORMEMTYPE_RAW); - EXPECT_EQ(tensor.v2.clientBuf.dataSize, 0); -} - -TEST(TestInitQnnTensor, BuildDefaultOutputTensor) { - Qnn_Tensor_t tensor = litert::qnn::BuildOutputTensor(); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); -} - -TEST(TestInitQnnTensor, SetOutputTensor) { - Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(); - litert::qnn::SetOutputTensorAttrs(tensor); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); -} - -TEST(TestInitQnnTensor, MoveToId) { - Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(2); - - litert::qnn::SetOutputTensorAttrs(tensor); - ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); - - EXPECT_EQ(litert::qnn::MoveToId(tensor), 2); - EXPECT_EQ(tensor.v2.id, 2); - EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_UNDEFINED); -} - -TEST(TestLegalizeTensor, SimpleSupportedTensorSubgraphInput) { - auto model = litert::testing::LoadTestFileModel("one_mul.tflite"); - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto outputs = subgraph->Outputs(); - - auto qnn_tensor = litert::qnn::BuildDefaultTensor(); - const auto& output_tensor = outputs.front(); - LITERT_ASSERT_OK(litert::qnn::LegalizeTensor(output_tensor, qnn_tensor)); - - ASSERT_EQ(qnn_tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_FLOAT_32); - EXPECT_EQ(qnn_tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); - - ASSERT_EQ(qnn_tensor.v2.rank, 2); - ASSERT_NE(qnn_tensor.v2.dimensions, nullptr); - EXPECT_THAT(absl::MakeConstSpan(qnn_tensor.v2.dimensions, 2), - ::testing::ElementsAreArray({2, 2})); - - litert::qnn::ResetTensor(qnn_tensor); -} - -TEST(TestLegalizeTensor, SimpleSupportedTensor) { - auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); - - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto ops = subgraph->Ops(); - auto op_outs = ops.at(1).Outputs(); - - auto qnn_tensor = litert::qnn::BuildDefaultTensor(); - const auto& op_out = op_outs.front(); - LITERT_ASSERT_OK(litert::qnn::LegalizeTensor(op_out, qnn_tensor)); - - ASSERT_EQ(qnn_tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_FLOAT_32); - EXPECT_EQ(qnn_tensor.v2.type, QNN_TENSOR_TYPE_NATIVE); - - ASSERT_EQ(qnn_tensor.v2.rank, 2); - ASSERT_NE(qnn_tensor.v2.dimensions, nullptr); - EXPECT_THAT(absl::MakeConstSpan(qnn_tensor.v2.dimensions, 2), - ::testing::ElementsAreArray({2, 2})); - - litert::qnn::ResetTensor(qnn_tensor); -} - -TEST(TestLegalizeTensor, SimpleQuantizedTensor) { - auto model = litert::testing::LoadTestFileModel(kQSimpleMul16x16Model); - - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto ops = subgraph->Ops(); - auto op_outs = ops.at(0).Outputs(); - - auto qnn_tensor = litert::qnn::BuildDefaultTensor(); - const auto& op_out = op_outs.front(); - LITERT_ASSERT_OK(litert::qnn::LegalizeTensor(op_out, qnn_tensor)); - - ASSERT_EQ(qnn_tensor.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_INT_16); - EXPECT_EQ(qnn_tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); - - ASSERT_EQ(qnn_tensor.v2.quantizeParams.quantizationEncoding, - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET); - ASSERT_FLOAT_EQ(qnn_tensor.v2.quantizeParams.scaleOffsetEncoding.scale, - kSimpleMulQuantModelOutputScale); - - ASSERT_FLOAT_EQ(qnn_tensor.v2.quantizeParams.scaleOffsetEncoding.offset, - kSimpleMulQuantModelOutputOffset); - litert::qnn::ResetTensor(qnn_tensor); -} - -TEST(TestLegalizeTensor, PerChannelQuantizedTensor) { - auto model = litert::testing::LoadTestFileModel(kQKeyEinsum16x8Model); - - auto subgraph = model.MainSubgraph(); - EXPECT_TRUE(subgraph); - auto ops = subgraph->Ops(); - auto op_ins = ops.at(1).Inputs(); - - auto qnn_tensor = litert::qnn::BuildDefaultTensor(); - const auto& per_channel_quant_tensor = op_ins[1]; - LITERT_ASSERT_OK( - litert::qnn::LegalizeTensor(per_channel_quant_tensor, qnn_tensor)); - - EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_INT_8); - - LiteRtQuantizationPerChannel per_channel_quant_params = - per_channel_quant_tensor.PerChannelQuantization(); - - ASSERT_EQ(qnn_tensor.v2.quantizeParams.quantizationEncoding, - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET); - EXPECT_EQ(qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.axis, - per_channel_quant_params.quantized_dimension); - EXPECT_EQ( - qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets, - per_channel_quant_params.num_channels); - for (int i = 0; i < per_channel_quant_params.num_channels; ++i) { - ASSERT_FLOAT_EQ( - qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i] - .scale, - per_channel_quant_params.scales[i]); - ASSERT_EQ( - qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i] - .offset, - per_channel_quant_params.zero_points[i]); - } - litert::qnn::ResetTensor(qnn_tensor); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc deleted file mode 100644 index e0be0c0c8650ae..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" - -#include -#include - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpGraph.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnGraph.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn { - -inline absl::Span GetDefaultGraphConfigs() { - static std::array graph_custom_configs; - // QNN suggest always enable relax precision. - graph_custom_configs[0] = QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT; - graph_custom_configs[0].option = QNN_HTP_GRAPH_CONFIG_OPTION_PRECISION; - graph_custom_configs[0].precision = QNN_PRECISION_FLOAT16; - // Default use O3 for now. - graph_custom_configs[1] = QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT; - graph_custom_configs[1].option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; - graph_custom_configs[1].optimizationOption.type = - QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG; - // Change to 2 if you want to use O2 (default). - graph_custom_configs[1].optimizationOption.floatValue = 3; - - static std::array graph_configs; - graph_configs[0] = QNN_GRAPH_CONFIG_INIT; - graph_configs[0].option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; - graph_configs[0].customConfig = &graph_custom_configs[0]; - - graph_configs[1] = QNN_GRAPH_CONFIG_INIT; - graph_configs[1].option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; - graph_configs[1].customConfig = &graph_custom_configs[1]; - - static std::array result = { - &graph_configs[0], &graph_configs[1], nullptr}; - - return absl::MakeSpan(result.data(), result.size()); -} - -inline absl::Span GetLegacyGraphConfigs() { - static QnnHtpGraph_CustomConfig_t graph_custom_config; - // Default use O3 for now. - graph_custom_config = QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT; - graph_custom_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; - graph_custom_config.optimizationOption.type = - QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG; - // Change to 2 if you want to use O2 (default). - graph_custom_config.optimizationOption.floatValue = 3; - - static QnnGraph_Config_t graph_config; - graph_config = QNN_GRAPH_CONFIG_INIT; - graph_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; - graph_config.customConfig = &graph_custom_config; - - static std::array result = {&graph_config, - nullptr}; - - return absl::MakeSpan(result.data(), result.size()); -} - -absl::Span GraphMapper::PickGraphConfigHeuristic() { - if (qnn_.IsLegacySocModel()) { - return GetLegacyGraphConfigs(); - } else { - return GetDefaultGraphConfigs(); - } -} - -LiteRtStatus GraphMapper::AssignTensorName(Qnn_Tensor_t& qnn_tensor) { - char* name = nullptr; - const int written = asprintf(&name, "Tensor_%d", cur_tensor_num_++); - LITERT_ENSURE(written != -1 && name != nullptr, kLiteRtStatusErrorNotFound, - "Failed to make tensor name"); - qnn_tensor.v2.name = name; - return kLiteRtStatusOk; -} - -absl::flat_hash_map& GraphMapper::CurrentScope() { - return current_scope_; -} - -LiteRtStatus GraphMapper::LookupInScope(LiteRtTensor litert_tensor, - Qnn_Tensor_t& qnn_tensor) { - // If we go in topological order, this should never happen. TODO: add - // "internal error" status code. - const auto qnn_id = CurrentScope().find(litert_tensor); - // when qnn_id is not found, the tensor is a constant tensor thats not been - // added qnn graph. - if (qnn_id == CurrentScope().end()) { - LITERT_LOG(LITERT_INFO, "Adding constant tensor %s to qnn graph", - qnn_tensor.v2.name); - LITERT_RETURN_IF_ERROR(LegalizeAndRegister(litert_tensor, qnn_tensor)); - LITERT_RETURN_IF_ERROR(PushToScope(litert_tensor, qnn_tensor)); - // } - return kLiteRtStatusOk; - } - LITERT_LOG(LITERT_INFO, "Found tensor %d in current_scope.", qnn_id->second); - ResetTensor(qnn_tensor); - qnn_tensor.v2.id = qnn_id->second; - - return kLiteRtStatusOk; -} - -LiteRtStatus GraphMapper::PushToScope(LiteRtTensor litert_tensor, - Qnn_Tensor_t& qnn_tensor) { - CurrentScope()[litert_tensor] = MoveToId(qnn_tensor); - return kLiteRtStatusOk; -} - -QnnManager& GraphMapper::Qnn() { return qnn_; } - -Qnn_GraphHandle_t& GraphMapper::QnnGraph() { return qnn_graph_; } - -LiteRtStatus GraphMapper::LegalizeAndRegister(LiteRtTensor litert_tensor, - Qnn_Tensor_t& qnn_tensor) { - litert::Tensor tensor(litert_tensor); - LITERT_RETURN_IF_ERROR(LegalizeTensor(tensor, qnn_tensor)); - LITERT_RETURN_IF_ERROR(AssignTensorName(qnn_tensor)); - - // Set tensor as graph output if it is used by other Ops. - if (graph_outpus_.contains(litert_tensor)) { - LITERT_LOG(LITERT_INFO, "Setting tensor %d as Graph output", - qnn_tensor.v2.id); - qnn_tensor.v2.type = QNN_TENSOR_TYPE_APP_READ; - } - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - qnn_.Api()->tensorCreateGraphTensor(QnnGraph(), &qnn_tensor)); - - LITERT_LOG(LITERT_INFO, "Legalized and registered tensor %d", - qnn_tensor.v2.id); - - for (int i = 0; i < qnn_tensor.v2.rank; ++i) { - LITERT_LOG(LITERT_INFO, "qnn_tensor dim[%d] = %d", i, - qnn_tensor.v2.dimensions[i]); - } - - return kLiteRtStatusOk; -} - -LiteRtStatus GraphMapper::IsLiteRtSubgraphSupported() { - // For now, we assume all LiteRt subgraphs are supported. - // TODO: b/381133565: Implement or remove this function. - return kLiteRtStatusOk; -} - -LiteRtStatus GraphMapper::InitQnnGraph(absl::string_view qnn_graph_name) { - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - qnn_.Api()->graphCreate(context_handle_, qnn_graph_name.data(), - PickGraphConfigHeuristic().data(), &QnnGraph())); - return kLiteRtStatusOk; -} - -LiteRtStatus GraphMapper::Finalize() { - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - qnn_.Api()->graphFinalize(QnnGraph(), nullptr, nullptr)); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h deleted file mode 100644 index 3e70e9f222e442..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_GRAPH_MAPPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_GRAPH_MAPPER_H_ - -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnGraph.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn { - -// Algorithm class for managing "scope" when mapping litert Subgraphs -// to QNN Graphs. -class GraphMapper { - public: - GraphMapper(LiteRtSubgraph subgraph, QnnManager& qnn, - Qnn_ContextHandle_t context_handle) - : subgraph_(Subgraph(subgraph)), - qnn_(qnn), - context_handle_(context_handle) {} - - // Legalize given LiteRtTensors attributes into QNN Tensor registered with - // QNN context. Result QNN Tensor is empty except for the canonical id - // assigned by QNN Api. - LiteRtStatus LegalizeAndRegister(LiteRtTensor litert_tensor, - Qnn_Tensor_t& qnn_tensor); - - // Find ID associated with evaluated litert Tensor and add it to given - // QNN Tensor. - LiteRtStatus LookupInScope(LiteRtTensor litert_tensor, - Qnn_Tensor_t& qnn_tensor); - - // Adds new mapping to scope. All fields other than ID in given QNN Tensor are - // cleared and its ID is added to "current_scope". Expects QNN Tensor has - // already been registered with context. - LiteRtStatus PushToScope(LiteRtTensor litert_tensor, - Qnn_Tensor_t& qnn_tensor); - - // NOTE: QNN Tensors must be created with a unique name. This will ensure - // uniqueness but will want to have more meaningful names in the future. - LiteRtStatus AssignTensorName(Qnn_Tensor_t& qnn_tensor); - - // QNN Sdk Accessors - QnnManager& Qnn(); - Qnn_GraphHandle_t& QnnGraph(); - - // CC Convenience Accessors - const Subgraph& Graph() const { return subgraph_; } - - // Accessor for current scope. - // Since each QNN Tensor needs to have a unique name globally within each QNN - // context, we maintain "Current scope", which is a map of evaluated - // LiteRtTensors to their resolved QNN Tensor ID. - absl::flat_hash_map& CurrentScope(); - - // Can implementation handle given LiteRtSubgraph topology (see comment at - // bottom of file). - LiteRtStatus IsLiteRtSubgraphSupported(); - - // Initialize QNN Graph with given name. Call this after parsing - // LiteRtSubgraph. - LiteRtStatus InitQnnGraph(absl::string_view qnn_graph_name); - - // Finalize QNN Graph. Call this after all ops have been mapped. - LiteRtStatus Finalize(); - - inline void RegisterOutput(LiteRtTensor litert_tensor) { - graph_outpus_.insert(litert_tensor); - } - - // Pick graph config based on subgraph. - absl::Span PickGraphConfigHeuristic(); - - inline bool IsTensorOutput(LiteRtTensor litert_tensor) { - return graph_outpus_.contains(litert_tensor); - } - - private: - const Subgraph subgraph_; - - // Set of all outputs of the graph. - absl::flat_hash_set graph_outpus_; - - // Maps evaluated tensors to their resolved QNN Tensor ID. - absl::flat_hash_map current_scope_; - - // - // QNN Sdk State - // - QnnManager& qnn_; - Qnn_ContextHandle_t context_handle_; - Qnn_GraphHandle_t qnn_graph_ = nullptr; - - // - // Tensor Naming - // - - uint32_t cur_tensor_num_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_GRAPH_MAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD deleted file mode 100644 index 46f27e985fd262..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD +++ /dev/null @@ -1,922 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_lib") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:private"], -) - -litert_lib( - name = "legalization", - hdrs = ["legalization.h"], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - ], -) - -litert_lib( - name = "add_op_legalization", - srcs = ["add_op_legalization.cc"], - hdrs = ["add_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "batch_matmul_op_legalization", - srcs = ["batch_matmul_op_legalization.cc"], - hdrs = ["batch_matmul_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "cast_op_legalization", - srcs = ["cast_op_legalization.cc"], - hdrs = ["cast_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "concatenation_op_legalization", - srcs = ["concatenation_op_legalization.cc"], - hdrs = ["concatenation_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "cos_op_legalization", - srcs = ["cos_op_legalization.cc"], - hdrs = ["cos_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "div_op_legalization", - srcs = ["div_op_legalization.cc"], - hdrs = ["div_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "dynamic_update_slice_op_legalization", - srcs = ["dynamic_update_slice_op_legalization.cc"], - hdrs = ["dynamic_update_slice_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "embedding_lookup_op_legalization", - srcs = ["embedding_lookup_op_legalization.cc"], - hdrs = ["embedding_lookup_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "transpose_op_legalization", - srcs = ["transpose_op_legalization.cc"], - hdrs = ["transpose_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "fully_connected_op_legalization", - srcs = ["fully_connected_op_legalization.cc"], - hdrs = ["fully_connected_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "gelu_op_legalization", - srcs = ["gelu_op_legalization.cc"], - hdrs = ["gelu_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "greater_op_legalization", - srcs = ["greater_op_legalization.cc"], - hdrs = ["greater_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "less_op_legalization", - srcs = ["less_op_legalization.cc"], - hdrs = ["less_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "logical_and_op_legalization", - srcs = ["logical_and_op_legalization.cc"], - hdrs = ["logical_and_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "mul_op_legalization", - srcs = ["mul_op_legalization.cc"], - hdrs = ["mul_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "pack_op_legalization", - srcs = ["pack_op_legalization.cc"], - hdrs = ["pack_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "quantize_op_legalization", - srcs = ["quantize_op_legalization.cc"], - hdrs = ["quantize_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_element_type", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "reshape_op_legalization", - srcs = ["reshape_op_legalization.cc"], - hdrs = ["reshape_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "rsqrt_op_legalization", - srcs = ["rsqrt_op_legalization.cc"], - hdrs = ["rsqrt_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "sin_op_legalization", - srcs = ["sin_op_legalization.cc"], - hdrs = ["sin_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "select_op_legalization", - srcs = ["select_op_legalization.cc"], - hdrs = ["select_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "slice_op_legalization", - srcs = ["slice_op_legalization.cc"], - hdrs = ["slice_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "sum_op_legalization", - srcs = ["sum_op_legalization.cc"], - hdrs = ["sum_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "sub_op_legalization", - srcs = ["sub_op_legalization.cc"], - hdrs = ["sub_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "softmax_op_legalization", - srcs = ["softmax_op_legalization.cc"], - hdrs = ["softmax_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "tanh_op_legalization", - srcs = ["tanh_op_legalization.cc"], - hdrs = ["tanh_op_legalization.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - ":util", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/c:litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) - -litert_lib( - name = "util", - srcs = ["util.cc"], - hdrs = ["util.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - ":legalization", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", - "//tensorflow/lite/experimental/litert/tools:dump", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc deleted file mode 100644 index a2a8da69bdc816..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnAddOpTypeName = "ElementWiseAdd"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kAddOpFmt = "add_%d"; - -LiteRtStatus AddOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflAdd) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kAddOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnAddOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized add op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h deleted file mode 100644 index c8301cb124666b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class AddOpLegalization : public Legalization { - public: - AddOpLegalization() = default; - ~AddOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc deleted file mode 100644 index 0685a751243054..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnBatchMatmulOpTypeName = "MatMul"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kBatchMatmulOpFmt = "batch_matmul_%d"; - -LiteRtStatus BatchMatmulOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflBatchMatmul) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kBatchMatmulOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnBatchMatmulOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized batch_matmul op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h deleted file mode 100644 index 60aee1f164f079..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class BatchMatmulOpLegalization : public Legalization { - public: - BatchMatmulOpLegalization() = default; - ~BatchMatmulOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.cc deleted file mode 100644 index 8a3bdef7138a6d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnCastOpTypeName = "Cast"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kCastOpFmt = "cast_%d"; - -LiteRtStatus CastOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflCast) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kCastOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnCastOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized cast op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h deleted file mode 100644 index fecbe54be7643d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cast_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CAST_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CAST_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class CastOpLegalization : public Legalization { - public: - CastOpLegalization() = default; - ~CastOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CAST_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.cc deleted file mode 100644 index 11fd3f526fb8ed..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.cc +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h" - -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnConcatenationOpTypeName = "Concat"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kConcatenationOpFmt = "concatenation_%d"; - -static constexpr int kReduceConcatenationOpOutputSize = 1; -static constexpr int kReduceConcatenationOpParamSize = 1; - -LiteRtStatus ConcatenationOpLegalization::LegalizeOp( - const Op& src, Qnn_OpConfig_t& dest, GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflConcatenation) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - std::string op_name = absl::StrFormat(kConcatenationOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnConcatenationOpTypeName.data(), dest)); - - // Look up op input tensors in scope. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, op_ins.size(), QNN_TENSOR_INIT); - - Qnn_Tensor_t* cur_qnn_op_in = qnn_op_ins; - for (const auto& op_in : op_ins) { - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_in.Get(), *cur_qnn_op_in)); - ++cur_qnn_op_in; - } - - // QNN concatenation op expects 1 output tensor. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, - kReduceConcatenationOpOutputSize, QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - // Extract axis option from concatenation op. - int32_t axis; - LITERT_RETURN_IF_ERROR(LiteRtGetConcatenationAxisOption(src.Get(), &axis)); - - // Construct the scalar "axis" param. - Qnn_Param_t axis_param = BuildDefaultParam(); - axis_param.paramType = QNN_PARAMTYPE_SCALAR; - axis_param.name = "axis"; - Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT; - axis_scalar.dataType = QNN_DATATYPE_UINT_32; - axis_scalar.int32Value = axis; - axis_param.scalarParam = axis_scalar; - - Qnn_Param_t concatenation_params[] = {axis_param}; - dest.v1.inputTensors = qnn_op_ins; - dest.v1.numOfInputs = op_ins.size(); - dest.v1.outputTensors = qnn_op_outs; - dest.v1.numOfOutputs = kReduceConcatenationOpOutputSize; - dest.v1.numOfParams = kReduceConcatenationOpParamSize; - dest.v1.params = concatenation_params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized concatenation op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h deleted file mode 100644 index b3c26971b57c43..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/concatenation_op_legalization.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CONCATENATION_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CONCATENATION_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class ConcatenationOpLegalization : public Legalization { - public: - ConcatenationOpLegalization() = default; - ~ConcatenationOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { - return std::make_unique(); - } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_CONCATENATION_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.cc deleted file mode 100644 index 7bd555b31cef8b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnCosOpTypeName = "ElementWiseCos"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kCosOpFmt = "cos_%d"; - -LiteRtStatus CosOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflCos) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kCosOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnCosOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized cos op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h deleted file mode 100644 index 6a35da2fb12d4c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/cos_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_COS_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_COS_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class CosOpLegalization : public Legalization { - public: - CosOpLegalization() = default; - ~CosOpLegalization() = default; - using UniquePtr = std::unique_ptr; - static UniquePtr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_COS_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc deleted file mode 100644 index 947bad6f719b0f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnDivOpTypeName = "ElementWiseDivide"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kDivOpFmt = "div_%d"; - -LiteRtStatus DivOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflDiv) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kDivOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnDivOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized div op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h deleted file mode 100644 index a22b91248a4661..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DIV_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DIV_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class DivOpLegalization : public Legalization { - public: - DivOpLegalization() = default; - ~DivOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DIV_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.cc deleted file mode 100644 index 1511802a788a1b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.cc +++ /dev/null @@ -1,304 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.h" - -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; - -// Dynamic update slice op info. -static constexpr int kDynamicUpdateSliceOpOperandIndex = 0; -static constexpr int kDynamicUpdateSliceOpUpdateIndex = 1; -static constexpr int kDynamicUpdateSliceOpIndicesIndex = 2; - -// ScatterND op config. -static constexpr absl::string_view kQnnScatterNdOpTypeName = "ScatterNd"; -static constexpr absl::string_view kScatterNdOpFmt = "dus_scatter_nd_%d"; -static constexpr int kScatterNDOpInputSize = 3; -static constexpr int kScatterNDOpOutputSize = 1; -static constexpr int kScatterNDOutputRank = 4; -static constexpr int kScatterNDParamSize = 0; - -// Strided slice op config. -static constexpr absl::string_view kStridedSliceOpTypeName = "StridedSlice"; -static constexpr absl::string_view kStridedSliceOpFmt = "dus_strided_slice_%d"; -static constexpr int kStridedSliceOpInputSize = 1; -static constexpr int kStridedSliceOpOutputSize = 1; -static constexpr int kStridedSliceOpOutputRank = 1; -static constexpr int kStridedSliceParamSize = 1; -static constexpr absl::string_view kRangesParamName = "ranges"; -static constexpr int kRangesParamRank = 2; -static constexpr int kRangesParamArgSize = 3; - -// Reshape op config. -static constexpr absl::string_view kReshapeOpTypeName = "Reshape"; -static constexpr absl::string_view kReshapeOpFmt = "dus_reshape_%d"; -static constexpr int kReshapeOpInputSize = 1; -static constexpr int kReshapeOpOutputSize = 1; -static constexpr int kReshapeOpOutputRank = 2; -static constexpr int kReshapeParamSize = 0; - -// Transpose op config. -static constexpr absl::string_view kTransposeOpTypeName = "Transpose"; -static constexpr absl::string_view kTransposeOperandOpFmt = - "dus_transpose_operand_%d"; -static constexpr absl::string_view kTransposeUpdateOpFmt = - "dus_transpose_update_%d"; -static constexpr absl::string_view kTransposeResultOpFmt = - "dus_transpose_result_%d"; -static constexpr int kTransposeOpInputSize = 1; -static constexpr int kTransposeOpOutputSize = 1; -static constexpr int kTransposeOpOutputRank = 4; -static constexpr int kTransposeParamSize = 1; -static constexpr absl::string_view kPermParamName = "perm"; -static constexpr int kPermParamRank = 1; -static constexpr int kPermParamArgSize = 4; - -LiteRtStatus DynamicUpdateSliceOpLegalization::LegalizeOp( - const Op& src, Qnn_OpConfig_t& dest, GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflDynamicUpdateSlice) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - - // Legalize input tensors, lookup operand tensor in scope. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kScatterNDOpInputSize, - QNN_TENSOR_INIT); - - Qnn_Tensor_t* cur_qnn_op_in = qnn_op_ins; - for (const auto& op_in : op_ins) { - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_in.Get(), *cur_qnn_op_in)); - ++cur_qnn_op_in; - } - // Legalize op data type. - Qnn_DataType_t OperandDataType, UpdateDataType; - LITERT_RETURN_IF_ERROR(LegalizeElementType( - op_ins[kDynamicUpdateSliceOpOperandIndex].ElementType(), - &OperandDataType)); - LITERT_RETURN_IF_ERROR(LegalizeElementType( - op_ins[kDynamicUpdateSliceOpUpdateIndex].ElementType(), &UpdateDataType)); - - //=========================================================================== - // Step 1.1 Build strided slice op. Extract slice index from input[2] - // input: [0, x, 0, 0] (LiteRT.DUS input[2]) - // output: [x] - Qnn_OpConfig_t strided_slice_op = BuildDefaultOp(); - std::string op_name = absl::StrFormat(kStridedSliceOpFmt, op_counter_); - LITERT_RETURN_IF_ERROR( - SetOpInfo(op_name.c_str(), kDefaultQnnOpPackageName.data(), - kStridedSliceOpTypeName.data(), strided_slice_op)); - - // Prepare strided slice op params. - std::vector ranges = {1, 2, 1}; - std::vector ranges_dims = {1, kRangesParamArgSize}; - Qnn_Param_t range_param = BuildDefaultParam(); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildQnnTesnorParam( - ranges.data(), ranges_dims.data(), QNN_DATATYPE_INT_32, kRangesParamRank, - kRangesParamName.data(), graph_mapper, range_param)); - - // Prepare strided slice op outputs. - Qnn_Tensor_t strided_slice_op_out = BuildDefaultTensor(); - std::vector slice_op_out_dims = {1}; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnNativeTensor( - QNN_DATATYPE_INT_32, kStridedSliceOpOutputRank, slice_op_out_dims.data(), - graph_mapper, strided_slice_op_out)); - - // Configure strided slice op. - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kStridedSliceOpInputSize, &qnn_op_ins[kDynamicUpdateSliceOpIndicesIndex], - kStridedSliceOpOutputSize, &strided_slice_op_out, strided_slice_op, - kStridedSliceParamSize, &range_param, graph_mapper)); - - LITERT_LOG(LITERT_INFO, "Added strided slice op for dus"); - - //=========================================================================== - // Step 1.2 Build reshape op. Construct input tensor shape for QNN.ScatterND - // op. - // input: [x] (QNN.StridedSlice output) - // output: [[x]] - Qnn_OpConfig_t reshape_op = BuildDefaultOp(); - std::string reshpae_op_name = absl::StrFormat(kReshapeOpFmt, op_counter_); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kReshapeOpTypeName.data(), reshape_op)); - - // Prepare reshape op output tensor. - Qnn_Tensor_t reshape_op_out = BuildDefaultTensor(); - std::vector reshape_op_out_dims = {1, 1}; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnNativeTensor( - QNN_DATATYPE_INT_32, kReshapeOpOutputRank, reshape_op_out_dims.data(), - graph_mapper, reshape_op_out)); - - // Configure reshape op. - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kReshapeOpInputSize, &strided_slice_op_out, kReshapeOpOutputSize, - &reshape_op_out, reshape_op, kReshapeParamSize, nullptr, graph_mapper)); - - LITERT_LOG(LITERT_INFO, "Added reshape op for dus"); - - //=========================================================================== - // Step 2 Build transpose op. Swap the first two dimensions of the input - // tensor[0] and input tensor[1]. - // op. - // input: [a, b, c, d] (LiteRT.DUS input[0]/input[1] ) - // output: [b, a, c, d] - Qnn_OpConfig_t transpose_operand_op = BuildDefaultOp(); - Qnn_OpConfig_t transpose_update_op = BuildDefaultOp(); - std::string transpose_operand_op_name = - absl::StrFormat(kTransposeOperandOpFmt, op_counter_); - std::string transpose_update_op_name = - absl::StrFormat(kTransposeUpdateOpFmt, op_counter_); - LITERT_RETURN_IF_ERROR(SetOpInfo( - transpose_operand_op_name.c_str(), kDefaultQnnOpPackageName.data(), - kTransposeOpTypeName.data(), transpose_operand_op)); - LITERT_RETURN_IF_ERROR(SetOpInfo( - transpose_update_op_name.c_str(), kDefaultQnnOpPackageName.data(), - kTransposeOpTypeName.data(), transpose_update_op)); - - // Prepare transpose op params. - std::vector perm = {1, 0, 2, 3}; - std::vector perm_dims = {kPermParamArgSize}; - Qnn_Param_t perm_param = BuildDefaultParam(); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildQnnTesnorParam( - perm.data(), perm_dims.data(), QNN_DATATYPE_UINT_32, kPermParamRank, - kPermParamName.data(), graph_mapper, perm_param)); - - // Prepare transpose op outputs. - Qnn_Tensor_t transpose_operand_op_output = BuildDefaultTensor(); - Qnn_Tensor_t transpose_update_op_output = BuildDefaultTensor(); - - // Cast const int to uint32_t. - auto cast_f = [](int const_int) { return static_cast(const_int); }; - - std::vector transpose_operand_op_output_dims( - kTransposeOpOutputRank); - std::vector transpose_update_op_output_dims(kTransposeOpOutputRank); - auto operand_dims = src.Inputs()[kDynamicUpdateSliceOpOperandIndex] - .RankedTensorType() - ->Layout() - .Dimensions(); - transpose_operand_op_output_dims[0] = cast_f(operand_dims[1]); - transpose_operand_op_output_dims[1] = cast_f(operand_dims[0]); - transpose_operand_op_output_dims[2] = cast_f(operand_dims[2]); - transpose_operand_op_output_dims[3] = cast_f(operand_dims[3]); - - auto update_dims = src.Inputs()[kDynamicUpdateSliceOpUpdateIndex] - .RankedTensorType() - ->Layout() - .Dimensions(); - transpose_update_op_output_dims[0] = cast_f(update_dims[1]); - transpose_update_op_output_dims[1] = cast_f(update_dims[0]); - transpose_update_op_output_dims[2] = cast_f(update_dims[2]); - transpose_update_op_output_dims[3] = cast_f(update_dims[3]); - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnNativeTensor( - OperandDataType, kTransposeOpOutputRank, - transpose_operand_op_output_dims.data(), graph_mapper, - transpose_operand_op_output)); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnNativeTensor( - UpdateDataType, kTransposeOpOutputRank, - transpose_update_op_output_dims.data(), graph_mapper, - transpose_update_op_output)); - - // Configure transpose ops. - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kTransposeOpInputSize, &qnn_op_ins[kDynamicUpdateSliceOpOperandIndex], - kTransposeOpOutputSize, &transpose_operand_op_output, - transpose_operand_op, kTransposeParamSize, &perm_param, graph_mapper)); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kTransposeOpInputSize, &qnn_op_ins[kDynamicUpdateSliceOpUpdateIndex], - kTransposeOpOutputSize, &transpose_update_op_output, transpose_update_op, - kTransposeParamSize, &perm_param, graph_mapper)); - - //=========================================================================== - // Step 3 Build ScatterND op. - Qnn_OpConfig_t scatter_nd_op = BuildDefaultOp(); - std::string scatter_nd_op_name = - absl::StrFormat(kScatterNdOpFmt, op_counter_); - LITERT_RETURN_IF_ERROR( - SetOpInfo(scatter_nd_op_name.c_str(), kDefaultQnnOpPackageName.data(), - kQnnScatterNdOpTypeName.data(), scatter_nd_op)); - - // Prepare scatter nd op output tensor. - Qnn_Tensor_t scatter_nd_op_output = BuildDefaultTensor(); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - BuildAndRegisterQnnNativeTensor(OperandDataType, kScatterNDOutputRank, - transpose_operand_op_output_dims.data(), - graph_mapper, scatter_nd_op_output)); - - // Configure ScatterND op. - LITERT_STACK_ARRAY(Qnn_Tensor_t, scatter_nd_op_ins, kScatterNDOpInputSize, - QNN_TENSOR_INIT); - scatter_nd_op_ins[0] = transpose_operand_op_output; - scatter_nd_op_ins[1] = reshape_op_out; - scatter_nd_op_ins[2] = transpose_update_op_output; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kScatterNDOpInputSize, scatter_nd_op_ins, kScatterNDOpOutputSize, - &scatter_nd_op_output, scatter_nd_op, kScatterNDParamSize, nullptr, - graph_mapper)); - - //=========================================================================== - // Step 4 Build final transpose op. Swap back the first two dimensions of the - // scatter nd op output. - // op. - // input: [b, a, c, d] (QNN.ScatterND output) - // output: [a, b, c, d] - std::string transpose_result_op_name = absl::StrFormat( - kTransposeResultOpFmt, /*increase counter*/ op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(transpose_result_op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kTransposeOpTypeName.data(), dest)); - - // Legalize op outputs and update scope. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, op_outs.size(), - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - // Configure transpose op. - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kTransposeOpInputSize, &scatter_nd_op_output, kTransposeOpOutputSize, - &qnn_op_outs[0], dest, kTransposeParamSize, &perm_param, graph_mapper)); - - LITERT_LOG(LITERT_INFO, "Legalized dynamic update slice op"); - - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.h deleted file mode 100644 index 2a497f4f5cfcdb..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/dynamic_update_slice_op_legalization.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DYNAMIC_UPDATE_SLICE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DYNAMIC_UPDATE_SLICE_OP_LEGALIZATION_H_ - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class DynamicUpdateSliceOpLegalization : public Legalization { - public: - DynamicUpdateSliceOpLegalization() = default; - ~DynamicUpdateSliceOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { - return std::make_unique(); - } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DYNAMIC_UPDATE_SLICE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.cc deleted file mode 100644 index ecab067e3846a5..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.cc +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnEmbeddingLookupOpTypeName = "Gather"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kEmbeddingLookupOpFmt = - "embedding_lookup_%d"; - -static constexpr int kReduceEmbeddingLookupOpOutputSize = 1; -static constexpr int kReduceEmbeddingLookupOpParamSize = 1; - -static constexpr int kEmbeddingLookupOpTableInputIndex = 1; -static constexpr int kEmbeddingLookupOpLookipInputIndex = 0; -static constexpr int kQnnGatherOpTableInputIndex = 0; -static constexpr int kQnnGatherOpLookupInputIndex = 1; - -LiteRtStatus EmbeddingLookupOpLegalization::LegalizeOp( - const Op& src, Qnn_OpConfig_t& dest, GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflEmbeddingLookup) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - std::string op_name = absl::StrFormat(kEmbeddingLookupOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnEmbeddingLookupOpTypeName.data(), dest)); - - // Look up op input tensors in scope. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, op_ins.size(), QNN_TENSOR_INIT); - - LITERT_RETURN_IF_ERROR(graph_mapper.LookupInScope( - op_ins[kEmbeddingLookupOpLookipInputIndex].Get(), - qnn_op_ins[kQnnGatherOpLookupInputIndex])); - LITERT_RETURN_IF_ERROR(graph_mapper.LookupInScope( - op_ins[kEmbeddingLookupOpTableInputIndex].Get(), - qnn_op_ins[kQnnGatherOpTableInputIndex])); - - // QNN embedding_lookup op expects 1 output tensor. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, - kReduceEmbeddingLookupOpOutputSize, QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - // Construct the scalar "axis" param. - Qnn_Param_t axis_param = BuildDefaultParam(); - axis_param.paramType = QNN_PARAMTYPE_SCALAR; - axis_param.name = "axis"; - Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT; - axis_scalar.dataType = QNN_DATATYPE_INT_32; - // Embedding lookup op expects axis to always be 0. - axis_scalar.int32Value = 0; - axis_param.scalarParam = axis_scalar; - - Qnn_Param_t embedding_lookup_params[] = {axis_param}; - dest.v1.inputTensors = qnn_op_ins; - dest.v1.numOfInputs = op_ins.size(); - dest.v1.outputTensors = qnn_op_outs; - dest.v1.numOfOutputs = kReduceEmbeddingLookupOpOutputSize; - dest.v1.numOfParams = kReduceEmbeddingLookupOpParamSize; - dest.v1.params = embedding_lookup_params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized embedding_lookup op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h deleted file mode 100644 index e8bae779d2ae64..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_EMBEDDING_LOOKUP_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_EMBEDDING_LOOKUP_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class EmbeddingLookupOpLegalization : public Legalization { - public: - EmbeddingLookupOpLegalization() = default; - ~EmbeddingLookupOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { - return std::make_unique(); - } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_EMBEDDING_LOOKUP_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.cc deleted file mode 100644 index fca0d31c26a987..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnFullyConnectedOpTypeName = - "FullyConnected"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kFullyConnectedOpFmt = "fully_connected_%d"; - -LiteRtStatus FullyConnectedOpLegalization::LegalizeOp( - const Op& src, Qnn_OpConfig_t& dest, GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflFullyConnected) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kFullyConnectedOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnFullyConnectedOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - - LITERT_LOG(LITERT_INFO, "Legalized fully_connected op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h deleted file mode 100644 index 0ff2983e59e708..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class FullyConnectedOpLegalization : public Legalization { - public: - FullyConnectedOpLegalization() = default; - ~FullyConnectedOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { - return std::make_unique(); - } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_FULLY_CONNECTED_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.cc deleted file mode 100644 index 3b769d9bd7521e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.cc +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnGeluOpTypeName = "Gelu"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kGeluOpFmt = "gelu_%d"; - -LiteRtStatus GeluOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflGelu) { - return kLiteRtStatusLegalizeNoMatch; - } - const std::string op_name = absl::StrFormat(kGeluOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnGeluOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized gelu op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h deleted file mode 100644 index fdb31f5300d07c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class GeluOpLegalization : public Legalization { - public: - GeluOpLegalization() = default; - ~GeluOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.cc deleted file mode 100644 index d07ca4f086c708..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Ungreater required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnGreaterOpTypeName = "ElementWiseGreater"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kGreaterOpFmt = "greater_%d"; - -LiteRtStatus GreaterOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflGreater) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kGreaterOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnGreaterOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized greater op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h deleted file mode 100644 index bb353420291c00..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GREATER_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GREATER_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class GreaterOpLegalization : public Legalization { - public: - GreaterOpLegalization() = default; - ~GreaterOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GREATER_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h deleted file mode 100644 index 5f7c8ef96062ef..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LEGALIZATION_H_ - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" - -#define STRINGIFY(x) #x -#define QNN_OP_NAME(prefix) STRINGIFY(prefix##__COUNTER__) - -namespace litert::qnn { - -class Legalization { - public: - Legalization() = default; - virtual ~Legalization() = default; - - virtual LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) = 0; - - // Sets the op name, package name, and type. - // Note: All argument strings can't be de-allocated until the op has been - // registered with the qnn api. i.e graphAddNode(). - inline LiteRtStatus SetOpInfo(const char* name, const char* op_package_name, - const char* op_type, Qnn_OpConfig_t& op) { - op.v1.name = name; - op.v1.packageName = op_package_name; - op.v1.typeName = op_type; - return kLiteRtStatusOk; - } -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.cc deleted file mode 100644 index 23d45e4ba4a6fb..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnLessOpTypeName = "ElementWiseLess"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kLessOpFmt = "less_%d"; - -LiteRtStatus LessOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflLess) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kLessOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnLessOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized less op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h deleted file mode 100644 index b16c5335f01a8e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LESS_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LESS_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class LessOpLegalization : public Legalization { - public: - LessOpLegalization() = default; - ~LessOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LESS_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.cc deleted file mode 100644 index 1a1bc4dbdc7aa5..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnLogicalAndOpTypeName = "ElementWiseAnd"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kLogicalAndOpFmt = "logical_and_%d"; - -LiteRtStatus LogicalAndOpLegalization::LegalizeOp(const Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflLogicalAnd) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kLogicalAndOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnLogicalAndOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized logical_and op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h deleted file mode 100644 index ec5c5c2a03bf5e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/logical_and_op_legalization.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LOGICAL_AND_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LOGICAL_AND_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class LogicalAndOpLegalization : public Legalization { - public: - LogicalAndOpLegalization() = default; - ~LogicalAndOpLegalization() = default; - using UniquePtr = std::unique_ptr; - static UniquePtr Create() { - return std::make_unique(); - } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LOGICAL_AND_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc deleted file mode 100644 index 4185740e2cb2b0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnMulOpTypeName = "ElementWiseMultiply"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kMulOpFmt = "mul_%d"; - -LiteRtStatus MulOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflMul) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kMulOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnMulOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized mul op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h deleted file mode 100644 index 098d0954430d50..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class MulOpLegalization : public Legalization { - public: - MulOpLegalization() = default; - ~MulOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.cc deleted file mode 100644 index 6e1f3d350813fd..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.h" - -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -// Pack op config. -static constexpr absl::string_view kQnnPackOpTypeName = "Pack"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kPackOpFmt = "pack_%d"; -static constexpr absl::string_view kPackOpAxisParamName = "axis"; -static constexpr int kPackOpAxisParamSize = 1; -static constexpr int kPackScalarsOpOutputRank = 2; - -// Reshape op config. -static constexpr absl::string_view kReshapeOpTypeName = "Reshape"; -static constexpr absl::string_view kReshapeOpFmt = "pack_reshape_%d"; -static constexpr int kReshapeOpInputSize = 1; -static constexpr int kReshapeOpOutputSize = 1; -static constexpr int kReshapeParamSize = 0; - -LiteRtStatus PackOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflPack) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string pack_op_name = absl::StrFormat(kPackOpFmt, op_counter_); - DumpLegalization(*src.Get()); - - // Legalize input tensors, lookup operand tensor in scope. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, op_ins.size(), QNN_TENSOR_INIT); - Qnn_Tensor_t* cur_qnn_op_in = qnn_op_ins; - for (const auto& op_in : op_ins) { - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_in.Get(), *cur_qnn_op_in)); - ++cur_qnn_op_in; - } - - // Legalize output tensors. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, op_outs.size(), - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - // Get axis option and build QNN scalar param. - int32_t axis; - LITERT_RETURN_IF_ERROR(LiteRtGetPackAxisOption(src.Get(), &axis)); - uint32_t axis_value = static_cast(axis); - - Qnn_Param_t axis_param = BuildDefaultParam(); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildQnnScalarParam( - axis_value, QNN_DATATYPE_UINT_32, kPackOpAxisParamName.data(), - graph_mapper, axis_param)); - - // Qnn does not support Packing scalars, scalar value are legalized as 1D - // tensor with single element. In such case, we need to add a reshape op to - // convert result packed 2D tensor to 1D tensor. - auto input_layout = op_ins[0].RankedTensorType()->Layout(); - if (input_layout.Rank() == 0) { - // prepare Pack op output tensor. - Qnn_Tensor_t pack_op_out = BuildDefaultTensor(); - uint32_t pack_op_out_rank = kPackScalarsOpOutputRank; - Qnn_DataType_t PackOpDataType = QNN_DATATYPE_UNDEFINED; - - LITERT_RETURN_IF_ERROR( - LegalizeElementType(op_ins[0].ElementType(), &PackOpDataType)); - std::vector pack_op_out_dims = { - static_cast(op_ins.size())}; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnNativeTensor( - PackOpDataType, pack_op_out_rank, pack_op_out_dims.data(), graph_mapper, - pack_op_out)); - - // Build Pack op. - Qnn_OpConfig_t pack_op = BuildDefaultOp(); - LITERT_RETURN_IF_ERROR(SetOpInfo(pack_op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnPackOpTypeName.data(), pack_op)); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - op_ins.size(), qnn_op_ins, op_outs.size(), &pack_op_out, pack_op, - kPackOpAxisParamSize, &axis_param, graph_mapper)); - - // Build Reshape op. - std::string reshape_op_name = absl::StrFormat(kReshapeOpFmt, op_counter_); - LITERT_RETURN_IF_ERROR(SetOpInfo(reshape_op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kReshapeOpTypeName.data(), dest)); - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(BuildAndRegisterQnnOp( - kReshapeOpInputSize, &pack_op_out, kReshapeOpOutputSize, qnn_op_outs, - dest, kReshapeParamSize, nullptr, graph_mapper)); - } else { - LITERT_RETURN_IF_ERROR(SetOpInfo(pack_op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnPackOpTypeName.data(), dest)); - BuildAndRegisterQnnOp(op_ins.size(), qnn_op_ins, op_outs.size(), - qnn_op_outs, dest, kPackOpAxisParamSize, &axis_param, - graph_mapper); - } - op_counter_++; - - LITERT_LOG(LITERT_INFO, "Legalized pack op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.h deleted file mode 100644 index 42bd24f95b7813..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/pack_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_PACK_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_PACK_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class PackOpLegalization : public Legalization { - public: - PackOpLegalization() = default; - ~PackOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_PACK_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.cc deleted file mode 100644 index bf16efb347447d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.cc +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.h" - -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnConvertOpTypeName = "Convert"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kConvertOpFmt = "q_convert_%d"; - -static constexpr absl::string_view kQnnQuantizeOpTypeName = "Quantize"; -static constexpr absl::string_view kQuantizeOpFmt = "quantize_%d"; - -static constexpr absl::string_view kQnnCastOpTypeName = "Cast"; -static constexpr absl::string_view kCastOpFmt = "q_cast_%d"; - -// SFIXED_8 and UFIXED_8 offset diff -static constexpr int kSUFixed8OffsetDiff = 128; -// SFIXED_16 and UFIXED_16 offset diff -static constexpr int kSUFixed16OffsetDiff = 32768; - -LiteRtStatus QuantizeOpLegalization::LegalizeQuantizeOpAsConvertOp( - const litert::Op& src, Qnn_OpConfig_t& dest) { - std::string op_name = absl::StrFormat(kConvertOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnConvertOpTypeName.data(), dest)); - return kLiteRtStatusOk; -} - -LiteRtStatus QuantizeOpLegalization::LegalizeQuantizeOpAsCastOp( - const litert::Op& src, Qnn_OpConfig_t& dest) { - std::string op_name = absl::StrFormat(kCastOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnCastOpTypeName.data(), dest)); - return kLiteRtStatusOk; -} - -LiteRtStatus QuantizeOpLegalization::LegalizeQuantizeOpAsQuantizeOp( - const litert::Op& src, Qnn_OpConfig_t& dest) { - std::string op_name = absl::StrFormat(kQuantizeOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnQuantizeOpTypeName.data(), dest)); - return kLiteRtStatusOk; -} - -inline bool IsTensorUInt8(Tensor& tensor) { - return tensor.RankedTensorType()->ElementType() == ElementType::UInt8; -} -inline bool IsTensorInt8(Tensor& tensor) { - return tensor.RankedTensorType()->ElementType() == ElementType::Int8; -} -inline bool IsTensorUInt16(Tensor& tensor) { - return tensor.RankedTensorType()->ElementType() == ElementType::UInt16; -} -inline bool IsTensorInt16(Tensor& tensor) { - return tensor.RankedTensorType()->ElementType() == ElementType::Int16; -} - -inline bool IsTensorPerTensorQuantized(Tensor& tensor) { - return (IsTensorInt8(tensor) || IsTensorUInt8(tensor) || - IsTensorInt16(tensor) || IsTensorUInt16(tensor)) && - tensor.QTypeId() == kLiteRtQuantizationPerTensor; -} - -inline bool WithinCastRange(Tensor& input_tensor, Tensor& output_tensor, - const int offst_diff) { - return (std::fabs(input_tensor.PerTensorQuantization().scale - - output_tensor.PerTensorQuantization().scale)) < - std::numeric_limits::epsilon() && - std::abs(input_tensor.PerTensorQuantization().zero_point - - output_tensor.PerTensorQuantization().zero_point) == - offst_diff; -} - -LiteRtStatus QuantizeOpLegalization::ConfigureQnnOp(const litert::Op& src, - Qnn_OpConfig_t& dest) { - const bool is_input_tensor_per_tensor_quantized = - IsTensorPerTensorQuantized(src.Inputs().front()); - const bool is_output_tensor_per_tensor_quantized = - IsTensorPerTensorQuantized(src.Outputs().front()); - - if (is_input_tensor_per_tensor_quantized && - is_output_tensor_per_tensor_quantized) { - // Check if the input and output tensors are int8/uint8 or int16/uint16. - const bool is_input_tensor_int8 = IsTensorInt8(src.Inputs().front()); - const bool is_input_tensor_uint8 = IsTensorUInt8(src.Inputs().front()); - const bool is_input_tensor_int16 = IsTensorInt16(src.Inputs().front()); - const bool is_input_tensor_uint16 = IsTensorUInt16(src.Inputs().front()); - const bool is_output_tensor_int8 = IsTensorInt8(src.Outputs().front()); - const bool is_output_tensor_uint8 = IsTensorUInt8(src.Outputs().front()); - const bool is_output_tensor_int16 = IsTensorInt16(src.Outputs().front()); - const bool is_output_tensor_uint16 = IsTensorUInt16(src.Outputs().front()); - - if ((is_input_tensor_int8 && is_output_tensor_uint8) || - (is_input_tensor_uint8 && is_output_tensor_int8)) { - // Case if the input and output tensors are int8/uint8. - const bool is_quantization_range_within_cast_range = WithinCastRange( - src.Inputs().front(), src.Outputs().front(), kSUFixed8OffsetDiff); - if (is_quantization_range_within_cast_range) { - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - LegalizeQuantizeOpAsCastOp(src, dest)); - LITERT_LOG(LITERT_INFO, "Configured quantize op to Cast Op"); - return kLiteRtStatusOk; - } - } else if ((is_input_tensor_int16 && is_output_tensor_uint16) || - (is_input_tensor_uint16 && is_output_tensor_int16)) { - // Case if the input and output tensors are int16/uint16. - const bool is_quantization_range_within_cast_range = WithinCastRange( - src.Inputs().front(), src.Outputs().front(), kSUFixed16OffsetDiff); - if (is_quantization_range_within_cast_range) { - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - LegalizeQuantizeOpAsCastOp(src, dest)); - LITERT_LOG(LITERT_INFO, "Configured quantize op to Cast Op"); - return kLiteRtStatusOk; - } - } - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - LegalizeQuantizeOpAsConvertOp(src, dest)); - LITERT_LOG(LITERT_INFO, "Configured quantize op to Convert Op"); - return kLiteRtStatusOk; - } - - // Not per tensor quantized, legalize to Quantize Op. - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(LegalizeQuantizeOpAsQuantizeOp(src, dest)); - LITERT_LOG(LITERT_INFO, "Legalized quantize op to Quantize Op"); - return kLiteRtStatusOk; -} - -LiteRtStatus QuantizeOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflQuantize) { - return kLiteRtStatusLegalizeNoMatch; - } - ConfigureQnnOp(src, dest); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized quantize Op"); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.h deleted file mode 100644 index 7621f701b87a3d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_QUANTIZE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_QUANTIZE_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class QuantizeOpLegalization : public Legalization { - public: - QuantizeOpLegalization() = default; - ~QuantizeOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - LiteRtStatus ConfigureQnnOp(const litert::Op& src, Qnn_OpConfig_t& dest); - - private: - // Requantization: legalize quantize to QNN Convert Op. - // Quantization range is not within QNN cast Op range. - LiteRtStatus LegalizeQuantizeOpAsConvertOp(const litert::Op& src, - Qnn_OpConfig_t& dest); - - // Ignore Requantization: legalize quantize to QNN Cast Op. - // Quantization range is within QNN cast Op range. Directly use QNN Cast Op. - LiteRtStatus LegalizeQuantizeOpAsCastOp(const litert::Op& src, - Qnn_OpConfig_t& dest); - - // Quantization: legalize quantize to QNN Quantize Op. - LiteRtStatus LegalizeQuantizeOpAsQuantizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest); - - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_QUANTIZE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc deleted file mode 100644 index 1127b0f3188cf3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnReshapeOpTypeName = "Reshape"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kReshapeOpFmt = "reshape_%d"; - -static constexpr int kReshapeOpInputSize = 1; -static constexpr int kReshapeOpOutputSize = 1; - -LiteRtStatus ReshapeOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflReshape) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kReshapeOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnReshapeOpTypeName.data(), dest)); - DumpLegalization(*src.Get()); - // Look up op input tensors in scope. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kReshapeOpInputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_ins.front().Get(), qnn_op_ins[0])); - - // Legalize op outputs and update scope. - - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kReshapeOpOutputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - dest.v1.numOfInputs = kReshapeOpInputSize; - dest.v1.inputTensors = qnn_op_ins; - - dest.v1.numOfOutputs = kReshapeOpOutputSize; - dest.v1.outputTensors = qnn_op_outs; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized reshape op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h deleted file mode 100644 index e8553639fc0906..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class ReshapeOpLegalization : public Legalization { - public: - ReshapeOpLegalization() = default; - ~ReshapeOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc deleted file mode 100644 index 363434821d6d08..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnRsqrtOpTypeName = "ElementWiseRsqrt"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kRsqrtOpFmt = "rsqrt_%d"; - -LiteRtStatus RsqrtOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflRsqrt) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kRsqrtOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnRsqrtOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized rsqrt op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h deleted file mode 100644 index 5971e9f98cd5b1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class RsqrtOpLegalization : public Legalization { - public: - RsqrtOpLegalization() = default; - ~RsqrtOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.cc deleted file mode 100644 index 9c6da052221bcc..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.cc +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnSelectOpTypeName = "ElementWiseSelect"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kSelectOpFmt = "select_%d"; - -LiteRtStatus SelectOpLegalization::LegalizeOp(const Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflSelect && - src.Code() != kLiteRtOpCodeTflSelectV2) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kSelectOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnSelectOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - - return kLiteRtStatusOk; - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized select op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h deleted file mode 100644 index 526498a4bb4b51..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/select_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SELECT_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SELECT_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class SelectOpLegalization : public Legalization { - public: - SelectOpLegalization() = default; - ~SelectOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SELECT_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.cc deleted file mode 100644 index 17932971f8cee4..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnSinOpTypeName = "ElementWiseSin"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kSinOpFmt = "sin_%d"; - -LiteRtStatus SinOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflSin) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kSinOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnSinOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized sin op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h deleted file mode 100644 index e87296eeb10fb2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sin_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SIN_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SIN_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class SinOpLegalization : public Legalization { - public: - SinOpLegalization() = default; - ~SinOpLegalization() = default; - using UniquePtr = std::unique_ptr; - static UniquePtr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SIN_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc deleted file mode 100644 index 6749bd654eb51c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h" - -#include -#include - -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnSliceOpTypeName = "StridedSlice"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kSliceOpFmt = "slice_%d"; - -static constexpr int kSliceOpInputSize = 1; -static constexpr int kSliceOpOutputSize = 1; -static constexpr int kSliceOpParamSize = 1; -// QNN StridedSlice op packs "start", "end", and "stride" into a single tensor -// param "ranges". -static constexpr int kRangesParamArgSize = 3; -static constexpr int kRangesParamRank = 2; - -LiteRtStatus SliceOpLegalization::LegalizeOp(const Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflSlice) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - std::string op_name = absl::StrFormat(kSliceOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnSliceOpTypeName.data(), dest)); - - // QNN strided slice op expects 1 input tensor. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kSliceOpInputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_ins.front().Get(), qnn_op_ins[0])); - - // QNN strided slice op expects 1 output tensor. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kSliceOpOutputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - const auto& src_input_tensor = op_ins.front(); - auto src_input_tensor_type = src_input_tensor.RankedTensorType(); - if (!src_input_tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", - src_input_tensor_type.Error().Message().c_str()); - return src_input_tensor_type.Error().Status(); - } - - auto src_input_tensor_rank = src_input_tensor_type->Layout().Rank(); - - // Prepare qnn strided slice parameters. - - auto src_begin_indices = op_ins.at(1).WeightsData(); - if (!src_begin_indices) { - return src_begin_indices.Error().Status(); - } - - auto src_size_indices = op_ins.at(2).WeightsData(); - if (!src_size_indices) { - return src_size_indices.Error().Status(); - } - - // Check if src_begin_indices and src_size_indices are weights tensors. - if (src_begin_indices->empty() || src_size_indices->empty()) { - return kLiteRtStatusErrorInvalidLegalization; - } - - LITERT_STACK_ARRAY(int32_t, range_tensor_data, - src_input_tensor_rank* kRangesParamArgSize, - /*init value*/ 0); - for (int i = 0; i < src_input_tensor_rank; ++i) { - // Copy begin, end, and stride values from src_begin_indices and - // src_size_indices to range_tensor_data. Stride is always 1. - range_tensor_data[i * kRangesParamArgSize] = src_begin_indices->at(i); - range_tensor_data[i * kRangesParamArgSize + 1] = - src_begin_indices->at(i) + src_size_indices->at(i); - range_tensor_data[i * kRangesParamArgSize + 2] = 1; - } - - Qnn_ClientBuffer_t range_tensor_client_buf = BuildDefaultClientBuffer(); - range_tensor_client_buf.data = range_tensor_data; - range_tensor_client_buf.dataSize = - src_input_tensor_rank * kRangesParamArgSize * sizeof(int32_t); - - // Construct the const tensor "ranges". - Qnn_Tensor_t range_tensor = BuildDefaultTensor(); - graph_mapper.AssignTensorName(range_tensor); - range_tensor.v2.dataType = QNN_DATATYPE_INT_32; - range_tensor.v2.type = QNN_TENSOR_TYPE_STATIC; - range_tensor.v2.rank = kRangesParamRank; - range_tensor.v2.dimensions = new uint32_t[kRangesParamRank]; - range_tensor.v2.dimensions[0] = src_input_tensor_rank; - range_tensor.v2.dimensions[1] = kRangesParamArgSize; - range_tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; - range_tensor.v2.clientBuf = range_tensor_client_buf; - range_tensor.v2.isDynamicDimensions = nullptr; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), - &range_tensor)); - - Qnn_Param_t range_param = BuildDefaultParam(); - range_param.paramType = QNN_PARAMTYPE_TENSOR; - range_param.name = "ranges"; - range_param.tensorParam = range_tensor; - - Qnn_Param_t strided_slice_params[] = {range_param}; - dest.v1.inputTensors = qnn_op_ins; - dest.v1.numOfInputs = kSliceOpInputSize; - dest.v1.outputTensors = qnn_op_outs; - dest.v1.numOfOutputs = kSliceOpOutputSize; - dest.v1.numOfParams = kSliceOpParamSize; - dest.v1.params = strided_slice_params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized slice op", ""); - - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h deleted file mode 100644 index 1430d1e1fa43ca..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SLICE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SLICE_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class SliceOpLegalization : public Legalization { - public: - SliceOpLegalization() = default; - ~SliceOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SLICE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.cc deleted file mode 100644 index c974e6f462fda2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnSoftmaxOpTypeName = "Softmax"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kSoftmaxOpFmt = "softmax_%d"; - -static constexpr int kSoftmaxOpInputSize = 1; -static constexpr int kSoftmaxOpOutputSize = 1; -static constexpr int kSoftmaxOpParamSize = 1; - -LiteRtStatus SoftmaxOpLegalization::LegalizeOp(const Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflSoftmax) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - std::string op_name = absl::StrFormat(kSoftmaxOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnSoftmaxOpTypeName.data(), dest)); - - // QNN reduce softmax op expects 1 input tensor. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kSoftmaxOpInputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_ins.front().Get(), qnn_op_ins[0])); - - // QNN softmax op expects 1 output tensor. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kSoftmaxOpOutputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - // Prepare QNN reduce softmax parameters. - - // Extract beta option from softmax op. - float beta; - LITERT_RETURN_IF_ERROR(LiteRtGetSoftmaxBetaOption(src.Get(), &beta)); - Qnn_Param_t beta_param = BuildDefaultParam(); - beta_param.paramType = QNN_PARAMTYPE_SCALAR; - beta_param.name = "beta"; - Qnn_Scalar_t keep_dims_scalar = QNN_SCALAR_INIT; - keep_dims_scalar.dataType = QNN_DATATYPE_FLOAT_32; - keep_dims_scalar.floatValue = beta; - beta_param.scalarParam = keep_dims_scalar; - - Qnn_Param_t reduce_softmax_params[] = {beta_param}; - dest.v1.inputTensors = qnn_op_ins; - dest.v1.numOfInputs = kSoftmaxOpInputSize; - dest.v1.outputTensors = qnn_op_outs; - dest.v1.numOfOutputs = kSoftmaxOpOutputSize; - dest.v1.numOfParams = kSoftmaxOpParamSize; - dest.v1.params = reduce_softmax_params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized softmax op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h deleted file mode 100644 index b4ecb005003c91..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/softmax_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class SoftmaxOpLegalization : public Legalization { - public: - SoftmaxOpLegalization() = default; - ~SoftmaxOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SOFTMAX_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc deleted file mode 100644 index 09ff1cbbc4dcfe..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnSubOpTypeName = "ElementWiseSubtract"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kSubOpFmt = "sub_%d"; - -LiteRtStatus SubOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflSub) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kSubOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnSubOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized sub op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h deleted file mode 100644 index 3f05f8e04a7d3e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class SubOpLegalization : public Legalization { - public: - SubOpLegalization() = default; - ~SubOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc deleted file mode 100644 index 40fe0c10f878c0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h" - -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnSumOpTypeName = "ReduceSum"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kSumOpFmt = "sum_%d"; - -static constexpr int kReduceSumOpInputSize = 1; -static constexpr int kReduceSumOpOutputSize = 1; -static constexpr int kReduceSumOpParamSize = 1; -static constexpr int kReduceSumOpParamRank = 1; - -LiteRtStatus SumOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflSum) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - std::string op_name = absl::StrFormat(kSumOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnSumOpTypeName.data(), dest)); - - // QNN reduce sum op expects 1 input tensor. - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kReduceSumOpInputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(src.Inputs().front().Get(), qnn_op_ins[0])); - - // QNN sum op expects 1 output tensor. - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kReduceSumOpOutputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR(graph_mapper.LegalizeAndRegister( - src.Outputs().front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(src.Outputs().front().Get(), qnn_op_outs[0])); - - // Prepare QNN reduce sum parameters. - const auto inputs = src.Inputs(); - const auto& src_axes = inputs.at(1); - - // Check if src_axes are weights tensors. - if (!src_axes.HasWeights()) { - LITERT_LOG(LITERT_ERROR, "Sum op axes are not weights tensors"); - return kLiteRtStatusErrorInvalidLegalization; - } - - auto src_axes_tensor_type = src_axes.RankedTensorType(); - if (!src_axes_tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", - src_axes_tensor_type.Error().Message().c_str()); - return src_axes_tensor_type.Error().Status(); - } - - int32_t dest_axes_size = src_axes_tensor_type->Layout().Dimensions()[0]; - auto src_axes_data = src_axes.Weights().Bytes(); - Qnn_ClientBuffer_t axes_tensor_client_buf = BuildDefaultClientBuffer(); - axes_tensor_client_buf.data = (void*)src_axes_data.data(); - axes_tensor_client_buf.dataSize = src_axes_data.size(); - - // Extract keepdims option from sum op. - bool keep_dims; - LITERT_RETURN_IF_ERROR(LiteRtGetSumKeepDimsOption(src.Get(), &keep_dims)); - - // Construct the scalar "keep_dims" param. - if (keep_dims) { - Qnn_Param_t range_param = BuildDefaultParam(); - range_param.paramType = QNN_PARAMTYPE_SCALAR; - range_param.name = "keep_dims"; - Qnn_Scalar_t keep_dims_scalar = QNN_SCALAR_INIT; - keep_dims_scalar.dataType = QNN_DATATYPE_BOOL_8; - keep_dims_scalar.bool8Value = true; - range_param.scalarParam = keep_dims_scalar; - } - - // Construct the const tensor "axes". - Qnn_Tensor_t range_tensor = BuildDefaultTensor(); - graph_mapper.AssignTensorName(range_tensor); - range_tensor.v2.dataType = QNN_DATATYPE_INT_32; - range_tensor.v2.type = QNN_TENSOR_TYPE_STATIC; - range_tensor.v2.rank = kReduceSumOpParamRank; - range_tensor.v2.dimensions = new uint32_t[kReduceSumOpParamRank]; - range_tensor.v2.dimensions[0] = dest_axes_size; - range_tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; - range_tensor.v2.clientBuf = axes_tensor_client_buf; - range_tensor.v2.isDynamicDimensions = nullptr; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), - &range_tensor)); - - Qnn_Param_t range_param = BuildDefaultParam(); - range_param.paramType = QNN_PARAMTYPE_TENSOR; - range_param.name = "axes"; - range_param.tensorParam = range_tensor; - - Qnn_Param_t reduce_sum_params[] = {range_param}; - dest.v1.inputTensors = qnn_op_ins; - dest.v1.numOfInputs = kReduceSumOpInputSize; - dest.v1.outputTensors = qnn_op_outs; - dest.v1.numOfOutputs = kReduceSumOpOutputSize; - dest.v1.numOfParams = kReduceSumOpParamSize; - dest.v1.params = reduce_sum_params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized sum op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h deleted file mode 100644 index a50e946ad069b8..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUM_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUM_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class SumOpLegalization : public Legalization { - public: - SumOpLegalization() = default; - ~SumOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUM_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc deleted file mode 100644 index 121c564c1e95fd..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnTanhOpTypeName = "Tanh"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kTanhOpFmt = "tanh_%d"; - -LiteRtStatus TanhOpLegalization::LegalizeOp(const litert::Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflTanh) { - return kLiteRtStatusLegalizeNoMatch; - } - std::string op_name = absl::StrFormat(kTanhOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnTanhOpTypeName.data(), dest)); - LITERT_RETURN_IF_ERROR(LegalizeSimpleOp(src, dest, graph_mapper)); - LITERT_LOG(LITERT_INFO, "Legalized tanh op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h deleted file mode 100644 index 486e321ae8e2d3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TANH_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TANH_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class TanhOpLegalization : public Legalization { - public: - TanhOpLegalization() = default; - ~TanhOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const litert::Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TANH_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.cc deleted file mode 100644 index 487ecce2e66d79..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h" - -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -namespace litert::qnn { - -static constexpr absl::string_view kQnnTransposeOpTypeName = "Transpose"; -static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; -static constexpr absl::string_view kTransposeOpFmt = "transpose_%d"; - -static constexpr int kTransposeOpInputSize = 1; -static constexpr int kTransposeOpOutputSize = 1; -static constexpr int kTransposeOpParamSize = 1; -static constexpr int kTransposeOpParamRank = 1; - -LiteRtStatus TransposeOpLegalization::LegalizeOp(const Op& src, - Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - if (src.Code() != kLiteRtOpCodeTflTranspose) { - return kLiteRtStatusLegalizeNoMatch; - } - DumpLegalization(*src.Get()); - std::string op_name = absl::StrFormat(kTransposeOpFmt, op_counter_++); - LITERT_RETURN_IF_ERROR(SetOpInfo(op_name.c_str(), - kDefaultQnnOpPackageName.data(), - kQnnTransposeOpTypeName.data(), dest)); - - // QNN transpose op expects 1 input tensor. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kTransposeOpInputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_ins.front().Get(), qnn_op_ins[0])); - - // QNN transpose op expects 1 output tensor. - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kTransposeOpOutputSize, - QNN_TENSOR_INIT); - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_outs.front().Get(), qnn_op_outs[0])); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); - - // Prepare QNN transpose parameters. - auto perm = Tensor(op_ins.at(1).Get()); - - // Check if src_axes are weights tensors. - if (!perm.HasWeights()) { - return kLiteRtStatusErrorInvalidLegalization; - } - auto perm_data = perm.Weights().Bytes(); - int32_t dest_axes_size = perm_data.size(); - Qnn_ClientBuffer_t perm_tensor_client_buf = BuildDefaultClientBuffer(); - perm_tensor_client_buf.data = (void*)perm_data.data(); - perm_tensor_client_buf.dataSize = dest_axes_size; - - // Construct the const tensor "perm". - Qnn_Tensor_t perm_tensor = BuildDefaultTensor(); - graph_mapper.AssignTensorName(perm_tensor); - perm_tensor.v2.dataType = QNN_DATATYPE_INT_32; - perm_tensor.v2.type = QNN_TENSOR_TYPE_STATIC; - perm_tensor.v2.rank = kTransposeOpParamRank; - perm_tensor.v2.dimensions = new uint32_t[kTransposeOpParamRank]; - perm_tensor.v2.dimensions[0] = dest_axes_size; - perm_tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; - perm_tensor.v2.clientBuf = perm_tensor_client_buf; - perm_tensor.v2.isDynamicDimensions = nullptr; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), - &perm_tensor)); - - Qnn_Param_t perm_param = BuildDefaultParam(); - perm_param.paramType = QNN_PARAMTYPE_TENSOR; - perm_param.name = "perm"; - perm_param.tensorParam = perm_tensor; - - Qnn_Param_t transpose_params[] = {perm_param}; - dest.v1.inputTensors = qnn_op_ins; - dest.v1.numOfInputs = kTransposeOpInputSize; - dest.v1.outputTensors = qnn_op_outs; - dest.v1.numOfOutputs = kTransposeOpOutputSize; - dest.v1.numOfParams = kTransposeOpParamSize; - dest.v1.params = transpose_params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - LITERT_LOG(LITERT_INFO, "Legalized transpose op", ""); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h deleted file mode 100644 index 39d7fc645c8e80..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/transpose_op_legalization.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ - -#include -#include - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" - -namespace litert::qnn { - -class TransposeOpLegalization : public Legalization { - public: - TransposeOpLegalization() = default; - ~TransposeOpLegalization() = default; - using Ptr = std::unique_ptr; - static Ptr Create() { return std::make_unique(); } - - LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - - private: - // Counter to ensure unique op names. - uint32_t op_counter_ = 0; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TRANSPOSE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc deleted file mode 100644 index 5cd0646a907fc0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/tools/dump.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn { - -using ::litert::internal::Dump; -using ::litert::internal::DumpOptions; - -// Dump source Op details. -void DumpLegalization(const LiteRtOpT& op) { - std::ostringstream dump; - // TODO Make dump tools part of stable api. - Dump(op, dump); - DumpOptions(op, dump); - std::string s = dump.str(); - LITERT_LOG(LITERT_INFO, "%s", s.data()); -} - -LiteRtStatus LegalizeSimpleOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper) { - DumpLegalization(*src.Get()); - // Look up op input tensors in scope. - const auto op_ins = src.Inputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, op_ins.size(), QNN_TENSOR_INIT); - - Qnn_Tensor_t* cur_qnn_op_in = qnn_op_ins; - for (const auto& op_in : op_ins) { - LITERT_RETURN_IF_ERROR( - graph_mapper.LookupInScope(op_in.Get(), *cur_qnn_op_in)); - ++cur_qnn_op_in; - } - - // Legalize op outputs and update scope. - - const auto op_outs = src.Outputs(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, op_outs.size(), - QNN_TENSOR_INIT); - - Qnn_Tensor_t* cur_qnn_op_out = qnn_op_outs; - for (const auto& op_out : op_outs) { - LITERT_RETURN_IF_ERROR( - graph_mapper.LegalizeAndRegister(op_out.Get(), *cur_qnn_op_out)); - LITERT_RETURN_IF_ERROR( - graph_mapper.PushToScope(op_out.Get(), *cur_qnn_op_out)); - ++cur_qnn_op_out; - } - dest.v1.numOfInputs = op_ins.size(); - dest.v1.inputTensors = qnn_op_ins; - - dest.v1.numOfOutputs = op_outs.size(); - dest.v1.outputTensors = qnn_op_outs; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); - - return kLiteRtStatusOk; -} - -LiteRtStatus BuildAndRegisterQnnNativeTensor(Qnn_DataType_t param_data_type, - uint32_t rank, uint32_t* dims, - GraphMapper& graph_mapper, - Qnn_Tensor_t& tensor) { - graph_mapper.AssignTensorName(tensor); - tensor.v2.dataType = param_data_type; - tensor.v2.type = QNN_TENSOR_TYPE_NATIVE; - tensor.v2.rank = rank; - tensor.v2.dimensions = dims; - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), - &tensor)); - return kLiteRtStatusOk; -} - -LiteRtStatus BuildAndRegisterQnnOp(uint32_t input_size, Qnn_Tensor_t* op_ins, - uint32_t output_size, Qnn_Tensor_t* op_outs, - Qnn_OpConfig_t& op, uint32_t param_size, - Qnn_Param_t* params, - GraphMapper& graph_mapper) { - op.v1.numOfInputs = input_size; - op.v1.inputTensors = op_ins; - op.v1.numOfOutputs = output_size; - op.v1.outputTensors = op_outs; - op.v1.numOfParams = param_size; - op.v1.params = params; - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), op)); - - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h deleted file mode 100644 index fb80708537b7e9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_UTIL_H_ - -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" - -namespace litert::qnn { - -// Use this function to legalize a LiteRtOp to a Qnn Op when: -// 1. Source input/output tensor and destination input/ouptut tensor are 1 : 1 -// mapped -// 2. Assigning params to destination OP does not depending on input tensor of -// source OP. -LiteRtStatus LegalizeSimpleOp(const Op& src, Qnn_OpConfig_t& dest, - GraphMapper& graph_mapper); - -// Dump source Op details. -void DumpLegalization(const LiteRtOpT& op); - -// Build and register a QNN native tensor in the QNN graph. -LiteRtStatus BuildAndRegisterQnnNativeTensor(Qnn_DataType_t param_data_type, - uint32_t rank, uint32_t* dims, - GraphMapper& graph_mapper, - Qnn_Tensor_t& tensor); - -// Build and register a QNN op in the QNN graph. -LiteRtStatus BuildAndRegisterQnnOp(uint32_t input_size, Qnn_Tensor_t* op_ins, - uint32_t output_size, Qnn_Tensor_t* op_outs, - Qnn_OpConfig_t& op, uint32_t param_size, - Qnn_Param_t* params, - GraphMapper& graph_mapper); - -// Build and register a QNN tensor param in the QNN graph. -template -LiteRtStatus BuildQnnTesnorParam(T* param_data, uint32_t* param_dims, - Qnn_DataType_t param_data_type, - uint32_t param_rank, const char* param_name, - GraphMapper& graph_mapper, - Qnn_Param_t& param) { - // Build ClientBuffer for the param tensor. - Qnn_ClientBuffer_t tensor_client_buf = BuildDefaultClientBuffer(); - tensor_client_buf.data = param_data; - tensor_client_buf.dataSize = sizeof(param_data); - - // Build QNN param tensor. - Qnn_Tensor_t param_tensor = BuildDefaultTensor(); - graph_mapper.AssignTensorName(param_tensor); - param_tensor.v2.dataType = param_data_type; - param_tensor.v2.type = QNN_TENSOR_TYPE_STATIC; - param_tensor.v2.rank = param_rank; - param_tensor.v2.dimensions = param_dims; - param_tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; - param_tensor.v2.clientBuf = tensor_client_buf; - - // Register param tensor in QNN graph. - LITERT_RETURN_STATUS_IF_QNN_NOT_OK( - graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), - ¶m_tensor)); - param.paramType = QNN_PARAMTYPE_TENSOR; - param.name = param_name; - param.tensorParam = param_tensor; - return kLiteRtStatusOk; -} - -template -LiteRtStatus BuildQnnScalarParam(T& param_data, Qnn_DataType_t param_data_type, - const char* param_name, - GraphMapper& graph_mapper, - Qnn_Param_t& param) { - // Build QNN scalar. - Qnn_Scalar_t scalar = QNN_SCALAR_INIT; - scalar.dataType = param_data_type; - - // Build QNN scalar param. - switch (param_data_type) { - case QNN_DATATYPE_BOOL_8: - scalar.bool8Value = param_data; - break; - case QNN_DATATYPE_UINT_32: - scalar.uint32Value = param_data; - break; - case QNN_DATATYPE_INT_32: - scalar.int32Value = param_data; - break; - default: - return kLiteRtStatusErrorUnsupported; - } - param.paramType = QNN_PARAMTYPE_SCALAR; - param.name = param_name; - param.scalarParam = scalar; - return kLiteRtStatusOk; -} - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_UTIL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc deleted file mode 100644 index a26a91a030d23b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc +++ /dev/null @@ -1,423 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpDevice.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -using ::litert::qnn::QnnManager; -using LiteRtBufferId = uint32_t; -using LiteRtContextHandleIdx = uint32_t; -using WeightSharingMap = - absl::flat_hash_map; - -// -// Configurations -// - -namespace { - -constexpr char kPluginManufacturer[] = "Qualcomm"; -constexpr LiteRtParamIndex kDefaultPartitionIndex = 0; - -// clang-format off -constexpr std::pair kPluginSocModels[] = { - {"V68", QNN_HTP_DEVICE_ARCH_V68}, - {"V69", QNN_HTP_DEVICE_ARCH_V69}, - {"V73", QNN_HTP_DEVICE_ARCH_V73}, - {"V75", QNN_HTP_DEVICE_ARCH_V75}, - {"V79", QNN_HTP_DEVICE_ARCH_V79}, -}; - -constexpr const char* kSocModelsSupportsWeightSharing[] = { - "V73", - "V75", - "V79", -}; -// clang-format on - -static constexpr absl::string_view kEntryPointNameFmt = "qnn_partition_%d"; - -constexpr auto kNumPluginSocModels = - sizeof(kPluginSocModels) / sizeof(kPluginSocModels[0]); - -std::optional FindSocModel( - absl::string_view soc_model_name) { - std::optional soc_model; - for (auto i = 0; i < kNumPluginSocModels; ++i) { - if (soc_model_name == kPluginSocModels[i].first) { - soc_model = kPluginSocModels[i].second; - break; - } - } - return soc_model; -} - -bool IsWeightSharingSupported(absl::string_view soc_model_name) { - return std::find(std::begin(kSocModelsSupportsWeightSharing), - std::end(kSocModelsSupportsWeightSharing), - soc_model_name) != std::end(kSocModelsSupportsWeightSharing); -} - -} // namespace - -LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version) { - if (api_version == nullptr) { - return kLiteRtStatusErrorInvalidArgument; - } - api_version->major = LITERT_API_VERSION_MAJOR; - api_version->minor = LITERT_API_VERSION_MINOR; - api_version->patch = LITERT_API_VERSION_PATCH; - return kLiteRtStatusOk; -} - -const char* LiteRtGetCompilerPluginSocManufacturer() { - return kPluginManufacturer; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( - LiteRtCompilerPlugin compiler_plugin, - LiteRtHwAccelerators* supported_hardware) { - if (!compiler_plugin || !supported_hardware) { - return kLiteRtStatusErrorInvalidArgument; - } - *supported_hardware = kLiteRtHwAcceleratorNpu; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( - LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex* num_supported_soc_models) { - if (!compiler_plugin || !num_supported_soc_models) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_supported_soc_models = kNumPluginSocModels; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, - const char** soc_model_name) { - if (!compiler_plugin || !soc_model_name) { - return kLiteRtStatusErrorInvalidArgument; - } else if (soc_model_idx < 0 || soc_model_idx >= kNumPluginSocModels) { - return kLiteRtStatusErrorInvalidArgument; - } - *soc_model_name = kPluginSocModels[soc_model_idx].first; - return kLiteRtStatusOk; -} - -// -// Compiled Result Definition -// - -struct LiteRtCompiledResultT { - std::vector> context_bin; - std::vector graph_names; -}; - -LiteRtStatus LiteRtGetCompiledResultByteCode( - LiteRtCompiledResult compiled_result, LiteRtParamIndex byte_code_idx, - const void** byte_code, size_t* byte_code_size) { - if (!compiled_result || !byte_code || !byte_code_size) { - return kLiteRtStatusErrorInvalidArgument; - } - - *byte_code = compiled_result->context_bin[byte_code_idx].data(); - *byte_code_size = compiled_result->context_bin[byte_code_idx].size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledResultCallInfo( - LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, - const void** call_info, size_t* call_info_size, - LiteRtParamIndex* byte_code_idx) { - if (!compiled_result || !call_info || !call_info_size) { - return kLiteRtStatusErrorInvalidArgument; - } else if (call_idx >= compiled_result->graph_names.size()) { - return kLiteRtStatusErrorIndexOOB; - } - - *call_info = compiled_result->graph_names.at(call_idx).data(); - *call_info_size = compiled_result->graph_names.at(call_idx).size(); - *byte_code_idx = 0; - - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompiledResultCalls( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { - if (!compiled_result || !num_calls) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_calls = compiled_result->graph_names.size(); - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompiledResult(LiteRtCompiledResult compiled_result) { - delete compiled_result; -} - -LiteRtStatus LiteRtCompiledResultNumByteCodeModules( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_byte_code) { - *num_byte_code = compiled_result->context_bin.size(); - return kLiteRtStatusOk; -} - -// -// Plugin Definition -// - -// Plugins can hold state. -struct LiteRtCompilerPluginT { - // A "key-only" flag will have an empty string as the value. - using Flag = std::pair; - std::vector flags; -}; - -LiteRtStatus LiteRtCompilerPluginSetFlags(LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex num_flags, - const char** keys, - const char** values) { - auto& flags = compiler_plugin->flags; - flags.resize(num_flags); - for (int i = 0; i < num_flags; ++i) { - auto& flag = flags[i]; - flag.first = std::string(keys[i]); - flag.second = std::string(values[i]); - } - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { - auto* plugin = new LiteRtCompilerPluginT; - *compiler_plugin = plugin; - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { - delete compiler_plugin; -} - -LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, - const char* soc_model, - LiteRtSubgraph subgraph, - LiteRtOpList selected_ops) { - ::litert::Subgraph graph(subgraph); - - auto backend_configs = QnnManager::DefaultBackendConfigs(); - // TODO: pass soc_model as parameter - auto qnn_manager = QnnManager::Create(backend_configs, std::nullopt, - {QNN_HTP_DEVICE_ARCH_V75}); - if (!qnn_manager) { - LITERT_LOG(LITERT_ERROR, "%s", qnn_manager.Error().Message().data()); - return qnn_manager.Error().Status(); - } - LITERT_LOG(LITERT_INFO, "%s", "QNN manager created"); - - for (const auto& op : graph.Ops()) { - // default constructed, won't add tensor to QNN - ::qnn::TensorPool tensor_pool; - std::vector<::qnn::TensorWrapperRef> input_tensors; - for (const auto& input : op.Inputs()) { - ::qnn::TensorWrapper* res{nullptr}; - LITERT_RETURN_IF_ERROR( - litert::qnn::ConvertTensor(input, tensor_pool, res)); - input_tensors.emplace_back(*res); - } - - std::vector<::qnn::TensorWrapperRef> output_tensors; - for (const auto& output : op.Outputs()) { - ::qnn::TensorWrapper* res{nullptr}; - LITERT_RETURN_IF_ERROR( - litert::qnn::ConvertTensor(output, tensor_pool, res)); - output_tensors.emplace_back(*res); - } - - std::vector<::qnn::OpWrapper> op_wrappers; - LITERT_RETURN_IF_ERROR(litert::qnn::ConvertOp( - op, tensor_pool, input_tensors, output_tensors, op_wrappers)); - tensor_pool.ForEach([](::qnn::TensorWrapper& tensor_wrapper) { - // TODO(chunhsue): Use compile interface to get useQInt16AsQUint16. - constexpr bool useQInt16AsQUint16 = true; - if constexpr (useQInt16AsQUint16) { - tensor_wrapper.ConvertQint16ToQuint16(); - } - }); - // Empty op_wrappers means the op is not supported by QNN. - if (op_wrappers.empty()) { - continue; - } - if (std::all_of( - op_wrappers.begin(), op_wrappers.end(), - [&qnn_manager](::qnn::OpWrapper& op_wrapper) -> bool { - return kLiteRtStatusOk == - (*qnn_manager)->ValidateOp(op_wrapper.GetOpConfig()); - })) { - LITERT_RETURN_IF_ERROR( - // Use default partition index if vendor doesn't support multiple - // partitions. - LiteRtPushOp(selected_ops, op.Get(), kDefaultPartitionIndex)); - } - } - - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtCompilerPluginCompile( - LiteRtCompilerPlugin compiler_plugin, const char* soc_model, - LiteRtModel partitions, LiteRtCompiledResult* compiled_result) { - auto model = litert::Model::CreateFromNonOwnedHandle(partitions); - const auto num_partitions = model.NumSubgraphs(); - - LITERT_LOG(LITERT_INFO, - "Starting QNN Compilation for %d subgraphs, soc_model=%s", - num_partitions, soc_model); - - auto opt_soc_model = soc_model ? FindSocModel(soc_model) : std::nullopt; - if (opt_soc_model) { - LITERT_LOG(LITERT_ERROR, "Compiling QNN architecture: %d", *opt_soc_model); - } else if (soc_model) { - LITERT_LOG(LITERT_ERROR, "Unexpected SoC model: %s", soc_model); - return kLiteRtStatusErrorInvalidArgument; - } - - auto result = std::make_unique(); - // Prepare one context binary per partition, since each partition is a - // separate subgraph that maps to a single Dispatch Op in the compiled the - // model. - result->context_bin.resize(num_partitions); - - // Initialize SDK and load qnn shared libraries. - LITERT_LOG(LITERT_INFO, "%s", "Creating QNN manager"); - auto backend_configs = QnnManager::DefaultBackendConfigs(); - auto qnn_manager = QnnManager::Create( - backend_configs, /*shared_library_dir=*/std::nullopt, opt_soc_model); - if (!qnn_manager) { - LITERT_LOG(LITERT_ERROR, "%s", qnn_manager.Error().Message().c_str()); - return qnn_manager.Error().Status(); - } - LITERT_LOG(LITERT_INFO, "%s", "QNN manager created"); - - // Map of LiteRt buffer id to context handle index. - // This map memerizes the last context handle index of a weight was registered - // in. - WeightSharingMap weight_sharing_map; - LiteRtContextHandleIdx next_context_handle_idx = 0; - - std::vector context_handles; - - // Compile each partition (subgraph) individually. - for (int partition_idx = 0; partition_idx < num_partitions; ++partition_idx) { - LiteRtContextHandleIdx context_handle_idx = next_context_handle_idx; - uint64_t largest_weight_size = 0; - // Check all weights in this subgraph, see if any of them were previously - // seen and added to existing qnn context, use the largest weight size to - // determine which context to use. - for (const auto& op : model.Subgraph(partition_idx)->Ops()) { - for (const auto& input : op.Inputs()) { - if (input.IsConstant()) { - auto buffer_id = input.Weights().Get()->GetBufferId(); - auto it = weight_sharing_map.find(buffer_id); - if (it != weight_sharing_map.end()) { - if (input.Weights().Get()->Buffer().Size() >= largest_weight_size) { - context_handle_idx = it->second; - largest_weight_size = input.Weights().Get()->Buffer().Size(); - } - } - } - } - } - // If we didn't find a existing context handle for this subgraph, create a - // new one. - if (context_handle_idx == next_context_handle_idx) { - // Initialize context. - LITERT_LOG(LITERT_INFO, "%s", "Creating context handle"); - // We enable weight sharing by default, this could lead to issue when - // support legacy SoC. - // TODO: use option to control weight sharing. - auto context_configs = QnnManager::WeightSharingContextConfigs(); - if (!IsWeightSharingSupported(soc_model)) { - context_configs = QnnManager::DefaultContextConfigs(); - } - auto context_handle = - (*qnn_manager)->CreateContextHandle(context_configs); - if (!context_handle) { - LITERT_LOG(LITERT_ERROR, "%s", - context_handle.Error().Message().c_str()); - return context_handle.Error().Status(); - } - context_handles.push_back(std::move(context_handle.Value())); - LITERT_LOG(LITERT_INFO, "%s", "Context handle created"); - ++next_context_handle_idx; - } - // Set context handle index for all weight buffers in this subgraph. - for (const auto& op : model.Subgraph(partition_idx)->Ops()) { - for (const auto& input : op.Inputs()) { - if (input.IsConstant()) { - auto buffer_id = input.Weights().Get()->GetBufferId(); - weight_sharing_map[buffer_id] = context_handle_idx; - } - } - } - - // Compose graphs. - LITERT_LOG(LITERT_INFO, "%s", "Composing graph"); - std::string& entry_point_name = result->graph_names.emplace_back(); - entry_point_name = absl::StrFormat(kEntryPointNameFmt, partition_idx); - LiteRtSubgraph partition = model.Subgraph(partition_idx)->Get(); - LITERT_RETURN_IF_ERROR(litert::qnn::ComposeGraph( - **qnn_manager, context_handles[context_handle_idx].get(), partition, - entry_point_name)); - LITERT_LOG(LITERT_INFO, "%s", "Graph composed"); - } - - // Generate context binary. - result->context_bin.resize(next_context_handle_idx); - for (int i = 0; i < next_context_handle_idx; ++i) { - LITERT_LOG(LITERT_INFO, "%s", "Generating context binary"); - LITERT_RETURN_IF_ERROR((*qnn_manager) - ->GenerateContextBinary(context_handles[i].get(), - result->context_bin[i])); - LITERT_LOG(LITERT_INFO, "Context binary %d generated", i); - } - *compiled_result = result.release(); - - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc deleted file mode 100644 index 2b6016a0578d15..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc +++ /dev/null @@ -1,399 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/matchers.h" -#include "tensorflow/lite/experimental/litert/test/test_models.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/quantize_op_legalization.h" - -namespace litert { -namespace { - -using ::testing::Values; - -// clang-format off -// TODO: Add support and uncomment these models. -const auto kSupportedOps = - Values( - "rms_norm_composite.tflite", - "simple_add_op.tflite", - "simple_div_op.tflite", - "simple_mul_op.tflite", - "simple_rsqrt_op.tflite", - "simple_slice_op.tflite", - "simple_sub_op.tflite", - "simple_sum_op.tflite", - "simple_tanh_op.tflite", - "simple_reshape_op.tflite", - "simple_batch_matmul_op.tflite", - "rms_norm.tflite", - "simple_concatenation_op.tflite", - "simple_softmax_op.tflite", - "simple_cast_op.tflite", - "simple_transpose_op.tflite", - "simple_sin_op.tflite", - "simple_cos_op.tflite", - "simple_select_op.tflite", - "simple_select_v2_op.tflite", - "simple_fully_connected_op.tflite", - "fully_connected_3d.tflite", - "simple_embedding_lookup_op.tflite", - "simple_logical_and_op.tflite", - "simple_less_op.tflite", - "simple_greater_op.tflite", - "simple_gelu_op.tflite", - "simple_dynamic_update_slice_op.tflite", - "simple_pack_op.tflite", - "simple_gather_op.tflite", - "simple_mean_op.tflite", - "simple_split_op.tflite", - "simple_average_poll_2d.tflite", - "simple_conv_2d_op.tflite", - "simple_depth_to_space_op.tflite", - "simple_depthwise_conv_2d_op.tflite", - "simple_hard_swish_op.tflite", - "simple_leaky_relu_op.tflite", - "simple_resize_bilinear_op.tflite", - "simple_space_to_depth_op.tflite", - "simple_resize_nearest_neighbor_op.tflite", - "simple_relu_op.tflite", - kFeedForwardModel, - kKeyEinsumModel, - kQueryEinsumModel, - kValueEinsumModel, - kAttnVecEinsumModel, - kROPEModel, - kLookUpROPEModel, - kRMSNormModel, - kSDPAModel, - kAttentionModel, - kTransformerBlockModel, - kQSimpleMul16x16Model, - kQMulAdd16x16Model, - kQQueryEinsum16x8Model, - kQKeyEinsum16x8Model, - kQVauleEinsum16x8Model, - kQAttnVecEinsum16x8Model - ); - -const auto kSupportedSocModels = Values( - "V68", - "V69", - "V73", - "V75", - "V79" -); -// clang-format on - -TEST(TestQnnPlugin, GetConfigInfo) { - EXPECT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), "Qualcomm"); - - auto plugin = CreatePlugin(); - - LiteRtParamIndex num_supported_soc_models; - LITERT_ASSERT_OK(LiteRtGetNumCompilerPluginSupportedSocModels( - plugin.get(), &num_supported_soc_models)); - ASSERT_EQ(num_supported_soc_models, 5); - - const char* config_id; - LITERT_ASSERT_OK( - LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, &config_id)); - EXPECT_STREQ(config_id, "V68"); -} - -TEST(TestQnnPlugin, PartitionMulOps) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("one_mul.tflite"); - - LiteRtOpListT selected_op_list; - LITERT_ASSERT_OK(LiteRtCompilerPluginPartition( - plugin.get(), /*soc_model=*/nullptr, model.Subgraph(0)->Get(), - &selected_op_list)); - const auto selected_ops = selected_op_list.Values(); - - ASSERT_EQ(selected_ops.size(), 1); - EXPECT_EQ(selected_ops[0].first->OpCode(), kLiteRtOpCodeTflMul); -} - -TEST(TestQnnPlugin, CompileMulSubgraph) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("one_mul.tflite"); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK( - LiteRtCompilerPluginCompile(plugin.get(), "V75", model.Get(), &compiled)); - - const void* byte_code; - size_t byte_code_size; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultByteCode( - compiled, /*byte_code_idx=*/0, &byte_code, &byte_code_size)); - - absl::string_view byte_code_string(reinterpret_cast(byte_code), - byte_code_size); - ASSERT_FALSE(byte_code_string.empty()); - - const void* op_data; - size_t op_data_size; - LiteRtParamIndex byte_code_idx; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultCallInfo( - compiled, /*call_idx=*/0, &op_data, &op_data_size, &byte_code_idx)); - - absl::string_view op_data_string(reinterpret_cast(op_data), - op_data_size); - ASSERT_EQ("qnn_partition_0", op_data_string); - - LiteRtDestroyCompiledResult(compiled); -} - -TEST(TestQnnPlugin, ShareContextBinary) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("cst_multi_subgraph.tflite"); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK( - LiteRtCompilerPluginCompile(plugin.get(), "V75", model.Get(), &compiled)); - uint64_t num_byte_code; - LITERT_ASSERT_OK( - LiteRtCompiledResultNumByteCodeModules(compiled, &num_byte_code)); - ASSERT_EQ(num_byte_code, 1); - - LiteRtDestroyCompiledResult(compiled); -} - -TEST(TestQnnPlugin, NotShareContextBinary) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("multi_subgraph.tflite"); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK( - LiteRtCompilerPluginCompile(plugin.get(), "V75", model.Get(), &compiled)); - uint64_t num_byte_code; - LITERT_ASSERT_OK( - LiteRtCompiledResultNumByteCodeModules(compiled, &num_byte_code)); - ASSERT_EQ(num_byte_code, 3); - - LiteRtDestroyCompiledResult(compiled); -} - -TEST(TestLegalization, QuantizeOpLegalizedToCastOp) { - static constexpr absl::string_view kQnnOpName = "Cast"; - static constexpr int kSUFixed8OffsetDiff = 128; - const auto input_quantization_params = MakePerTensorQuantization( - /*scale=*/1.0f, /*zero_point=*/0); - const auto output_quantization_params = MakePerTensorQuantization( - /*scale=*/1.0f, /*zero_point=*/kSUFixed8OffsetDiff); - LiteRtOpT quantize_op; - LiteRtTensorT input_tensor; - LiteRtTensorT output_tensor; - // Set quantization params, tensor type for input and output tensors. - input_tensor.SetQarams(input_quantization_params); - TensorType input_tensor_type = - MakeRankedTensorType(kLiteRtElementTypeInt8, {1, 1}); - input_tensor.SetType(input_tensor_type); - output_tensor.SetQarams(output_quantization_params); - TensorType output_tensor_type = - MakeRankedTensorType(kLiteRtElementTypeUInt8, {1, 1}); - output_tensor.SetType(output_tensor_type); - quantize_op.Inputs().push_back(&input_tensor); - quantize_op.Outputs().push_back(&output_tensor); - quantize_op.SetOpCode(kLiteRtOpCodeTflQuantize); - - qnn::QuantizeOpLegalization legalization; - Qnn_OpConfig_t legalized_qnn_op = qnn::BuildDefaultOp(); - litert::Op litert_quantize_op(&quantize_op); - LITERT_ASSERT_OK( - legalization.ConfigureQnnOp(litert_quantize_op, legalized_qnn_op)); - absl::string_view qnn_op_name(legalized_qnn_op.v1.typeName); - EXPECT_EQ(qnn_op_name, kQnnOpName); -} - -TEST(TestLegalization, QuantizeOpLegalizedToConvertOp) { - static constexpr absl::string_view kQnnOpName = "Convert"; - static constexpr int kSUFixed8OffsetDiff = 0; - const auto input_quantization_params = MakePerTensorQuantization( - /*scale=*/1.0f, /*zero_point=*/0); - const auto output_quantization_params = MakePerTensorQuantization( - /*scale=*/1.0f, /*zero_point=*/kSUFixed8OffsetDiff); - LiteRtOpT quantize_op; - LiteRtTensorT input_tensor; - LiteRtTensorT output_tensor; - // Set quantization params, tensor type for input and output tensors. - input_tensor.SetQarams(input_quantization_params); - TensorType input_tensor_type = - MakeRankedTensorType(kLiteRtElementTypeInt8, {1, 1}); - input_tensor.SetType(input_tensor_type); - output_tensor.SetQarams(output_quantization_params); - TensorType output_tensor_type = - MakeRankedTensorType(kLiteRtElementTypeUInt8, {1, 1}); - output_tensor.SetType(output_tensor_type); - quantize_op.Inputs().push_back(&input_tensor); - quantize_op.Outputs().push_back(&output_tensor); - quantize_op.SetOpCode(kLiteRtOpCodeTflQuantize); - - qnn::QuantizeOpLegalization legalization; - Qnn_OpConfig_t legalized_qnn_op = qnn::BuildDefaultOp(); - litert::Op litert_quantize_op(&quantize_op); - LITERT_ASSERT_OK( - legalization.ConfigureQnnOp(litert_quantize_op, legalized_qnn_op)); - absl::string_view qnn_op_name(legalized_qnn_op.v1.typeName); - EXPECT_EQ(qnn_op_name, kQnnOpName); -} - -TEST(TestLegalization, QuantizeOpLegalizedToQuantizeOp) { - static constexpr absl::string_view kQnnOpName = "Quantize"; - const auto output_quantization_params = MakePerTensorQuantization( - /*scale=*/1.0f, /*zero_point=*/0); - LiteRtOpT quantize_op; - LiteRtTensorT input_tensor; - LiteRtTensorT output_tensor; - // Set quantization params, tensor type for input and output tensors. - TensorType input_tensor_type = - MakeRankedTensorType(kLiteRtElementTypeFloat32, {1, 1}); - input_tensor.SetType(input_tensor_type); - output_tensor.SetQarams(output_quantization_params); - TensorType output_tensor_type = - MakeRankedTensorType(kLiteRtElementTypeInt16, {1, 1}); - output_tensor.SetType(output_tensor_type); - quantize_op.Inputs().push_back(&input_tensor); - quantize_op.Outputs().push_back(&output_tensor); - quantize_op.SetOpCode(kLiteRtOpCodeTflQuantize); - - qnn::QuantizeOpLegalization legalization; - Qnn_OpConfig_t legalized_qnn_op = qnn::BuildDefaultOp(); - litert::Op litert_quantize_op(&quantize_op); - LITERT_ASSERT_OK( - legalization.ConfigureQnnOp(litert_quantize_op, legalized_qnn_op)); - absl::string_view qnn_op_name(legalized_qnn_op.v1.typeName); - EXPECT_EQ(qnn_op_name, kQnnOpName); -} - -class QnnPlyginSupportedSocCompilationTest - : public ::testing::TestWithParam {}; - -TEST_P(QnnPlyginSupportedSocCompilationTest, CompileMulSubgraph) { - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel("one_mul.tflite"); - auto soc_model = GetParam(); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK(LiteRtCompilerPluginCompile(plugin.get(), soc_model.c_str(), - model.Get(), &compiled)); - - const void* byte_code; - size_t byte_code_size; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultByteCode( - compiled, /*byte_code_idx=*/0, &byte_code, &byte_code_size)); - - absl::string_view byte_code_string(reinterpret_cast(byte_code), - byte_code_size); - ASSERT_FALSE(byte_code_string.empty()); - - const void* op_data; - size_t op_data_size; - LiteRtParamIndex byte_code_idx; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultCallInfo( - compiled, /*call_idx=*/0, &op_data, &op_data_size, &byte_code_idx)); - - absl::string_view op_data_string(reinterpret_cast(op_data), - op_data_size); - ASSERT_EQ("qnn_partition_0", op_data_string); - - LiteRtDestroyCompiledResult(compiled); -} - -INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, QnnPlyginSupportedSocCompilationTest, - kSupportedSocModels); - -class QnnPluginOpValidationTest : public ::testing::TestWithParam { -}; - -TEST_P(QnnPluginOpValidationTest, SupportedOpsTest) { - LITERT_LOG(LITERT_INFO, "Validating TFLite model: %s", GetParam().c_str()); - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel(GetParam()); - - const auto subgraph = model.MainSubgraph(); - LiteRtSubgraph litert_subgraph = subgraph->Get(); - - LiteRtOpListT selected_ops; - LITERT_ASSERT_OK(LiteRtCompilerPluginPartition( - plugin.get(), /*soc_model=*/nullptr, litert_subgraph, &selected_ops)); - - EXPECT_EQ(selected_ops.Values().size(), litert_subgraph->Ops().size()); -} - -INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, QnnPluginOpValidationTest, - kSupportedOps); - -class QnnPluginOpCompatibilityTest - : public ::testing::TestWithParam {}; - -TEST_P(QnnPluginOpCompatibilityTest, SupportedOpsTest) { - LITERT_LOG(LITERT_INFO, "Testing TFLite model: %s", GetParam().c_str()); - auto plugin = CreatePlugin(); - auto model = testing::LoadTestFileModel(GetParam()); - - LiteRtCompiledResult compiled; - LITERT_ASSERT_OK( - LiteRtCompilerPluginCompile(plugin.get(), "V75", model.Get(), &compiled)); - - const void* byte_code; - size_t byte_code_size; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultByteCode( - compiled, /*byte_code_idx=*/0, &byte_code, &byte_code_size)); - - absl::string_view byte_code_string(reinterpret_cast(byte_code), - byte_code_size); - ASSERT_FALSE(byte_code_string.empty()); - - const void* op_data; - size_t op_data_size; - LiteRtParamIndex byte_code_idx; - - LITERT_ASSERT_OK(LiteRtGetCompiledResultCallInfo( - compiled, /*call_idx=*/0, &op_data, &op_data_size, &byte_code_idx)); - - absl::string_view op_data_string(reinterpret_cast(op_data), - op_data_size); - ASSERT_EQ("qnn_partition_0", op_data_string); - - LiteRtDestroyCompiledResult(compiled); -} - -INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, QnnPluginOpCompatibilityTest, - kSupportedOps); - -} // namespace -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc deleted file mode 100644 index 1c1a7fca31408d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc +++ /dev/null @@ -1,758 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h" - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" -#include "tensorflow/lite/experimental/litert/c/litert_options.h" -#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/tools/dump.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn { - -using ::litert::internal::Dump; - -LiteRtStatus ConvertPaddingType(const uint32_t litert_padding, - ::qnn::PaddingType& qnn_padding) { - switch (litert_padding) { - case 0: { - qnn_padding = ::qnn::PaddingType::Same; - break; - } - case 1: { - qnn_padding = ::qnn::PaddingType::Valid; - break; - } - default: { - return kLiteRtStatusErrorUnsupported; - } - } - return kLiteRtStatusOk; -} - -LiteRtStatus ConvertDataType(const litert::ElementType litert_type, - const bool is_quantized, - Qnn_DataType_t& qnn_type) { - qnn_type = QNN_DATATYPE_UNDEFINED; - switch (litert_type) { - case litert::ElementType::Bool: - qnn_type = QNN_DATATYPE_BOOL_8; - break; - case litert::ElementType::Int4: - qnn_type = QNN_DATATYPE_SFIXED_POINT_4; - break; - case litert::ElementType::Int8: - qnn_type = - is_quantized ? QNN_DATATYPE_SFIXED_POINT_8 : QNN_DATATYPE_INT_8; - break; - case litert::ElementType::Int16: - qnn_type = - is_quantized ? QNN_DATATYPE_SFIXED_POINT_16 : QNN_DATATYPE_INT_16; - break; - case litert::ElementType::Int32: - qnn_type = - is_quantized ? QNN_DATATYPE_SFIXED_POINT_32 : QNN_DATATYPE_INT_32; - break; - case litert::ElementType::Int64: - qnn_type = QNN_DATATYPE_INT_64; - break; - case litert::ElementType::UInt8: - qnn_type = - is_quantized ? QNN_DATATYPE_UFIXED_POINT_8 : QNN_DATATYPE_UINT_8; - break; - case litert::ElementType::UInt16: - qnn_type = - is_quantized ? QNN_DATATYPE_UFIXED_POINT_16 : QNN_DATATYPE_UINT_16; - break; - case litert::ElementType::UInt32: - qnn_type = - is_quantized ? QNN_DATATYPE_UFIXED_POINT_32 : QNN_DATATYPE_UINT_32; - break; - case litert::ElementType::UInt64: - qnn_type = QNN_DATATYPE_UINT_64; - break; - case litert::ElementType::Float16: - qnn_type = QNN_DATATYPE_FLOAT_16; - break; - case litert::ElementType::Float32: - qnn_type = QNN_DATATYPE_FLOAT_32; - break; - case litert::ElementType::Float64: - qnn_type = QNN_DATATYPE_FLOAT_64; - break; - default: - return kLiteRtStatusErrorUnsupported; - } - return kLiteRtStatusOk; -} - -LiteRtStatus ConvertTensor(const litert::Tensor& litert_tensor, - ::qnn::TensorPool& tensor_pool, - ::qnn::TensorWrapper*& tensor_wrapper, - bool is_tensor_read_and_write) { - tensor_wrapper = nullptr; - - if (litert_tensor.TypeId() != kLiteRtRankedTensorType) { - return kLiteRtStatusErrorInvalidArgument; - } - - const auto ranked_tensor_type = litert_tensor.RankedTensorType(); - if (!ranked_tensor_type) { - LITERT_LOG(LITERT_ERROR, "%s", ranked_tensor_type.Error().Message().data()); - return ranked_tensor_type.Error().Status(); - } - - Qnn_DataType_t qnn_data_type; - LITERT_RETURN_IF_ERROR(ConvertDataType(ranked_tensor_type->ElementType(), - litert_tensor.HasQuantization(), - qnn_data_type)); - - std::vector dimentions; - const auto litert_layout = ranked_tensor_type->Layout(); - if (litert_layout.Rank() == 0) { - dimentions.resize(1, 1); - } else { - dimentions.resize(litert_layout.Rank()); - for (size_t i = 0; i < dimentions.size(); ++i) { - dimentions[i] = litert_layout.Dimensions()[i]; - } - } - - ::qnn::QuantizeParamsWrapperVariant quantize_params; - switch (litert_tensor.QTypeId()) { - case kLiteRtQuantizationPerTensor: { - const auto per_tensor_quant = litert_tensor.PerTensorQuantization(); - quantize_params.emplace<::qnn::ScaleOffsetQuantizeParamsWrapper>( - per_tensor_quant.scale, per_tensor_quant.zero_point); - break; - } - case kLiteRtQuantizationPerChannel: { - const auto per_channel_quant = litert_tensor.PerChannelQuantization(); - // convert zero points from std::int64_t to std::int32_t - std::vector zero_points(per_channel_quant.num_channels); - for (size_t i = 0; i < zero_points.size(); ++i) { - zero_points[i] = per_channel_quant.zero_points[i]; - } - quantize_params.emplace<::qnn::AxisScaleOffsetQuantizeParamsWrapper>( - per_channel_quant.quantized_dimension, - absl::Span{per_channel_quant.scales, - per_channel_quant.num_channels}, - absl::Span{zero_points.data(), - zero_points.size()}); - break; - } - case kLiteRtQuantizationBlockWise: { - LITERT_LOG(LITERT_ERROR, "Unsupported quantization type."); - return kLiteRtStatusErrorInvalidArgument; - } - case kLiteRtQuantizationNone: - default: - break; - } - - if (litert_tensor.IsSubgraphInput()) { - auto& res = tensor_pool.CreateInputTensor(qnn_data_type, quantize_params, - dimentions); - tensor_wrapper = &res; - } else if (litert_tensor.IsSubgraphOutput() || is_tensor_read_and_write) { - auto& res = tensor_pool.CreateOutpuTensor(qnn_data_type, quantize_params, - dimentions); - tensor_wrapper = &res; - } else if (litert_tensor.IsConstant()) { - LITERT_ENSURE(litert_tensor.HasWeights(), - kLiteRtStatusErrorInvalidLegalization, - "Empty weights for constant tensor."); - auto& res = tensor_pool.CreateStaticTensor( - qnn_data_type, quantize_params, dimentions, - litert_tensor.Weights().Bytes().size(), - reinterpret_cast(litert_tensor.Weights().Bytes().data())); - tensor_wrapper = &res; - } else { - auto& res = tensor_pool.CreateNativeTensor(qnn_data_type, quantize_params, - dimentions); - tensor_wrapper = &res; - } - return kLiteRtStatusOk; -} - -LiteRtStatus ConvertOp( - const litert::Op& litert_op, ::qnn::TensorPool& tensor_pool, - const std::vector<::qnn::TensorWrapperRef>& input_tensors, - const std::vector<::qnn::TensorWrapperRef>& output_tensors, - std::vector<::qnn::OpWrapper>& op_wrappers) { - switch (litert_op.Code()) { - case LiteRtOpCode::kLiteRtOpCodeTflCast: { - op_wrappers = - ::qnn::BuildCastOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflConcatenation: { - int32_t axis{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetConcatenationAxisOption(litert_op.Get(), &axis)); - op_wrappers = ::qnn::BuildConcatenationOp(tensor_pool, input_tensors, - output_tensors, axis); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflAdd: { - uint32_t fused_activation{}; - LITERT_RETURN_IF_ERROR(LiteRtGetAddFusedActivationOption( - litert_op.Get(), &fused_activation)); - op_wrappers = ::qnn::BuildElementwiseAddOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflLogicalAnd: { - op_wrappers = ::qnn::BuildElementwiseAndOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflCos: { - op_wrappers = ::qnn::BuildElementwiseCosOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflDiv: { - uint32_t fused_activation{}; - LITERT_RETURN_IF_ERROR(LiteRtGetDivFusedActivationOption( - litert_op.Get(), &fused_activation)); - op_wrappers = ::qnn::BuildElementwiseDivOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflGreater: { - op_wrappers = ::qnn::BuildElementwiseGreaterOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflLess: { - op_wrappers = ::qnn::BuildElementwiseLessOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflMul: { - uint32_t fused_activation{}; - LITERT_RETURN_IF_ERROR(LiteRtGetMulFusedActivationOption( - litert_op.Get(), &fused_activation)); - op_wrappers = ::qnn::BuildElementwiseMulOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflRsqrt: { - op_wrappers = ::qnn::BuildElementwiseRsqrtOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSin: { - op_wrappers = ::qnn::BuildElementwiseSinOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSquaredDifference: { - op_wrappers = ::qnn::BuildElementwiseSquaredDifferenceOp( - tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSquare: { - op_wrappers = ::qnn::BuildElementwiseSquareOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSub: { - uint32_t fused_activation{}; - LITERT_RETURN_IF_ERROR(LiteRtGetSubFusedActivationOption( - litert_op.Get(), &fused_activation)); - op_wrappers = ::qnn::BuildElementwiseSubOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflMinimum: { - op_wrappers = ::qnn::BuildElementwiseMinimumOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflMaximum: { - op_wrappers = ::qnn::BuildElementwiseMaximumOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflEmbeddingLookup: { - op_wrappers = ::qnn::BuildEmbeddingLookupOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflFullyConnected: { - uint32_t fused_activation{}; - LITERT_RETURN_IF_ERROR(LiteRtGetFullyConnectedFusedActivationOption( - litert_op.Get(), &fused_activation)); - bool keep_num_dims{}; - LITERT_RETURN_IF_ERROR(LiteRtGetFullyConnectedKeepNumDimsOption( - litert_op.Get(), &keep_num_dims)); - // TODO(jiunkaiy): Use compile interface to get useHtpPreferencs. - constexpr LiteRtQnnOptions qnn_options = LITERT_QNN_OPTIONS_INIT; - if (qnn_options.useHtpPreferencs) { - op_wrappers = ::qnn::BuildFullyConnectedOpHtp( - tensor_pool, input_tensors, output_tensors, keep_num_dims); - } - if (op_wrappers.empty()) { - op_wrappers = ::qnn::BuildFullyConnectedOp( - tensor_pool, input_tensors, output_tensors, keep_num_dims); - } - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflGather: { - int32_t axis{}; - LITERT_RETURN_IF_ERROR(LiteRtGetGatherAxisOption(litert_op.Get(), &axis)); - int32_t batch_dims{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetGatherBatchDimsOption(litert_op.Get(), &batch_dims)); - op_wrappers = ::qnn::BuildGatherOp(tensor_pool, input_tensors, - output_tensors, axis, batch_dims); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflGelu: { - op_wrappers = - ::qnn::BuildGeluOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflRelu: { - op_wrappers = - ::qnn::BuildReluOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflRelu6: { - op_wrappers = - ::qnn::BuildRelu6Op(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflBatchMatmul: { - bool adj_x{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetBatchMatmulAdjXOption(litert_op.Get(), &adj_x)); - bool adj_y{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetBatchMatmulAdjYOption(litert_op.Get(), &adj_y)); - op_wrappers = ::qnn::BuildMatmulOp(tensor_pool, input_tensors, - output_tensors, adj_x, adj_y); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflMean: { - bool keep_dims{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetMeanKeepDimsOption(litert_op.Get(), &keep_dims)); - op_wrappers = ::qnn::BuildMeanOp(tensor_pool, input_tensors, - output_tensors, keep_dims); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflQuantize: { - op_wrappers = - ::qnn::BuildQuantizeOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflDequantize: { - op_wrappers = - ::qnn::BuildDequantizeOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSum: { - bool keep_dims{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetSumKeepDimsOption(litert_op.Get(), &keep_dims)); - op_wrappers = ::qnn::BuildReduceSumOp(tensor_pool, input_tensors, - output_tensors, keep_dims); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflReshape: { - op_wrappers = - ::qnn::BuildReshapeOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSelect: - case LiteRtOpCode::kLiteRtOpCodeTflSelectV2: { - op_wrappers = - ::qnn::BuildSelectOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSlice: { - op_wrappers = - ::qnn::BuildSliceOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSoftmax: { - float beta{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetSoftmaxBetaOption(litert_op.Get(), &beta)); - op_wrappers = ::qnn::BuildSoftmaxOp(tensor_pool, input_tensors, - output_tensors, beta); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSplit: { - int32_t num_splits{}; - LITERT_RETURN_IF_ERROR( - LiteRtGetSplitNumSplitsOption(litert_op.Get(), &num_splits)); - op_wrappers = ::qnn::BuildSplitOp(tensor_pool, input_tensors, - output_tensors, num_splits); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflTanh: { - op_wrappers = - ::qnn::BuildTanhOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflTranspose: { - op_wrappers = - ::qnn::BuildTransposeOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflPack: { - int32_t axis{}; - LiteRtGetPackAxisOption(litert_op.Get(), &axis); - op_wrappers = - ::qnn::BuildPackOp(tensor_pool, input_tensors, output_tensors, axis); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflDynamicUpdateSlice: { - op_wrappers = ::qnn::BuildDynamicUpdateSliceOp(tensor_pool, input_tensors, - output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeShloComposite: { - // TODO(yunandrew): Support custom epsilon for RMS Norm. - float epsilon = 9.99999997E-7; - op_wrappers = ::qnn::BuildRmsNormOp(tensor_pool, input_tensors, - output_tensors, epsilon); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflConv2d: { - uint32_t padding; - LITERT_RETURN_IF_ERROR( - LiteRtGetConv2dPaddingOption(litert_op.Get(), &padding)); - int32_t stride_w; - LITERT_RETURN_IF_ERROR( - LiteRtGetConv2dStrideWOption(litert_op.Get(), &stride_w)); - int32_t stride_h; - LITERT_RETURN_IF_ERROR( - LiteRtGetConv2dStrideHOption(litert_op.Get(), &stride_h)); - int32_t dilation_w_factor; - LITERT_RETURN_IF_ERROR( - LiteRtGetConv2dDilationWOption(litert_op.Get(), &dilation_w_factor)); - int32_t dilation_h_factor; - LITERT_RETURN_IF_ERROR( - LiteRtGetConv2dDilationWOption(litert_op.Get(), &dilation_h_factor)); - - ::qnn::PaddingType qnn_padding; - LITERT_RETURN_IF_ERROR(ConvertPaddingType(padding, qnn_padding)); - op_wrappers = ::qnn::BuildConv2dOp( - tensor_pool, input_tensors, output_tensors, stride_h, stride_w, - dilation_h_factor, dilation_w_factor, qnn_padding); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflDepthwiseConv2d: { - uint32_t padding; - LITERT_RETURN_IF_ERROR( - LiteRtGetDepthwiseConv2dPaddingOption(litert_op.Get(), &padding)); - int32_t stride_w; - LITERT_RETURN_IF_ERROR( - LiteRtGetDepthwiseConv2dStrideWOption(litert_op.Get(), &stride_w)); - int32_t stride_h; - LITERT_RETURN_IF_ERROR( - LiteRtGetDepthwiseConv2dStrideHOption(litert_op.Get(), &stride_h)); - int32_t dilation_w_factor; - LITERT_RETURN_IF_ERROR(LiteRtGetDepthwiseConv2dDilationWOption( - litert_op.Get(), &dilation_w_factor)); - int32_t dilation_h_factor; - LITERT_RETURN_IF_ERROR(LiteRtGetDepthwiseConv2dDilationHOptions( - litert_op.Get(), &dilation_h_factor)); - - ::qnn::PaddingType qnn_padding; - LITERT_RETURN_IF_ERROR(ConvertPaddingType(padding, qnn_padding)); - op_wrappers = ::qnn::BuildDepthwiseConv2dOp( - tensor_pool, input_tensors, output_tensors, stride_h, stride_w, - dilation_h_factor, dilation_w_factor, qnn_padding); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflAveragePool2d: { - uint32_t padding; - LITERT_RETURN_IF_ERROR( - LiteRtGetAveragePool2dPaddingOption(litert_op.Get(), &padding)); - int32_t stride_w; - LITERT_RETURN_IF_ERROR( - LiteRtGetAveragePool2dStrideWOption(litert_op.Get(), &stride_w)); - int32_t stride_h; - LITERT_RETURN_IF_ERROR( - LiteRtGetAveragePool2dStrideHOption(litert_op.Get(), &stride_h)); - int32_t filter_width; - LITERT_RETURN_IF_ERROR(LiteRtGetAveragePool2dFilterWidthOption( - litert_op.Get(), &filter_width)); - int32_t filter_height; - LITERT_RETURN_IF_ERROR(LiteRtGetAveragePool2dFilterHeightOption( - litert_op.Get(), &filter_height)); - - ::qnn::PaddingType qnn_padding; - LITERT_RETURN_IF_ERROR(ConvertPaddingType(padding, qnn_padding)); - op_wrappers = ::qnn::BuildAveragePoolOp( - tensor_pool, input_tensors, output_tensors, stride_h, stride_w, - filter_height, filter_width, qnn_padding); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflDepthToSpace: { - int32_t block_size; - LITERT_RETURN_IF_ERROR( - LiteRtGetDepthToSpaceBlockSizeOption(litert_op.Get(), &block_size)); - op_wrappers = ::qnn::BuildDepthToSpaceOp(tensor_pool, input_tensors, - output_tensors, block_size); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflSpaceToDepth: { - int32_t block_size; - LITERT_RETURN_IF_ERROR( - LiteRtGetSpaceToDepthBlockSizeOption(litert_op.Get(), &block_size)); - op_wrappers = ::qnn::BuildSpaceToDepthOp(tensor_pool, input_tensors, - output_tensors, block_size); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflHardSwish: { - op_wrappers = - ::qnn::BuildHardSwishOp(tensor_pool, input_tensors, output_tensors); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflLeakyRelu: { - float alpha; - LITERT_RETURN_IF_ERROR( - LiteRtGetLeakyReluAlphaOption(litert_op.Get(), &alpha)); - op_wrappers = ::qnn::BuildLeakyReluOp(tensor_pool, input_tensors, - output_tensors, alpha); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflResizeBilinear: { - bool align_corners; - LITERT_RETURN_IF_ERROR(LiteRtGetResizeBilinearAlignCornersOption( - litert_op.Get(), &align_corners)); - bool half_pixel_centers; - LITERT_RETURN_IF_ERROR(LiteRtGetResizeBilinearHalfPixelCenterOption( - litert_op.Get(), &half_pixel_centers)); - op_wrappers = ::qnn::BuildResizeBilinearOp(tensor_pool, input_tensors, - output_tensors, align_corners, - half_pixel_centers); - break; - } - case LiteRtOpCode::kLiteRtOpCodeTflResizeNearestNeighbor: { - bool align_corners; - LITERT_RETURN_IF_ERROR(LiteRtGetResizeNearestNeighborAlignCornersOption( - litert_op.Get(), &align_corners)); - bool half_pixel_centers; - LITERT_RETURN_IF_ERROR( - LiteRtGetResizeNearestNeighborHalfPixelCenterOption( - litert_op.Get(), &half_pixel_centers)); - op_wrappers = ::qnn::BuildResizeNearestOp(tensor_pool, input_tensors, - output_tensors, align_corners, - half_pixel_centers); - break; - } - default: { - LITERT_LOG(LITERT_ERROR, - "LiteRT Op Code: %d is not supported in Qualcomm Compiler.", - litert_op.Code()); - } - } - return kLiteRtStatusOk; -} - -LiteRtStatus MapGraph(QnnManager& qnn, Qnn_ContextHandle_t context_handle, - LiteRtSubgraph subgraph, - absl::string_view qnn_graph_name) { - GraphMapper graph_mapper(subgraph, qnn, context_handle); - LITERT_RETURN_IF_ERROR(graph_mapper.IsLiteRtSubgraphSupported()); - LITERT_RETURN_IF_ERROR(graph_mapper.InitQnnGraph(qnn_graph_name)); - - // - // Legalize subgraph inputs and update tensors in scope - // - - ::qnn::TensorPool tensor_pool; - absl::flat_hash_map - litert_tensor_to_wrapper; - - for (const auto& subgraph_input : graph_mapper.Graph().Inputs()) { - ::qnn::TensorWrapper* tensor_wrapper{nullptr}; - LITERT_RETURN_IF_ERROR( - ConvertTensor(subgraph_input, tensor_pool, tensor_wrapper)); - litert_tensor_to_wrapper.emplace(subgraph_input.Get(), tensor_wrapper); - } - - for (const auto& subgraph_output : graph_mapper.Graph().Outputs()) { - graph_mapper.RegisterOutput(subgraph_output.Get()); - } - // - // Topologically traverse graph, legalizing and updating tensors in scope - // - - // TODO: make ConvertOp accept a vector and append OpWrapper in it. - std::vector<::qnn::OpWrapper> graph_op_wrappers; - std::ostringstream dump; - for (const auto& op : graph_mapper.Graph().Ops()) { - // Dump op info. - dump.clear(); - Dump(*op.Get(), dump); - std::string s = dump.str(); - LITERT_LOG(LITERT_VERBOSE, "%s", s.data()); - - std::vector<::qnn::TensorWrapperRef> input_tensors; - for (const auto& input : op.Inputs()) { - if (const auto it = litert_tensor_to_wrapper.find(input.Get()); - it == litert_tensor_to_wrapper.end()) { - ::qnn::TensorWrapper* tensor_wrapper{nullptr}; - LITERT_RETURN_IF_ERROR( - ConvertTensor(input, tensor_pool, tensor_wrapper)); - // add into map to capture re-used static tensor - litert_tensor_to_wrapper.emplace(input.Get(), tensor_wrapper); - input_tensors.emplace_back(*tensor_wrapper); - } else { - input_tensors.emplace_back(*(it->second)); - } - } - - std::vector<::qnn::TensorWrapperRef> output_tensors; - for (const auto& output : op.Outputs()) { - bool is_tensor_read_and_write = graph_mapper.IsTensorOutput(output.Get()); - ::qnn::TensorWrapper* tensor_wrapper{nullptr}; - LITERT_RETURN_IF_ERROR(ConvertTensor(output, tensor_pool, tensor_wrapper, - is_tensor_read_and_write)); - litert_tensor_to_wrapper.emplace(output.Get(), tensor_wrapper); - output_tensors.emplace_back(*tensor_wrapper); - } - - std::vector<::qnn::OpWrapper> op_wrappers; - LITERT_RETURN_IF_ERROR( - ConvertOp(op, tensor_pool, input_tensors, output_tensors, op_wrappers)); - std::move(op_wrappers.begin(), op_wrappers.end(), - std::back_inserter(graph_op_wrappers)); - } - // Insert all tensors into Qnn graph and update the id of Qnn_Tensor_t inside. - tensor_pool.ForEach( - [&qnn, &graph_mapper](::qnn::TensorWrapper& tensor_wrapper) { - // TODO(chunhsue): Use compile interface to get useQInt16AsQUint16. - constexpr bool useQInt16AsQUint16 = true; - if constexpr (useQInt16AsQUint16) { - tensor_wrapper.ConvertQint16ToQuint16(); - } - qnn.Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), - &tensor_wrapper.GetQnnTensor()); - }); - // Then op can be added into Qnn graph after the tensor ids are updated. - for (auto& op_wrapper : graph_op_wrappers) { - qnn.Api()->graphAddNode(graph_mapper.QnnGraph(), op_wrapper.GetOpConfig()); - } - - LITERT_RETURN_STATUS_IF_QNN_NOT_OK(graph_mapper.Finalize()); - - return kLiteRtStatusOk; -} - -//===----------------------------------------------------------------------===// -// -// [WIP] LiteRT SUBGRAPH -> QNN GRAPH -// -// Core driver for IR translation. Traverses LiteRt Subgraph, iteratively -// "legalizing" (mapping) LiteRt entities to their QNN counterpart. -// -// APPROACH: -// -// To support the general case we will need a driver loop that either -// traverses input recursively through edges or just iterates topologically. -// -// The algorithm is pretty straightforward: -// * Store mapping between already evaluated LiteRtTensors and their -// newly constructed Qnn Tensor counterpart. -// * Look up QNN Tensors when setting QNN Op inputs. -// * Add new QNN Tensor when setting QNN Op outputs. -// -// NOTES ON QNN API: -// -// After QNN Tensors are registered in the context, they need only -// be stored as their ID. QNN Tensor and "id" : uint32_t are used -// interchangeably. -// -//===----------------------------------------------------------------------===// - -LiteRtStatus ComposeGraph(QnnManager& qnn, Qnn_ContextHandle_t context_handle, - LiteRtSubgraph subgraph, - absl::string_view qnn_graph_name) { - LITERT_RETURN_IF_ERROR( - MapGraph(qnn, context_handle, subgraph, qnn_graph_name)); - return kLiteRtStatusOk; -} - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h deleted file mode 100644 index 3c43e1901acb02..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_QNN_COMPOSE_GRAPH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_QNN_COMPOSE_GRAPH_H_ - -#include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn { - -LiteRtStatus ConvertDataType(const litert::ElementType litert_type, - const bool is_quantized, Qnn_DataType_t& qnn_type); - -LiteRtStatus ConvertTensor(const litert::Tensor& litert_tensor, - ::qnn::TensorPool& tensor_pool, - ::qnn::TensorWrapper*& tensor_wrapper, - bool is_tensor_read_and_write = false); - -LiteRtStatus ConvertOp( - const litert::Op& litert_op, ::qnn::TensorPool& tensor_pool, - const std::vector<::qnn::TensorWrapperRef>& input_tensors, - const std::vector<::qnn::TensorWrapperRef>& output_tensors, - std::vector<::qnn::OpWrapper>& op_wrappers); - -// Composes a new QNN Graph from given LiteRt Graph. Qnn Graph is written to -// context behind "qnn". Uses given graph_name to name entry point. -LiteRtStatus ComposeGraph(QnnManager& qnn, Qnn_ContextHandle_t context_handle, - LiteRtSubgraph subgraph, - absl::string_view qnn_graph_name); - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_QNN_COMPOSE_GRAPH_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.cc deleted file mode 100644 index 366f3a228b06d1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.cc +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h" - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemContext.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h" - -namespace litert { -namespace qnn { - -namespace { - -Expected InsertQnnTensors(int num_qnn_tensors, Qnn_Tensor_t* qnn_tensors, - std::vector* tensors) { - tensors->clear(); - tensors->reserve(num_qnn_tensors); - for (auto i = 0; i < num_qnn_tensors; ++i) { - auto tensor = QnnTensor::Create(qnn_tensors[i]); - if (!tensor) { - return Unexpected(tensor.Error()); - } - tensors->push_back(std::move(*tensor)); - } - return {}; -} - -Expected InsertQnnGraphInfos( - int num_qnn_graph_infos, QnnSystemContext_GraphInfo_t* qnn_graph_infos, - std::vector* graphs) { - graphs->clear(); - graphs->reserve(num_qnn_graph_infos); - for (auto i = 0; i < num_qnn_graph_infos; ++i) { - auto graph = GraphInfo::Create(qnn_graph_infos[i]); - if (!graph) { - return Unexpected(graph.Error()); - } - graphs->push_back(std::move(*graph)); - } - - return {}; -} - -} // namespace - -Expected GraphInfo::Create( - const QnnSystemContext_GraphInfo_t& graph_info) { - GraphInfo info; - auto status = info.Init(graph_info); - if (status) { - return info; - } else { - return Unexpected(status.Error()); - } -} - -Expected GraphInfo::Init(const QnnSystemContext_GraphInfo_t& graph_info) { - if (graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { - const auto& graph_info_ = graph_info.graphInfoV1; - name_ = graph_info_.graphName; - LITERT_LOG(LITERT_INFO, "Found qnn graph: %s", name_.c_str()); - - if (auto status = InsertQnnTensors(graph_info_.numGraphInputs, - graph_info_.graphInputs, &inputs_); - !status) { - return Unexpected(status.Error()); - } - if (auto status = InsertQnnTensors(graph_info_.numGraphOutputs, - graph_info_.graphOutputs, &outputs_); - !status) { - return Unexpected(status.Error()); - } - - } else if (graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2) { - const auto& graph_info_ = graph_info.graphInfoV2; - name_ = graph_info_.graphName; - LITERT_LOG(LITERT_INFO, "Found qnn graph: %s", name_.c_str()); - - if (auto status = InsertQnnTensors(graph_info_.numGraphInputs, - graph_info_.graphInputs, &inputs_); - !status) { - return Unexpected(status.Error()); - } - if (auto status = InsertQnnTensors(graph_info_.numGraphOutputs, - graph_info_.graphOutputs, &outputs_); - !status) { - return Unexpected(status.Error()); - } - } else if (graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) { - const auto& graph_info_ = graph_info.graphInfoV3; - name_ = graph_info_.graphName; - LITERT_LOG(LITERT_INFO, "Found qnn graph: %s", name_.c_str()); - - if (auto status = InsertQnnTensors(graph_info_.numGraphInputs, - graph_info_.graphInputs, &inputs_); - !status) { - return Unexpected(status.Error()); - } - if (auto status = InsertQnnTensors(graph_info_.numGraphOutputs, - graph_info_.graphOutputs, &outputs_); - !status) { - return Unexpected(status.Error()); - } - - } else { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported graph info version."); - } - return {}; -} - -Expected ContextBinaryInfo::Init( - const QnnSystemContext_BinaryInfo_t& binary_info) { - if (binary_info.version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { - const auto& context_binary_info = binary_info.contextBinaryInfoV1; - if (auto status = InsertQnnTensors(context_binary_info.numContextTensors, - context_binary_info.contextTensors, - &context_tensors_); - !status) { - return Unexpected(status.Error()); - } - if (auto status = InsertQnnGraphInfos(context_binary_info.numGraphs, - context_binary_info.graphs, &graphs_); - !status) { - return Unexpected(status.Error()); - } - - } else if (binary_info.version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { - const auto& context_binary_info = binary_info.contextBinaryInfoV2; - if (auto status = InsertQnnTensors(context_binary_info.numContextTensors, - context_binary_info.contextTensors, - &context_tensors_); - !status) { - return Unexpected(status.Error()); - } - if (auto status = InsertQnnGraphInfos(context_binary_info.numGraphs, - context_binary_info.graphs, &graphs_); - !status) { - return Unexpected(status.Error()); - } - } else if (binary_info.version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { - const auto& context_binary_info = binary_info.contextBinaryInfoV3; - if (auto status = InsertQnnTensors(context_binary_info.numContextTensors, - context_binary_info.contextTensors, - &context_tensors_); - !status) { - return Unexpected(status.Error()); - } - if (auto status = InsertQnnGraphInfos(context_binary_info.numGraphs, - context_binary_info.graphs, &graphs_); - !status) { - return Unexpected(status.Error()); - } - } else { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported context binary version."); - } - return {}; -} - -Expected ContextBinaryInfo::Create( - QnnManager& qnn, const void* exec_bytecode_ptr, size_t exec_bytecode_size) { - auto system_context_handle = qnn.CreateSystemContextHandle(); - if (!system_context_handle) { - return Unexpected(system_context_handle.Error()); - } - - const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; - Qnn_ContextBinarySize_t binary_info_size = 0; - if (auto status = qnn.SystemApi()->systemContextGetBinaryInfo( - system_context_handle->get(), const_cast(exec_bytecode_ptr), - exec_bytecode_size, &binary_info, &binary_info_size); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to get context binary info: %d", status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get context binary info"); - } - - if (!binary_info) { - LITERT_LOG(LITERT_ERROR, "Null binary info", ""); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, "Null binary info"); - } - - ContextBinaryInfo info; - auto status = info.Init(*binary_info); - - if (status) { - return info; - } else { - return Unexpected(status.Error()); - } -} - -} // namespace qnn -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h deleted file mode 100644 index e1e11dfa19f375..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CONTEXT_BINARY_INFO_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CONTEXT_BINARY_INFO_H_ - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h" - -namespace litert::qnn { - -class GraphInfo { - public: - static Expected Create( - const QnnSystemContext_GraphInfo_t& graph_info); - const std::string& Name() const { return name_; } - const std::vector& Inputs() const { return inputs_; } - const std::vector& Outputs() const { return outputs_; } - - private: - GraphInfo() = default; - Expected Init(const QnnSystemContext_GraphInfo_t& graph_info); - std::string name_; - std::vector inputs_; - std::vector outputs_; -}; - -class ContextBinaryInfo { - public: - static Expected Create(QnnManager& qnn, - const void* exec_bytecode_ptr, - size_t exec_bytecode_size); - const std::vector& ContextTensors() const { - return context_tensors_; - } - const std::vector& Graphs() const { return graphs_; } - - private: - ContextBinaryInfo() = default; - Expected Init(const QnnSystemContext_BinaryInfo_t& binary_info); - std::vector context_tensors_; - std::vector graphs_; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CONTEXT_BINARY_INFO_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/BUILD deleted file mode 100644 index 902bb5b5b49bee..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/qualcomm:__subpackages__"], -) - -cc_library( - name = "tensor_pool", - srcs = ["tensor_pool.cc"], - hdrs = ["tensor_pool.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "common", - hdrs = ["common.h"], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/BUILD deleted file mode 100644 index 1c0f0367e2dab3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/BUILD +++ /dev/null @@ -1,574 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/qualcomm:__subpackages__"], -) - -cc_library( - name = "op_builder", - srcs = ["op_builder.cc"], - hdrs = ["op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "fully_connected_op_builder_htp", - srcs = ["fully_connected_op_builder_htp.cc"], - hdrs = ["fully_connected_op_builder_htp.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "elementwise_op_builder", - srcs = ["elementwise_op_builder.cc"], - hdrs = ["elementwise_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:param_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "cast_op_builder", - srcs = ["cast_op_builder.cc"], - hdrs = ["cast_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "concatenation_op_builder", - srcs = ["concatenation_op_builder.cc"], - hdrs = ["concatenation_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "embedding_lookup_op_builder", - srcs = ["embedding_lookup_op_builder.cc"], - hdrs = ["embedding_lookup_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "fully_connected_op_builder", - srcs = ["fully_connected_op_builder.cc"], - hdrs = ["fully_connected_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "gather_op_builder", - srcs = ["gather_op_builder.cc"], - hdrs = ["gather_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "gelu_op_builder", - srcs = ["gelu_op_builder.cc"], - hdrs = ["gelu_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "relu_op_builder", - srcs = ["relu_op_builder.cc"], - hdrs = ["relu_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "relu6_op_builder", - srcs = ["relu6_op_builder.cc"], - hdrs = ["relu6_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "matmul_op_builder", - srcs = ["matmul_op_builder.cc"], - hdrs = ["matmul_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "mean_op_builder", - srcs = ["mean_op_builder.cc"], - hdrs = ["mean_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "quantize_op_builder", - srcs = ["quantize_op_builder.cc"], - hdrs = ["quantize_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "reduce_op_builder", - srcs = ["reduce_op_builder.cc"], - hdrs = ["reduce_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "reshape_op_builder", - srcs = ["reshape_op_builder.cc"], - hdrs = ["reshape_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "select_op_builder", - srcs = ["select_op_builder.cc"], - hdrs = ["select_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "slice_op_builder", - srcs = ["slice_op_builder.cc"], - hdrs = ["slice_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "softmax_op_builder", - srcs = ["softmax_op_builder.cc"], - hdrs = ["softmax_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "split_op_builder", - srcs = ["split_op_builder.cc"], - hdrs = ["split_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "tanh_op_builder", - srcs = ["tanh_op_builder.cc"], - hdrs = ["tanh_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "transpose_op_builder", - srcs = ["transpose_op_builder.cc"], - hdrs = ["transpose_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "pack_op_builder", - srcs = ["pack_op_builder.cc"], - hdrs = ["pack_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "dynamic_update_slice_op_builder", - srcs = ["dynamic_update_slice_op_builder.cc"], - hdrs = ["dynamic_update_slice_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "rms_norm_op_builder", - srcs = ["rms_norm_op_builder.cc"], - hdrs = ["rms_norm_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "conv2d_op_builder", - srcs = ["conv2d_op_builder.cc"], - hdrs = ["conv2d_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "pool2d_op_builder", - srcs = ["pool2d_op_builder.cc"], - hdrs = ["pool2d_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "spatial_transform_op_builder", - srcs = ["spatial_transform_op_builder.cc"], - hdrs = ["spatial_transform_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "resize_op_builder", - srcs = ["resize_op_builder.cc"], - hdrs = ["resize_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "leaky_relu_op_builder", - srcs = ["leaky_relu_op_builder.cc"], - hdrs = ["leaky_relu_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "hard_swish_op_builder", - srcs = ["hard_swish_op_builder.cc"], - hdrs = ["hard_swish_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "depthwise_conv2d_op_builder", - srcs = ["depthwise_conv2d_op_builder.cc"], - hdrs = ["depthwise_conv2d_op_builder.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":op_builder", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:tensor_pool", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.cc deleted file mode 100644 index 361b6007572528..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildCastOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& op = CreateOpWrapper(res, QNN_OP_CAST); - op.AddInputTensor(inputs[0]); - op.AddOutputTensor(outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.h deleted file mode 100644 index 4de521da983870..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/cast_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CAST_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CAST_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildCastOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CAST_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.cc deleted file mode 100644 index c75d985dbbd2bd..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildConcatenationOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::int32_t axis) { - std::vector res; - - auto& concat_op = CreateOpWrapper(res, QNN_OP_CONCAT); - for (const auto& input : inputs) { - concat_op.AddInputTensor(input); - } - concat_op.AddOutputTensor(outputs[0]); - - std::uint32_t adjusted_axis = - (axis >= 0) ? axis : axis + inputs[0].get().GetRank(); - concat_op.AddScalarParam(QNN_OP_CONCAT_PARAM_AXIS, - adjusted_axis); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.h deleted file mode 100644 index ed0784e27a913a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/concatenation_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CONCATENATION_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CONCATENATION_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildConcatenationOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::int32_t axis); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CONCATENATION_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.cc deleted file mode 100644 index a41132440e15bc..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.cc +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.h" - -#include -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr size_t kInputIndex = 0; -constexpr size_t kFilterIndex = 1; -constexpr size_t kBiasIndex = 2; -constexpr size_t kOutputIndex = 0; -constexpr size_t kBatchIndex = 0; -constexpr size_t kHeightIndex = 1; -constexpr size_t kWidthIndex = 2; -constexpr size_t kChannelIndex = 3; - -} // namespace - -std::vector BuildConv2dOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::uint32_t stride_h, - const std::uint32_t stride_w, const std::uint32_t dilation_h, - const std::uint32_t dilation_w, const PaddingType padding_type) { - std::vector res; - - // transpose filter - TensorWrapper& filter_tensor = inputs[kFilterIndex]; - const std::vector& filters_dims = filter_tensor.GetDims(); - auto& filter_quant_params = filter_tensor.GetQuantParams(); - std::vector permute_dims{filters_dims[1], filters_dims[2], - filters_dims[3], filters_dims[0]}; - if (std::holds_alternative( - filter_quant_params)) { - auto& axis_quant_params = - std::get(filter_quant_params); - const std::array new_axis{3, 0, 1, 2}; - axis_quant_params.SetAxis(new_axis[axis_quant_params.GetAxis()]); - } - - size_t filter_bytes = filter_tensor.GetTensorBytes(); - TensorWrapper* transposed_filter_tensor = nullptr; - if (filter_tensor.IsTensorStatic() && - filter_tensor.GetDataType() == - Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_8) { - auto filter_data = filter_tensor.GetStaticTensorData(); - std::vector transpose_weight_int8; - TransposeFromOHWIToHWIO(filter_data.value(), filters_dims, - transpose_weight_int8); - transposed_filter_tensor = &(tensor_pool.CreateStaticTensor( - filter_tensor.GetDataType(), filter_quant_params, permute_dims, - filter_bytes, transpose_weight_int8.data())); - } else if (filter_tensor.IsTensorStatic() && - filter_tensor.GetDataType() == - Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8) { - auto filter_data = filter_tensor.GetStaticTensorData(); - std::vector transpose_weight_uint8; - TransposeFromOHWIToHWIO(filter_data.value(), filters_dims, - transpose_weight_uint8); - transposed_filter_tensor = &(tensor_pool.CreateStaticTensor( - filter_tensor.GetDataType(), filter_quant_params, permute_dims, - filter_bytes, transpose_weight_uint8.data())); - } else { - transposed_filter_tensor = - &(tensor_pool.CloneNativeTensorFrom(filter_tensor, permute_dims)); - - const std::vector permute_shape{4}; - const std::array permute_data{kHeightIndex, kWidthIndex, - kChannelIndex, kBatchIndex}; - auto& permute_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, permute_shape, - sizeof(decltype(permute_data)::value_type) * permute_data.size(), - permute_data.data()); - - OpWrapper& transpose_op = CreateOpWrapper(res, QNN_OP_TRANSPOSE); - transpose_op.AddInputTensor(filter_tensor); - transpose_op.AddOutputTensor(*transposed_filter_tensor); - transpose_op.AddTensorParam(QNN_OP_TRANSPOSE_PARAM_PERM, permute_tensor); - } - - // conv - OpWrapper& conv_op = CreateOpWrapper(res, QNN_OP_CONV_2D); - TensorWrapper& input_tensor = inputs[kInputIndex]; - conv_op.AddInputTensor(input_tensor); - conv_op.AddInputTensor(*transposed_filter_tensor); - if (inputs.size() - 1 >= kBiasIndex) { - TensorWrapper& bias_tensor = inputs[kBiasIndex]; - // QNN only support per-tensor quant for bias, - // and the scale and offset are both zero. - bias_tensor.ConvertAxisScaleOffsetToScaleOffset(); - conv_op.AddInputTensor(bias_tensor); - } - - TensorWrapper& output_tensor = outputs[kOutputIndex]; - conv_op.AddOutputTensor(output_tensor); - // TODO: fused activation - - // stride param - const std::array stride_data{stride_h, stride_w}; - const std::vector stride_shape{2}; - auto& stride_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, stride_shape, - sizeof(decltype(stride_data)::value_type) * stride_data.size(), - stride_data.data()); - conv_op.AddTensorParam(QNN_OP_CONV_2D_PARAM_STRIDE, stride_tensor); - - // dilation param - const std::array dilation_data{dilation_h, dilation_w}; - const std::vector dilation_shape{2}; - auto& dilation_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, dilation_shape, - sizeof(decltype(dilation_data)::value_type) * dilation_data.size(), - dilation_data.data()); - conv_op.AddTensorParam(QNN_OP_CONV_2D_PARAM_DILATION, dilation_tensor); - - // padding param - const auto [padding_before_height, padding_after_height] = - ComputePaddingBeforeAfter(input_tensor.GetDim(kHeightIndex), - filter_tensor.GetDim(kHeightIndex), stride_h, - dilation_h, padding_type); - const auto [padding_before_width, padding_after_width] = - ComputePaddingBeforeAfter(input_tensor.GetDim(kWidthIndex), - filter_tensor.GetDim(kWidthIndex), stride_w, - dilation_w, padding_type); - const std::array padding_data = { - padding_before_height, padding_after_height, padding_before_width, - padding_after_width}; - const std::vector padding_shape{2, 2}; - auto& padding_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, padding_shape, - sizeof(decltype(padding_data)::value_type) * padding_data.size(), - padding_data.data()); - conv_op.AddTensorParam(QNN_OP_CONV_2D_PARAM_PAD_AMOUNT, padding_tensor); - - // group param - if ((input_tensor.GetDim(kChannelIndex) % - filter_tensor.GetDim(kChannelIndex)) != 0) { - QNN_LOG_WARNING( - "The channels of the input tensor cannot be evenly divided by the " - "channels of the filter tensor."); - } - if (const std::uint32_t groups = input_tensor.GetDim(kChannelIndex) / - filter_tensor.GetDim(kChannelIndex); - groups > 1) { - conv_op.AddScalarParam(QNN_OP_CONV_2D_PARAM_GROUP, groups); - } - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.h deleted file mode 100644 index 7cdd99bec46c40..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/conv2d_op_builder.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CONV2D_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CONV2D_OP_BUILDER_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildConv2dOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::uint32_t stride_h, - const std::uint32_t stride_w, const std::uint32_t dilation_h, - const std::uint32_t dilation_w, const PaddingType padding); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_CONV2D_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.cc deleted file mode 100644 index 3d4840eb6c390a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.cc +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr size_t kInputIndex = 0; -constexpr size_t kFilterIndex = 1; -constexpr size_t kBiasIndex = 2; -constexpr size_t kOutputIndex = 0; -constexpr size_t kBatchIndex = 0; -constexpr size_t kHeightIndex = 1; -constexpr size_t kWidthIndex = 2; -constexpr size_t kChannelIndex = 3; - -} // namespace - -std::vector BuildDepthwiseConv2dOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::uint32_t stride_h, - const std::uint32_t stride_w, const std::uint32_t dilation_h, - const std::uint32_t dilation_w, const PaddingType padding_type) { - std::vector res; - - // reshape filter - TensorWrapper& filter_tensor = inputs[kFilterIndex]; - - // 1HWC to HW1C, only need reshape instead of transpose. - const std::vector reshape_dims{ - filter_tensor.GetDim(kHeightIndex), filter_tensor.GetDim(kWidthIndex), - filter_tensor.GetDim(kBatchIndex), filter_tensor.GetDim(kChannelIndex)}; - TensorWrapper* reshaped_filter_tensor = nullptr; - if (filter_tensor.IsTensorStatic()) { - reshaped_filter_tensor = - &(tensor_pool.CloneStaticTensorFrom(filter_tensor, reshape_dims)); - } else { - reshaped_filter_tensor = - &(tensor_pool.CloneNativeTensorFrom(filter_tensor, reshape_dims)); - - OpWrapper& reshape_op = CreateOpWrapper(res, QNN_OP_RESHAPE); - reshape_op.AddInputTensor(filter_tensor); - reshape_op.AddOutputTensor(*reshaped_filter_tensor); - } - - // conv - OpWrapper& conv_op = CreateOpWrapper(res, QNN_OP_DEPTH_WISE_CONV_2D); - TensorWrapper& input_tensor = inputs[kInputIndex]; - conv_op.AddInputTensor(input_tensor); - conv_op.AddInputTensor(*reshaped_filter_tensor); - if (inputs.size() - 1 >= kBiasIndex) { - TensorWrapper& bias_tensor = inputs[kBiasIndex]; - // QNN only support per-tensor quant for bias, - // and the scale and offset are both zero. - bias_tensor.ConvertAxisScaleOffsetToScaleOffset(); - conv_op.AddInputTensor(bias_tensor); - } - - TensorWrapper& output_tensor = outputs[kOutputIndex]; - conv_op.AddOutputTensor(output_tensor); - // TODO: fused activation - - // stride param - const std::array stride_data{stride_h, stride_w}; - const std::vector stride_shape{2}; - auto& stride_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, stride_shape, - sizeof(decltype(stride_data)::value_type) * stride_data.size(), - stride_data.data()); - conv_op.AddTensorParam(QNN_OP_DEPTH_WISE_CONV_2D_PARAM_STRIDE, stride_tensor); - - // dilation param - const std::array dilation_data{dilation_h, dilation_w}; - const std::vector dilation_shape{2}; - auto& dilation_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, dilation_shape, - sizeof(decltype(dilation_data)::value_type) * dilation_data.size(), - dilation_data.data()); - conv_op.AddTensorParam(QNN_OP_DEPTH_WISE_CONV_2D_PARAM_DILATION, - dilation_tensor); - - // padding param - const auto [padding_before_height, padding_after_height] = - ComputePaddingBeforeAfter(input_tensor.GetDim(kHeightIndex), - filter_tensor.GetDim(kHeightIndex), stride_h, - dilation_h, padding_type); - const auto [padding_before_width, padding_after_width] = - ComputePaddingBeforeAfter(input_tensor.GetDim(kWidthIndex), - filter_tensor.GetDim(kWidthIndex), stride_w, - dilation_w, padding_type); - const std::array padding_data = { - padding_before_height, padding_after_height, padding_before_width, - padding_after_width}; - const std::vector padding_shape{2, 2}; - auto& padding_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, padding_shape, - sizeof(decltype(padding_data)::value_type) * padding_data.size(), - padding_data.data()); - conv_op.AddTensorParam(QNN_OP_CONV_2D_PARAM_PAD_AMOUNT, padding_tensor); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.h deleted file mode 100644 index 32419352844a5b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/depthwise_conv2d_op_builder.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_DEPTHWISE_CONV2D_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_DEPTHWISE_CONV2D_OP_BUILDER_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildDepthwiseConv2dOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::uint32_t stride_h, - const std::uint32_t stride_w, const std::uint32_t dilation_h, - const std::uint32_t dilation_w, const PaddingType padding); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_DEPTHWISE_CONV2D_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.cc deleted file mode 100644 index b356188becb799..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.cc +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.h" - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { -namespace { -constexpr int kInputIdx = 0; -constexpr int kUpdateIdx = 1; -constexpr int kIndicesIdx = 2; -constexpr int kOutputIdx = 0; -} // namespace - -std::vector BuildDynamicUpdateSliceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - // Dynamic Update Slice: - // in[0] operand: [1, 64, 4, 64] - // in[1] updates: [1, 1, 4, 64] - // in[2] start_indices: [4] -> data: [0, x, 0, 0] - - // reduceSum and reshape in[2] -> index tensor - - // Create static tensor table - // shape: [64] - // data: [0,...,63] - - // QNN ElementWiseNotEqual: - // in[0]: table - // in[1]: index tensor - // out[0]: condition tensor - - // reshape condition tensor due to QNN broadcast rules - // in[0]: [64] - // out[0]: [64, 1, 1] - - // QNN ElementWiseSelect: - // in[0] condition: [64, 1, 1] - // in[1] input: [1, 64, 4, 64] - // in[2] updates: [1, 1, 4, 64] - - // CAUTION!!! only support Gemma2 use case now - - auto& input_tensor = inputs[kInputIdx].get(); - auto& update_tensor = inputs[kUpdateIdx].get(); - auto& indices_tensor = inputs[kIndicesIdx].get(); - auto& output_tensor = outputs[kOutputIdx].get(); - - if (input_tensor.GetRank() != update_tensor.GetRank()) { - LITERT_LOG(LITERT_ERROR, "%s", - "QNN LiteRT Delegate only supports Dynamic Update Slice when " - "operand and updates have the same rank."); - return {}; - } - - if (indices_tensor.GetDataType() != QNN_DATATYPE_INT_32) { - LITERT_LOG(LITERT_ERROR, "%s", - "Dynamic Update Slice only supports QNN_DATATYPE_INT_32 " - "start_indices."); - return {}; - } - - // reduce sum - auto& reduce_sum_op = CreateOpWrapper(res, QNN_OP_REDUCE_SUM); - reduce_sum_op.AddInputTensor(indices_tensor); - - std::vector axis_data = {0}; - TensorWrapper& axis_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, {1}, - sizeof(std::uint32_t), axis_data.data()); - reduce_sum_op.AddTensorParam(QNN_OP_REDUCE_SUM_PARAM_AXES, axis_tensor); - - // create intermediate tensor - TensorWrapper& one_dim_index = - tensor_pool.CloneNativeTensorFrom(indices_tensor, {1}); - reduce_sum_op.AddOutputTensor(one_dim_index); - - // ElementwiseNotEqual - // get table dims from in[0]->Dims[1] - if (input_tensor.GetRank() < 2) { - LITERT_LOG(LITERT_ERROR, "%s", - "Dynamic Update Slice only supports operand tensor rank >= 2"); - return {}; - } - uint32_t table_size = input_tensor.GetDim(1); - std::vector static_table_dims = {table_size}; - std::vector table_data(table_size); - std::iota(table_data.begin(), table_data.end(), 0); - - // create static table tensor - TensorWrapper& static_table = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_INT_32, QuantizeParamsWrapperVariant{}, static_table_dims, - table_size * sizeof(std::int32_t), table_data.data()); - - OpWrapper& not_equal_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_NOT_EQUAL); - not_equal_op.AddInputTensor(static_table); - not_equal_op.AddInputTensor(one_dim_index); - - TensorWrapper& not_equal_out = tensor_pool.CreateNativeTensor( - QNN_DATATYPE_BOOL_8, QuantizeParamsWrapperVariant{}, static_table_dims); - not_equal_op.AddOutputTensor(not_equal_out); - - // reshape not equal output to [N, 1, 1] - OpWrapper& reshape_op = CreateOpWrapper(res, QNN_OP_RESHAPE); - - reshape_op.AddInputTensor(not_equal_out); - TensorWrapper& reshape_out = - tensor_pool.CloneNativeTensorFrom(not_equal_out, {table_size, 1, 1}); - reshape_op.AddOutputTensor(reshape_out); - - // Select - OpWrapper& select_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_SELECT); - - select_op.AddInputTensor(reshape_out); - select_op.AddInputTensor(input_tensor); - select_op.AddInputTensor(update_tensor); - select_op.AddOutputTensor(output_tensor); - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.h deleted file mode 100644 index c5a74c1a7c5ed6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/dynamic_update_slice_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_DYNAMIC_UPDATE_SLICE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_DYNAMIC_UPDATE_SLICE_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildDynamicUpdateSliceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_DYNAMIC_UPDATE_SLICE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.cc deleted file mode 100644 index 38ed759d53fd21..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.cc +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildElementwiseAddOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_ADD); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - // TODO: fused activation - return res; -} - -std::vector BuildElementwiseSubOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_SUBTRACT); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - // TODO: fused activation - return res; -} - -std::vector BuildElementwiseMulOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_MULTIPLY); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - // TODO: fused activation - return res; -} - -std::vector BuildElementwiseDivOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_DIVIDE); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - // TODO: fused activation - return res; -} - -std::vector BuildElementwiseSinOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_SIN); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - return res; -} - -std::vector BuildElementwiseCosOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_COS); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - return res; -} - -std::vector BuildElementwiseRsqrtOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_RSQRT); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - return res; -} - -std::vector BuildElementwiseSquareOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - OpWrapper& elementwise_op = - CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_MULTIPLY); - elementwise_op.AddInputTensor(inputs[0]); - elementwise_op.AddInputTensor(inputs[0]); - elementwise_op.AddOutputTensor(outputs[0]); - - return res; -} - -std::vector BuildElementwiseSquaredDifferenceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = - CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_SQUARED_DIFFERENCE); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - - return res; -} - -std::vector BuildElementwiseLessOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_BINARY); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - elementwise_op.AddScalarParam( - QNN_OP_ELEMENT_WISE_BINARY_PARAM_OPERATION, - QNN_OP_ELEMENT_WISE_BINARY_OPERATION_LESS); - - return res; -} - -std::vector BuildElementwiseGreaterOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_BINARY); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - elementwise_op.AddScalarParam( - QNN_OP_ELEMENT_WISE_BINARY_PARAM_OPERATION, - QNN_OP_ELEMENT_WISE_BINARY_OPERATION_GREATER); - - return res; -} - -std::vector BuildElementwiseAndOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_BINARY); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - elementwise_op.AddScalarParam( - QNN_OP_ELEMENT_WISE_BINARY_PARAM_OPERATION, - QNN_OP_ELEMENT_WISE_BINARY_OPERATION_AND); - - return res; -} - -std::vector BuildElementwiseMinimumOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_BINARY); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - elementwise_op.AddScalarParam( - QNN_OP_ELEMENT_WISE_BINARY_PARAM_OPERATION, - QNN_OP_ELEMENT_WISE_BINARY_OPERATION_MINIMUM); - - return res; -} - -std::vector BuildElementwiseMaximumOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& elementwise_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_BINARY); - for (const auto& input : inputs) { - elementwise_op.AddInputTensor(input); - } - elementwise_op.AddOutputTensor(outputs[0]); - elementwise_op.AddScalarParam( - QNN_OP_ELEMENT_WISE_BINARY_PARAM_OPERATION, - QNN_OP_ELEMENT_WISE_BINARY_OPERATION_MAXIMUM); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.h deleted file mode 100644 index 7953ce93c26b17..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/elementwise_op_builder.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_ELEMENTWISE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_ELEMENTWISE_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildElementwiseAddOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseSubOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseMulOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseDivOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseSinOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseCosOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseRsqrtOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseSquareOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseSquaredDifferenceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseLessOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseGreaterOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseAndOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseMinimumOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildElementwiseMaximumOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_ELEMENTWISE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.cc deleted file mode 100644 index f33c5167ef3404..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.cc +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.h" - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { -namespace { -constexpr int kTableIdx = 1; -constexpr int kIndicesIdx = 0; -constexpr int kOutputIdx = 0; -} // namespace - -std::vector BuildEmbeddingLookupOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - TensorWrapper& table_tensor = inputs[kTableIdx]; - TensorWrapper& indices_tensor = inputs[kIndicesIdx]; - TensorWrapper& output_tensor = outputs[kOutputIdx]; - - auto& gather_op = CreateOpWrapper(res, QNN_OP_GATHER); - // Case: QInt8 table with QInt16 output - if (table_tensor.IsQuant8() && output_tensor.IsQuant16()) { - QNN_LOG_WARNING( - "The data type of embedding lookup table is int8, but output data type " - "is int16. Int8 table will be cast to int16."); - std::vector int16_data; - size_t data_len = table_tensor.GetTensorNumElements(); - auto int8_data = table_tensor.GetStaticTensorData(); - if (!int8_data.has_value()) { - QNN_LOG_ERROR("Embedding lookup get int8 table failed."); - return res; - } - int16_data.reserve(data_len); - for (int i = 0; i < data_len; ++i) { - int16_data.emplace_back(static_cast((*int8_data)[i])); - } - - TensorWrapper& int16_table_tensor = tensor_pool.CreateStaticTensor( - output_tensor.GetDataType(), table_tensor.GetQuantParams(), - table_tensor.GetDims(), - sizeof(decltype(int16_data)::value_type) * int16_data.size(), - reinterpret_cast(int16_data.data())); - - gather_op.AddInputTensor(int16_table_tensor); - } else { - gather_op.AddInputTensor(table_tensor); - } - - gather_op.AddInputTensor(indices_tensor); - gather_op.AddOutputTensor(output_tensor); - gather_op.AddScalarParam(QNN_OP_GATHER_PARAM_AXIS, 0); - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.h deleted file mode 100644 index 175f65dac0a5e0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/embedding_lookup_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_EMBEDDING_LOOKUP_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_EMBEDDING_LOOKUP_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildEmbeddingLookupOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_EMBEDDING_LOOKUP_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.cc deleted file mode 100644 index 2b471d5f7ce5d9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.cc +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr int kBiasIdx = 2; -} - -std::vector BuildFullyConnectedOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool keep_num_dims) { - std::vector res; - OpWrapper& fully_connected_op = CreateOpWrapper(res, QNN_OP_FULLY_CONNECTED); - - TensorWrapper& input_tensor = inputs[0]; - fully_connected_op.AddInputTensor(input_tensor); - TensorWrapper& weight_tensor = inputs[1]; - fully_connected_op.AddInputTensor(weight_tensor); - if (inputs.size() - 1 >= kBiasIdx) { - TensorWrapper& bias_tensor = inputs[kBiasIdx]; - fully_connected_op.AddInputTensor(bias_tensor); - } - - TensorWrapper& output_tensor = outputs[0]; - if (keep_num_dims) { - auto& input_dims = input_tensor.GetDims(); - std::uint32_t input_size = std::accumulate( - input_dims.begin(), input_dims.end(), 1, std::multiplies<>()); - const std::uint32_t num_units = weight_tensor.GetDim(0); - const std::uint32_t num_input_elem = weight_tensor.GetDim(1); - - // input_size must be divisible by num_input_elem. This should be validated - // by QNN. - const std::uint32_t batch_size = input_size / num_input_elem; - // QNN output should always be rank 2 - qnn::TensorWrapper& fully_connected_out = tensor_pool.CloneNativeTensorFrom( - output_tensor, {batch_size, num_units}); - - fully_connected_op.AddOutputTensor(fully_connected_out); - // TODO: fused activation - - qnn::OpWrapper& reshape_op = CreateOpWrapper(res, QNN_OP_RESHAPE); - reshape_op.AddInputTensor(fully_connected_out); - reshape_op.AddOutputTensor(output_tensor); - } else { - fully_connected_op.AddOutputTensor(outputs[0]); - // TODO: fused activation - } - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.h deleted file mode 100644 index 3031be6f3002b8..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_FULLY_CONNECTED_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_FULLY_CONNECTED_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildFullyConnectedOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool keep_num_dims); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_FULLY_CONNECTED_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.cc deleted file mode 100644 index a0e56116518f4b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.cc +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.h" - -#include -#include -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildFullyConnectedOpHtp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool keep_num_dims) { - std::vector res; - QNN_LOG_INFO("[FullyConnected Optimization] FC -> CONV2D"); - // TFLite FC Input: [1, k, n] and Weight: [m, n] - // QNN Conv2D Input: - // [batch, height, width, channel_in] - // -> [1, 1, k, n] - // QNN Conv2D Weight: - // [filter_height, filter_width, channel_in / group, channel_out] - // -> [1, 1, n, m] - bool is_supported = inputs[0].get().GetRank() == 3 && inputs.size() == 2 && - inputs[1].get().IsTensorStatic(); - if (!is_supported) { - QNN_LOG_INFO("[FullyConnected Optimization] FAILURE: Unsupported Input"); - return res; - } - - // TFLite FC -> QNN CONV2D: - // Reshape -> Conv2D -> Reshpae - TensorWrapper& input_tensor = inputs[0]; - TensorWrapper& weight_tensor = inputs[1]; - TensorWrapper& output_tensor = outputs[0]; - // Reshape - qnn::OpWrapper& reshape_op_1 = CreateOpWrapper(res, QNN_OP_RESHAPE); - reshape_op_1.AddInputTensor(input_tensor); - std::vector conv_input_dims = input_tensor.GetDims(); - conv_input_dims.insert(conv_input_dims.begin() + 1, 1); - qnn::TensorWrapper& conv_input_tensor = - tensor_pool.CloneNativeTensorFrom(input_tensor, conv_input_dims); - reshape_op_1.AddOutputTensor(conv_input_tensor); - // Conv2D Input, Weight, and Output - OpWrapper& conv_op = CreateOpWrapper(res, QNN_OP_CONV_2D); - conv_op.AddInputTensor(conv_input_tensor); - auto& quant_params = weight_tensor.GetQuantParams(); - if (std::holds_alternative( - quant_params)) { - auto& axis_quant_param = - std::get(quant_params); - axis_quant_param.SetAxis(3); - } - std::vector weight_dims{1, 1, weight_tensor.GetDim(1), - weight_tensor.GetDim(0)}; - size_t weight_bytes = weight_tensor.GetTensorBytes(); - const std::vector transpose_dim{weight_tensor.GetDim(0), 1, 1, - weight_tensor.GetDim(1)}; - TensorWrapper* weight; - if (weight_tensor.GetDataType() == QNN_DATATYPE_SFIXED_POINT_8) { - std::vector conv_weight; - auto fc_weight = weight_tensor.GetStaticTensorData(); - TransposeFromOHWIToHWIO(fc_weight.value(), transpose_dim, conv_weight); - weight = &(tensor_pool.CreateStaticTensor( - weight_tensor.GetDataType(), quant_params, weight_dims, weight_bytes, - conv_weight.data())); - } else if (weight_tensor.GetDataType() == QNN_DATATYPE_SFIXED_POINT_16) { - std::vector conv_weight; - auto fc_weight = weight_tensor.GetStaticTensorData(); - TransposeFromOHWIToHWIO(fc_weight.value(), transpose_dim, conv_weight); - weight = &(tensor_pool.CreateStaticTensor( - weight_tensor.GetDataType(), quant_params, weight_dims, weight_bytes, - conv_weight.data())); - } else if (weight_tensor.GetDataType() == QNN_DATATYPE_UFIXED_POINT_16) { - std::vector conv_weight; - auto fc_weight = weight_tensor.GetStaticTensorData(); - TransposeFromOHWIToHWIO(fc_weight.value(), transpose_dim, conv_weight); - weight = &(tensor_pool.CreateStaticTensor( - weight_tensor.GetDataType(), quant_params, weight_dims, weight_bytes, - conv_weight.data())); - } else if (weight_tensor.GetDataType() == QNN_DATATYPE_FLOAT_32) { - std::vector conv_weight; - auto fc_weight = weight_tensor.GetStaticTensorData(); - TransposeFromOHWIToHWIO(fc_weight.value(), transpose_dim, conv_weight); - weight = &(tensor_pool.CreateStaticTensor( - weight_tensor.GetDataType(), quant_params, weight_dims, weight_bytes, - conv_weight.data())); - } else { - QNN_LOG_INFO( - "[FullyConnected Optimization] FAILURE: Upsupported Weight Datatype"); - return {}; - } - conv_op.AddInputTensor(*weight); - qnn::TensorWrapper& conv_out = tensor_pool.CloneNativeTensorFrom( - output_tensor, {conv_input_dims[0], conv_input_dims[1], - conv_input_dims[2], weight_dims[3]}); - conv_op.AddOutputTensor(conv_out); - // Conv2D Stride - const std::array stride_data{1, 1}; - const std::vector stride_shape{2}; - auto& stride_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, stride_shape, - sizeof(std::uint32_t) * stride_data.size(), stride_data.data()); - conv_op.AddTensorParam(QNN_OP_DEPTH_WISE_CONV_2D_PARAM_STRIDE, stride_tensor); - // Conv2D Padding - const std::array padding_data = {0, 0, 0, 0}; - const std::vector padding_shape{2, 2}; - auto& padding_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, padding_shape, - sizeof(std::uint32_t) * padding_data.size(), padding_data.data()); - conv_op.AddTensorParam(QNN_OP_CONV_2D_PARAM_PAD_AMOUNT, padding_tensor); - - // Reshape - qnn::OpWrapper& reshape_op_2 = CreateOpWrapper(res, QNN_OP_RESHAPE); - reshape_op_2.AddInputTensor(conv_out); - reshape_op_2.AddOutputTensor(output_tensor); - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.h deleted file mode 100644 index ccf8371fe9755e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/fully_connected_op_builder_htp.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_FULLY_CONNECTED_OP_BUILDER_HTP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_FULLY_CONNECTED_OP_BUILDER_HTP_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildFullyConnectedOpHtp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool keep_num_dims); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_FULLY_CONNECTED_OP_BUILDER_HTP_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.cc deleted file mode 100644 index e20d7b31fad8bb..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildGatherOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::int32_t axis, - const std::int32_t batch_dims) { - std::vector res; - - if (batch_dims != 0) { - QNN_LOG_ERROR("The batch dimension of Gather OP is not equal to 0."); - return res; - } - - auto& gather_op = CreateOpWrapper(res, QNN_OP_GATHER); - for (const auto& input : inputs) { - gather_op.AddInputTensor(input); - } - for (const auto& output : outputs) { - gather_op.AddOutputTensor(output); - } - const std::int32_t adjusted_axis = - axis >= 0 ? axis : axis + inputs[0].get().GetRank(); - gather_op.AddScalarParam(QNN_OP_GATHER_PARAM_AXIS, - adjusted_axis); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.h deleted file mode 100644 index 00b078c4f36e7e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gather_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_GATHER_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_GATHER_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildGatherOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const std::int32_t axis, - const std::int32_t batch_dims); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_GATHER_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.cc deleted file mode 100644 index 8f382292b53bc5..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildGeluOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - CreateSimpleActivationOp(res, QNN_OP_GELU, inputs[0], outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.h deleted file mode 100644 index 77a72154ee89a9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/gelu_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_GELU_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_GELU_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildGeluOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_GELU_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.cc deleted file mode 100644 index be77996eb660e6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.cc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.h" - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr size_t kInputIndex = 0; -constexpr size_t kOutputIndex = 0; -} // namespace - -std::vector BuildHardSwishOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - OpWrapper& hard_swish_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_NEURON); - hard_swish_op.AddInputTensor(inputs[kInputIndex]); - hard_swish_op.AddOutputTensor(outputs[kOutputIndex]); - hard_swish_op.AddScalarParam( - QNN_OP_ELEMENT_WISE_NEURON_PARAM_OPERATION, - QNN_OP_ELEMENT_WISE_NEURON_OPERATION_HARD_SWISH); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.h deleted file mode 100644 index 9a0a6c3254d327..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/hard_swish_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_HARD_SWISH_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_HARD_SWISH_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildHardSwishOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_HARD_SWISH_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.cc deleted file mode 100644 index b6ece2ed343655..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.cc +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.h" - -#include -#include -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr size_t kInputIndex = 0; -constexpr size_t kOutputIndex = 0; - -template -TensorWrapper& CreateAlphaTensor( - TensorPool& tensor_pool, const Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quant_param, const T alpha) { - const std::vector alpha_dims{1}; - const std::array alpha_data{alpha}; - return tensor_pool.CreateStaticTensor(data_type, quant_param, alpha_dims, - sizeof(T) * alpha_data.size(), - alpha_data.data()); -} - -} // namespace -std::vector BuildLeakyReluOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const float alpha) { - std::vector res; - - OpWrapper& leaky_relu_op = CreateOpWrapper(res, QNN_OP_PRELU); - TensorWrapper& input_tensor = inputs[kInputIndex]; - leaky_relu_op.AddInputTensor(input_tensor); - leaky_relu_op.AddOutputTensor(outputs[kOutputIndex]); - - if (std::holds_alternative( - input_tensor.GetQuantParams())) { - TensorWrapper& alpha_tensor = - CreateAlphaTensor(tensor_pool, input_tensor.GetDataType(), - input_tensor.GetQuantParams(), alpha); - leaky_relu_op.AddInputTensor(alpha_tensor); - } else if (std::holds_alternative( - input_tensor.GetQuantParams())) { - QuantizeParamsWrapperVariant quant_param; - quant_param.emplace(std::max(alpha, 0.0f), - 0); - - switch (input_tensor.GetDataType()) { - case QNN_DATATYPE_UFIXED_POINT_8: { - TensorWrapper& alpha_tensor = CreateAlphaTensor( - tensor_pool, input_tensor.GetDataType(), quant_param, 1); - leaky_relu_op.AddInputTensor(alpha_tensor); - break; - } - case QNN_DATATYPE_SFIXED_POINT_8: { - TensorWrapper& alpha_tensor = CreateAlphaTensor( - tensor_pool, input_tensor.GetDataType(), quant_param, 1); - leaky_relu_op.AddInputTensor(alpha_tensor); - break; - } - case QNN_DATATYPE_UFIXED_POINT_16: { - TensorWrapper& alpha_tensor = CreateAlphaTensor( - tensor_pool, input_tensor.GetDataType(), quant_param, 1); - leaky_relu_op.AddInputTensor(alpha_tensor); - break; - } - case QNN_DATATYPE_SFIXED_POINT_16: { - TensorWrapper& alpha_tensor = CreateAlphaTensor( - tensor_pool, input_tensor.GetDataType(), quant_param, 1); - leaky_relu_op.AddInputTensor(alpha_tensor); - break; - } - default: { - QNN_LOG_ERROR( - "Unsupported QNN data type when creating alpha tensor for " - "per-tensor quantization."); - break; - } - } - } else { - QNN_LOG_ERROR("Unsupported quantization type for LeakyRelu op."); - } - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.h deleted file mode 100644 index 99f400a7285355..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/leaky_relu_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_LEAKY_RELU_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_LEAKY_RELU_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildLeakyReluOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const float alpha); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_LEAKY_RELU_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.cc deleted file mode 100644 index 9833c36fe71fec..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.cc +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildMatmulOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool adj_x, - const bool adj_y) { - std::vector res; - - auto& matmul_op = CreateOpWrapper(res, QNN_OP_MAT_MUL); - for (const auto& input : inputs) { - matmul_op.AddInputTensor(input); - } - matmul_op.AddOutputTensor(outputs[0]); - matmul_op.AddScalarParam(QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN0, adj_x); - matmul_op.AddScalarParam(QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN1, adj_y); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.h deleted file mode 100644 index 40958ebb9c4db2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/matmul_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_MATMUL_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_MATMUL_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildMatmulOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool adj_x, - const bool adj_y); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_MATMUL_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.cc deleted file mode 100644 index 3495ee24efa6f9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.cc +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildMeanOp(TensorPool& tensor_pool, - const std::vector& inputs, - const std::vector& outputs, - const bool keep_dim) { - std::vector res; - - TensorWrapper& axis_tensor = inputs[1]; - if (!axis_tensor.IsTensorStatic() || axis_tensor.GetRank() != 1) { - QNN_LOG_ERROR( - "The axis tensor is not static, or the rank of axis tensor is not " - "equal to 1."); - return res; - } - - TensorWrapper& input_tensor = inputs[0]; - - auto axis_data = axis_tensor.GetStaticTensorData(); - if (!axis_data.has_value()) { - QNN_LOG_ERROR("Get axis_data failed."); - return res; - } - std::vector adjusted_axis_data; - for (size_t i = 0; i < axis_tensor.GetDim(0); ++i) { - std::uint32_t adjusted_axis = - (*axis_data)[i] >= 0 ? (*axis_data)[i] - : (*axis_data)[i] + input_tensor.GetRank(); - if (std::find(adjusted_axis_data.begin(), adjusted_axis_data.end(), - adjusted_axis) == adjusted_axis_data.end()) { - adjusted_axis_data.emplace_back(adjusted_axis); - } - } - TensorWrapper& adjusted_axis_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, axis_tensor.GetQuantParams(), - {static_cast(adjusted_axis_data.size())}, - sizeof(std::uint32_t) * adjusted_axis_data.size(), - adjusted_axis_data.data()); - - auto& reduce_op = CreateOpWrapper(res, QNN_OP_REDUCE_MEAN); - reduce_op.AddInputTensor(input_tensor); - reduce_op.AddOutputTensor(outputs[0]); - reduce_op.AddTensorParam(QNN_OP_REDUCE_MEAN_PARAM_AXES, adjusted_axis_tensor); - reduce_op.AddScalarParam(QNN_OP_REDUCE_MEAN_PARAM_KEEP_DIMS, keep_dim); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.h deleted file mode 100644 index 50127647c90c10..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/mean_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_MEAN_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_MEAN_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildMeanOp(TensorPool& tensor_pool, - const std::vector& inputs, - const std::vector& outputs, - const bool keep_dims); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_MEAN_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.cc deleted file mode 100644 index 8687927d3875b2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.cc +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" - -#include -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::pair ComputePaddingBeforeAfter( - const std::uint32_t input_size, const std::uint32_t filter_size, - const std::uint32_t stride, const std::uint32_t dilation_rate, - const PaddingType padding_type) { - // padding_before, padding_after - std::pair result{0, 0}; - if (stride == 0) { - QNN_LOG_ERROR("Stride is 0"); - return result; - } - - std::uint32_t output_size{}; - std::uint32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1; - - switch (padding_type) { - case PaddingType::Same: - output_size = (input_size + stride - 1) / stride; - break; - case PaddingType::Valid: - output_size = (input_size + stride - effective_filter_size) / stride; - break; - default: // PaddingType::Unknown - QNN_LOG_ERROR("Unknown padding type"); - return result; - } - - std::uint32_t total_padding = - (output_size - 1) * stride + effective_filter_size - input_size; - result.first = total_padding / 2; - result.second = result.first + total_padding % 2; - return result; -} - -OpWrapper& CreateOpWrapper(std::vector& ops, const char* op_type) { - const auto op_count = ops.size(); - const auto name = "op_type_" + std::string(op_type) + "_op_count_" + - std::to_string(op_count); - return ops.emplace_back(std::move(name), op_type); -} - -OpWrapper& CreateSimpleActivationOp(std::vector& ops, - const char* op_type, - const TensorWrapper& input_tensor, - const TensorWrapper& output_tensor) { - auto& ret = CreateOpWrapper(ops, op_type); - ret.AddInputTensor(input_tensor); - ret.AddOutputTensor(output_tensor); - return ret; -} - -/* -LiteRtStatus OpMapper::AddFusedActivationNode( - const tflite::ActivationFunctionType activation, - const TensorWrapper& input_tensor, const TensorWrapper& output_tensor) { - switch (activation) { - case tflite::ActivationFunctionType_RELU: { - OpWrapper& activation_op = - CreateSimpleActivationOp(QNN_OP_RELU, input_tensor, output_tensor); - break; - } - case tflite::ActivationFunctionType_RELU_N1_TO_1: { - OpWrapper& activation_op = CreateSimpleActivationOp( - QNN_OP_RELU_MIN_MAX, input_tensor, output_tensor); - activation_op.AddScalarParam(QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE, - -1.f); - activation_op.AddScalarParam(QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE, - 1.f); - break; - } - case tflite::ActivationFunctionType_RELU6: { - OpWrapper& activation_op = CreateSimpleActivationOp( - QNN_OP_RELU_MIN_MAX, input_tensor, output_tensor); - activation_op.AddScalarParam(QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE, - 0.f); - activation_op.AddScalarParam(QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE, - 6.f); - break; - } - case tflite::ActivationFunctionType_TANH: { - OpWrapper& activation_op = - CreateSimpleActivationOp(QNN_OP_TANH, input_tensor, output_tensor); - break; - } - default: - return kLiteRtStatusErrorUnsupported; - } - - return kLiteRtStatusOk; -} -*/ - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h deleted file mode 100644 index 2888c3e84262c2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_OP_BUILDER_H_ - -#include -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -enum class PaddingType { - Unknown = 0, - Same, - Valid, -}; - -std::pair ComputePaddingBeforeAfter( - const std::uint32_t input_size, const std::uint32_t filter_size, - const std::uint32_t stride, const std::uint32_t dilation_rate, - const PaddingType padding_type); - -OpWrapper& CreateOpWrapper(std::vector& ops, const char* op_type); - -OpWrapper& CreateSimpleActivationOp(std::vector& ops, - const char* op_type, - const TensorWrapper& input_tensor, - const TensorWrapper& output_tensor); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.cc deleted file mode 100644 index 97dc4c5c9561b7..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.cc +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildPackOp(TensorPool& tensor_pool, - const std::vector& inputs, - const std::vector& outputs, - const int32_t axis) { - std::vector res; - - // pack op with only one input would violate op definition of qnn - // we'll replace it with reshape op - if (inputs.size() == 1) { - auto& op = CreateOpWrapper(res, QNN_OP_RESHAPE); - op.AddInputTensor(inputs[0]); - op.AddOutputTensor(outputs[0]); - return res; - } - - if (outputs[0].get().GetRank() != inputs[0].get().GetRank() + 1) { - auto& concat_op = CreateOpWrapper(res, QNN_OP_CONCAT); - for (const auto& input : inputs) { - concat_op.AddInputTensor(input); - } - concat_op.AddOutputTensor(outputs[0]); - } else { - auto& pack_op = CreateOpWrapper(res, QNN_OP_PACK); - for (const auto& input : inputs) { - pack_op.AddInputTensor(input); - } - std::uint32_t adjusted_axis = - axis < 0 ? axis + inputs[0].get().GetRank() : axis; - pack_op.AddScalarParam(QNN_OP_PACK_PARAM_AXIS, - adjusted_axis); - pack_op.AddOutputTensor(outputs[0]); - } - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.h deleted file mode 100644 index b0e39cc74ccd2f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pack_op_builder.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_PACK_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_PACK_OP_BUILDER_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildPackOp(TensorPool& tensor_pool, - const std::vector& inputs, - const std::vector& outputs, - const int32_t axis); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_PACK_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.cc deleted file mode 100644 index b4b42d743a0cf1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.cc +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { - -constexpr size_t kInputIndex = 0; -constexpr size_t kOutputIndex = 0; -constexpr size_t kHeightIndex = 1; -constexpr size_t kWidthIndex = 2; - -std::vector BuildPool2dOp( - TensorPool& tensor_pool, const char* op_type, const char* filter_param_name, - const char* stride_param_name, const char* padding_param_name, - const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t stride_height, const std::uint32_t stride_width, - const std::uint32_t filter_height, const std::uint32_t filter_width, - const PaddingType padding_type) { - std::vector res; - - OpWrapper& pool_op = CreateOpWrapper(res, op_type); - - TensorWrapper& input_tensor = inputs[kInputIndex]; - pool_op.AddInputTensor(input_tensor); - - // filter param - const std::vector filter_shape{2}; - const std::array filter_data{filter_height, filter_width}; - TensorWrapper& filter_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, filter_shape, - sizeof(decltype(filter_data)::value_type) * filter_data.size(), - filter_data.data()); - pool_op.AddTensorParam(filter_param_name, filter_tensor); - - // stride param - const std::vector stride_shape{2}; - const std::array stride_data{stride_height, stride_width}; - TensorWrapper& stride_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, stride_shape, - sizeof(decltype(stride_data)::value_type) * stride_data.size(), - stride_data.data()); - pool_op.AddTensorParam(stride_param_name, stride_tensor); - - // padding - const auto [padding_before_height, padding_after_height] = - ComputePaddingBeforeAfter(input_tensor.GetDim(kHeightIndex), - filter_height, stride_height, 1, padding_type); - const auto [padding_before_width, padding_after_width] = - ComputePaddingBeforeAfter(input_tensor.GetDim(kWidthIndex), filter_width, - stride_width, 1, padding_type); - const std::vector padding_shape{2, 2}; - const std::array padding_data{ - padding_before_height, padding_after_height, padding_before_width, - padding_after_width}; - TensorWrapper& padding_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, padding_shape, - sizeof(decltype(padding_data)::value_type) * padding_data.size(), - padding_data.data()); - pool_op.AddTensorParam(padding_param_name, padding_tensor); - - TensorWrapper& output_tensor = outputs[kOutputIndex]; - pool_op.AddOutputTensor(output_tensor); - // TODO: fused activation - - return res; -} - -} // namespace - -std::vector BuildMaxPoolOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t stride_height, const std::uint32_t stride_width, - const std::uint32_t filter_height, const std::uint32_t filter_width, - const PaddingType padding_type) { - return BuildPool2dOp( - tensor_pool, QNN_OP_POOL_MAX_2D, QNN_OP_POOL_MAX_2D_PARAM_FILTER_SIZE, - QNN_OP_POOL_MAX_2D_PARAM_STRIDE, QNN_OP_POOL_MAX_2D_PARAM_PAD_AMOUNT, - inputs, outputs, stride_height, stride_width, filter_height, filter_width, - padding_type); -} - -std::vector BuildAveragePoolOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t stride_height, const std::uint32_t stride_width, - const std::uint32_t filter_height, const std::uint32_t filter_width, - const PaddingType padding_type) { - return BuildPool2dOp( - tensor_pool, QNN_OP_POOL_AVG_2D, QNN_OP_POOL_AVG_2D_PARAM_FILTER_SIZE, - QNN_OP_POOL_AVG_2D_PARAM_STRIDE, QNN_OP_POOL_AVG_2D_PARAM_PAD_AMOUNT, - inputs, outputs, stride_height, stride_width, filter_height, filter_width, - padding_type); -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.h deleted file mode 100644 index cb8da0e7a19589..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/pool2d_op_builder.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_POOL2D_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_POOL2D_OP_BUILDER_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildMaxPoolOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t stride_height, const std::uint32_t stride_width, - const std::uint32_t filter_height, const std::uint32_t filter_width, - const PaddingType padding_type); - -std::vector BuildAveragePoolOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t stride_height, const std::uint32_t stride_width, - const std::uint32_t filter_height, const std::uint32_t filter_width, - const PaddingType padding_type); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_POOL2D_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.cc deleted file mode 100644 index 70c4b610336118..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.cc +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildQuantizeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - const char* qnn_op = nullptr; - if (inputs[0].get().IsPerTensorQuantWithOffsetDiff(outputs[0].get())) { - qnn_op = QNN_OP_CAST; - } else if ((inputs[0].get().IsQuant8() || inputs[0].get().IsQuant16()) && - (outputs[0].get().IsQuant8() || outputs[0].get().IsQuant16())) { - qnn_op = QNN_OP_CONVERT; - } else { - qnn_op = QNN_OP_QUANTIZE; - } - - auto& quantize_op = CreateOpWrapper(res, qnn_op); - quantize_op.AddInputTensor(inputs[0]); - quantize_op.AddOutputTensor(outputs[0]); - - return res; -} - -std::vector BuildDequantizeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - const char* qnn_op = nullptr; - if (inputs[0].get().IsF16() && outputs[0].get().IsF32()) { - qnn_op = QNN_OP_CAST; - } else { - qnn_op = QNN_OP_DEQUANTIZE; - } - - auto& quantize_op = CreateOpWrapper(res, qnn_op); - quantize_op.AddInputTensor(inputs[0]); - quantize_op.AddOutputTensor(outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.h deleted file mode 100644 index 2b2cfd923202bc..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/quantize_op_builder.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_QUANTIZE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_QUANTIZE_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildQuantizeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -std::vector BuildDequantizeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_QUANTIZE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.cc deleted file mode 100644 index b978f10450213b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildReduceSumOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool keep_dims) { - std::vector res; - - TensorWrapper& axis_tensor = inputs[1]; - if (!axis_tensor.IsTensorStatic() || axis_tensor.GetRank() != 1) { - QNN_LOG_ERROR( - "The axis tensor is not static, or the rank of axis tensor is not " - "equal to 1."); - return res; - } - - TensorWrapper& input_tensor = inputs[0]; - - auto axis_data = axis_tensor.GetStaticTensorData(); - if (!axis_data.has_value()) { - QNN_LOG_ERROR("Get axis_data failed."); - return res; - } - std::vector adjusted_axis_data; - for (size_t i = 0; i < axis_tensor.GetDim(0); ++i) { - std::uint32_t adjusted_axis = - (*axis_data)[i] >= 0 ? (*axis_data)[i] - : (*axis_data)[i] + input_tensor.GetRank(); - if (std::find(adjusted_axis_data.begin(), adjusted_axis_data.end(), - adjusted_axis) == adjusted_axis_data.end()) { - adjusted_axis_data.emplace_back(adjusted_axis); - } - } - TensorWrapper& adjusted_axis_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, axis_tensor.GetQuantParams(), - {static_cast(adjusted_axis_data.size())}, - sizeof(std::uint32_t) * adjusted_axis_data.size(), - adjusted_axis_data.data()); - - OpWrapper& reduce_op = CreateOpWrapper(res, QNN_OP_REDUCE_SUM); - reduce_op.AddInputTensor(input_tensor); - reduce_op.AddOutputTensor(outputs[0]); - reduce_op.AddTensorParam(QNN_OP_REDUCE_SUM_PARAM_AXES, adjusted_axis_tensor); - reduce_op.AddScalarParam(QNN_OP_REDUCE_SUM_PARAM_KEEP_DIMS, keep_dims); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.h deleted file mode 100644 index cb43106587d91e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reduce_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_REDUCE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_REDUCE_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildReduceSumOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool keep_dims); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_REDUCE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.cc deleted file mode 100644 index ed9330211bf41f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.cc +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildRelu6Op( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - CreateSimpleActivationOp(res, QNN_OP_RELU6, inputs[0], outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.h deleted file mode 100644 index 6261da7fd1b80d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu6_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RELU6_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RELU6_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildRelu6Op( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RELU6_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.cc deleted file mode 100644 index bfbfb37c8dd247..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.cc +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildReluOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - CreateSimpleActivationOp(res, QNN_OP_RELU, inputs[0], outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.h deleted file mode 100644 index 3d2d5da8f2fa7a..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/relu_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RELU_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RELU_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildReluOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RELU_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.cc deleted file mode 100644 index a51711dfb5ac59..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.cc +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildReshapeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& reshape_op = CreateOpWrapper(res, QNN_OP_RESHAPE); - reshape_op.AddInputTensor(inputs[0]); - reshape_op.AddOutputTensor(outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.h deleted file mode 100644 index 6b14ad38bbd01b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/reshape_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RESHAPE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RESHAPE_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildReshapeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RESHAPE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.cc deleted file mode 100644 index c0a1f173b423b0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr size_t kInputIndex = 0; -constexpr size_t kOutputIndex = 0; - -std::vector BuildResizeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const char* op_type, - const char* align_corners_param, const char* half_pixel_centers_param, - const bool align_corners, const bool half_pixel_centers) { - std::vector res; - - auto& resize_op = CreateOpWrapper(res, op_type); - resize_op.AddInputTensor(inputs[kInputIndex]); - resize_op.AddOutputTensor(outputs[kOutputIndex]); - resize_op.AddScalarParam(align_corners_param, align_corners); - resize_op.AddScalarParam(half_pixel_centers_param, half_pixel_centers); - - return res; -} -} // namespace - -std::vector BuildResizeBilinearOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool align_corners, - const bool half_pixel_centers) { - return BuildResizeOp(tensor_pool, inputs, outputs, QNN_OP_RESIZE_BILINEAR, - QNN_OP_RESIZE_BILINEAR_PARAM_ALIGN_CORNERS, - QNN_OP_RESIZE_BILINEAR_PARAM_HALF_PIXEL_CENTERS, - align_corners, half_pixel_centers); -} - -std::vector BuildResizeNearestOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool align_corners, - const bool half_pixel_centers) { - return BuildResizeOp(tensor_pool, inputs, outputs, - QNN_OP_RESIZE_NEAREST_NEIGHBOR, - QNN_OP_RESIZE_NEAREST_NEIGHBOR_PARAM_ALIGN_CORNERS, - QNN_OP_RESIZE_NEAREST_NEIGHBOR_PARAM_HALF_PIXEL_CENTERS, - align_corners, half_pixel_centers); -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.h deleted file mode 100644 index c24e889ee9f0a2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/resize_op_builder.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RESIZE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RESIZE_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildResizeBilinearOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool align_corners, - const bool half_pixel_centers); - -std::vector BuildResizeNearestOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const bool align_corners, - const bool half_pixel_centers); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RESIZE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.cc deleted file mode 100644 index fc88f639b76684..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.cc +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -static constexpr int kInputIndex = 0; -static constexpr int kGammaIndex = 1; - -std::vector BuildRmsNormOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const float epsilon) { - std::vector res; - - auto& rms_norm_op = CreateOpWrapper(res, QNN_OP_RMS_NORM); - for (const auto& input : inputs) { - rms_norm_op.AddInputTensor(input); - } - - // Constructs axis param tensor. - std::vector axis_data; - axis_data.emplace_back(inputs[kInputIndex].get().GetRank() - 1); - TensorWrapper& axis_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, inputs[kInputIndex].get().GetQuantParams(), {1}, - sizeof(std::uint32_t) * axis_data.size(), axis_data.data()); - - if (inputs[kGammaIndex].get().GetDataType() == QNN_DATATYPE_FLOAT_32) { - // Construct float beta static all 0 tensor. - std::vector beta_data( - inputs[kGammaIndex].get().GetTensorNumElements(), 0); - TensorWrapper& beta_tensor = tensor_pool.CreateStaticTensor( - inputs[kGammaIndex].get().GetDataType(), - inputs[kGammaIndex].get().GetQuantParams(), - inputs[kGammaIndex].get().GetDims(), sizeof(float) * beta_data.size(), - beta_data.data()); - rms_norm_op.AddInputTensor(beta_tensor); - } else { - // Construct uint8_t beta static all 0 tensor. - std::vector beta_data( - inputs[kGammaIndex].get().GetTensorNumElements(), 0); - - // Offset needs to be 0, scale does not matter since data is 0 - ScaleOffsetQuantizeParamsWrapper q_param(0.00001, 0); - - TensorWrapper& beta_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UFIXED_POINT_8, q_param, - inputs[kGammaIndex].get().GetDims(), sizeof(uint8_t) * beta_data.size(), - beta_data.data()); - rms_norm_op.AddInputTensor(beta_tensor); - } - - rms_norm_op.AddScalarParam(QNN_OP_RMS_NORM_PARAM_EPSILON, epsilon); - rms_norm_op.AddTensorParam(QNN_OP_RMS_NORM_PARAM_AXES, axis_tensor); - rms_norm_op.AddOutputTensor(outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.h deleted file mode 100644 index f97e35fd58717e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/rms_norm_op_builder.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2025 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RMS_NORM_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RMS_NORM_OP_BUILDER_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildRmsNormOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const float epsilon); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_RMS_NORM_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.cc deleted file mode 100644 index 3312ae3d2e8d96..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildSelectOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - auto& select_op = CreateOpWrapper(res, QNN_OP_ELEMENT_WISE_SELECT); - for (const auto& input : inputs) { - select_op.AddInputTensor(input); - } - select_op.AddOutputTensor(outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.h deleted file mode 100644 index e5a4431f99ddb9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/select_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SELECT_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SELECT_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildSelectOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SELECT_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.cc deleted file mode 100644 index af94e9ab833e22..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.h" - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr int kDefaultStrideValue = 1; -constexpr int kSizeNegative = -1; -constexpr int kRangeNumElements = 3; -} // namespace - -std::vector BuildSliceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - TensorWrapper& input_tensor = inputs[0]; - TensorWrapper& begin_tensor = inputs[1]; - TensorWrapper& size_tensor = inputs[2]; - if (!begin_tensor.IsTensorStatic() || !size_tensor.IsTensorStatic()) { - QNN_LOG_ERROR( - "The begin tensor and size tensor of Slice OP is not static."); - return res; - } - - const auto input_rank = input_tensor.GetRank(); - auto begin_data = begin_tensor.GetStaticTensorData(); - if (!begin_data.has_value()) { - QNN_LOG_ERROR("Get begin_data failed."); - return res; - } - auto size_data = size_tensor.GetStaticTensorData(); - if (!size_data.has_value()) { - QNN_LOG_ERROR("Get size_data failed."); - return res; - } - std::vector range_data; - range_data.reserve(input_rank * kRangeNumElements); - for (size_t i = 0; i < input_rank; ++i) { - range_data.emplace_back((*begin_data)[i]); - if ((*size_data)[i] == kSizeNegative) { - range_data.emplace_back(input_tensor.GetDim(i)); - } else { - range_data.emplace_back((*begin_data)[i] + (*size_data)[i]); - } - range_data.emplace_back(kDefaultStrideValue); - } - TensorWrapper& range_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_INT_32, begin_tensor.GetQuantParams(), - {input_rank, kRangeNumElements}, sizeof(std::int32_t) * range_data.size(), - range_data.data()); - - auto& slice_op = CreateOpWrapper(res, QNN_OP_STRIDED_SLICE); - slice_op.AddTensorParam(QNN_OP_STRIDED_SLICE_PARAM_RANGES, range_tensor); - slice_op.AddInputTensor(input_tensor); - slice_op.AddOutputTensor(outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.h deleted file mode 100644 index 7eb9c013dcccfd..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/slice_op_builder.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SLICE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SLICE_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildSliceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SLICE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.cc deleted file mode 100644 index 5d3e226b011846..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.cc +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildSoftmaxOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const float beta) { - std::vector res; - - auto& softmax_op = CreateOpWrapper(res, QNN_OP_SOFTMAX); - softmax_op.AddInputTensor(inputs[0]); - softmax_op.AddOutputTensor(outputs[0]); - softmax_op.AddScalarParam(QNN_OP_SOFTMAX_PARAM_BETA, beta); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.h deleted file mode 100644 index bac0ea1c0d76d1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/softmax_op_builder.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SOFTMAX_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SOFTMAX_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildSoftmaxOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const float beta); -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SOFTMAX_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.cc deleted file mode 100644 index 9f77d75c18ae01..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.cc +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr size_t kInputIndex = 0; -constexpr size_t kOutputIndex = 0; - -std::vector BuildSpatialTransformOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, const char* op_type, - const char* block_param, const std::uint32_t block_size) { - std::vector res; - - auto& spatial_transform_op = CreateOpWrapper(res, op_type); - spatial_transform_op.AddInputTensor(inputs[kInputIndex]); - spatial_transform_op.AddOutputTensor(outputs[kOutputIndex]); - const std::array block_data = {block_size, block_size}; - const std::vector block_dims{2}; - auto& block_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, QuantizeParamsWrapperVariant{}, block_dims, - sizeof(decltype(block_dims)::value_type) * block_dims.size(), - block_data.data()); - spatial_transform_op.AddTensorParam(block_param, block_tensor); - - return res; -} -} // namespace - -std::vector BuildDepthToSpaceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t block_size) { - return BuildSpatialTransformOp( - tensor_pool, inputs, outputs, QNN_OP_DEPTH_TO_SPACE, - QNN_OP_DEPTH_TO_SPACE_PARAM_BLOCK_SIZE, block_size); -} - -std::vector BuildSpaceToDepthOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t block_size) { - return BuildSpatialTransformOp( - tensor_pool, inputs, outputs, QNN_OP_SPACE_TO_DEPTH, - QNN_OP_SPACE_TO_DEPTH_PARAM_BLOCK_SIZE, block_size); -} -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.h deleted file mode 100644 index c2e7c5e19c68fd..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/spatial_transform_op_builder.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SPATIAL_TRANSFORM_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SPATIAL_TRANSFORM_OP_BUILDER_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildDepthToSpaceOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t block_size); - -std::vector BuildSpaceToDepthOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t block_size); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SPATIAL_TRANSFORM_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.cc deleted file mode 100644 index 4bdb6322bd0e0b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.cc +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -namespace { -constexpr int kSplitIndexRank = 1; -constexpr int kinputAxisIndex = 0; -} // namespace - -std::vector BuildSplitOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t num_splits) { - std::vector res; - - const TensorWrapper& axis_tensor = inputs[kinputAxisIndex]; - if (!axis_tensor.IsTensorStatic()) { - return res; - } - - const TensorWrapper& input_tensor = inputs[kSplitIndexRank]; - auto axis_data = axis_tensor.GetStaticTensorData(); - if (!axis_data.has_value()) { - QNN_LOG_ERROR("Get axis_data failed."); - return res; - } - std::uint32_t axis = (*axis_data)[0] >= 0 - ? (*axis_data)[0] - : (*axis_data)[0] + input_tensor.GetRank(); - - const std::uint32_t slice_size = input_tensor.GetDim(axis) / num_splits; - // The split_indice will do N cuts, split the dimension into N+1 clips - // so 0 will not be included in the split_indice - // for example, when we split 12 into 4 clip, the split index will be {3,6,9} - std::vector split_indice; - split_indice.reserve(num_splits); - for (int i = 1; i < num_splits; i++) { - split_indice.emplace_back(static_cast(i * slice_size)); - } - TensorWrapper& split_indice_tensor = tensor_pool.CreateStaticTensor( - QNN_DATATYPE_UINT_32, axis_tensor.GetQuantParams(), {num_splits - 1}, - sizeof(std::uint32_t) * split_indice.size(), split_indice.data()); - - auto& split_op = CreateOpWrapper(res, QNN_OP_SPLIT); - split_op.AddInputTensor(input_tensor); - for (const auto& output : outputs) { - split_op.AddOutputTensor(output); - } - split_op.AddScalarParam(QNN_OP_SPLIT_PARAM_AXIS, axis); - split_op.AddTensorParam(QNN_OP_SPLIT_PARAM_SPLIT_INDEX, split_indice_tensor); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.h deleted file mode 100644 index 76fafd15cba35c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/split_op_builder.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SPLIT_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SPLIT_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildSplitOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs, - const std::uint32_t num_splits); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_SPLIT_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.cc deleted file mode 100644 index 221ebf796c52e0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildTanhOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - CreateSimpleActivationOp(res, QNN_OP_TANH, inputs[0], outputs[0]); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.h deleted file mode 100644 index 1ede3ba202baf3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/tanh_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_TANH_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_TANH_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildTanhOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_TANH_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.cc deleted file mode 100644 index 5f1415ffadf8fd..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.h" - -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildTransposeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs) { - std::vector res; - - TensorWrapper& perm_tensor = inputs[1]; - if (!perm_tensor.IsTensorStatic()) { - QNN_LOG_ERROR("The param 'perm' of Transpose OP is not static."); - return res; - } - - auto& transpose_op = CreateOpWrapper(res, QNN_OP_TRANSPOSE); - transpose_op.AddInputTensor(inputs[0]); - transpose_op.AddOutputTensor(outputs[0]); - transpose_op.AddTensorParam( - QNN_OP_TRANSPOSE_PARAM_PERM, - tensor_pool.CloneStaticTensorFrom(perm_tensor, QNN_DATATYPE_UINT_32)); - - return res; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.h deleted file mode 100644 index 7f32710f29b309..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/transpose_op_builder.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_TRANSPOSE_OP_BUILDER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_TRANSPOSE_OP_BUILDER_H_ - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/builders/op_builder.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -std::vector BuildTransposeOp( - TensorPool& tensor_pool, const std::vector& inputs, - const std::vector& outputs); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_BUILDERS_TRANSPOSE_OP_BUILDER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/common.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/common.h deleted file mode 100644 index 7fd072eaff9b1d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/common.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_COMMON_H_ - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef enum LiteRtQnnLogLevel { // NOLINT(modernize-use-using) - /// Disable delegate and QNN backend logging messages. - kLogOff = 0, - kLogLevelError = 1, - kLogLevelWarn = 2, - kLogLevelInfo = 3, - kLogLevelVerbose = 4, - kLogLevelDebug = 5, -} LiteRtQnnLogLevel; - -typedef struct { // NOLINT(modernize-use-using) - /// Apply HTP-friendly op builder. - bool useHtpPreferencs; - /// This option will treat quantized int16 tensor as quantized uint16 tensor - /// for better backend compatibility. - bool useQInt16AsQUint16; -} LiteRtQnnOptions; - -// clang-format off -#define LITERT_QNN_OPTIONS_INIT \ - { \ - false, /*useHtpPreferencs*/ \ - true, /*useQInt16AsQUint16*/ \ - } -// clang-format on -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_COMMON_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.cc deleted file mode 100644 index 27cce37e3f2c4d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -TensorPool::TensorPool() = default; - -TensorWrapper& TensorPool::CreateInputTensor( - Qnn_DataType_t data_type, const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back( - id, QNN_TENSOR_TYPE_APP_WRITE, data_type, quant_params, dimentions); - return back; -} - -TensorWrapper& TensorPool::CreateOutpuTensor( - Qnn_DataType_t data_type, const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back( - id, QNN_TENSOR_TYPE_APP_READ, data_type, quant_params, dimentions); - return back; -} - -TensorWrapper& TensorPool::CreateNativeTensor( - Qnn_DataType_t data_type, const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back( - id, QNN_TENSOR_TYPE_NATIVE, data_type, quant_params, dimentions); - return back; -} - -TensorWrapper& TensorPool::CreateStaticTensor( - Qnn_DataType_t data_type, const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions, std::uint32_t bytes, - const void* data) { - const auto id = tensor_wrappers_.size(); - auto& back = - tensor_wrappers_.emplace_back(id, QNN_TENSOR_TYPE_STATIC, data_type, - quant_params, dimentions, bytes, data); - return back; -} - -TensorWrapper& TensorPool::CloneNativeTensorFrom(const TensorWrapper& src) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back( - id, QNN_TENSOR_TYPE_NATIVE, src.GetDataType(), src.quantize_params_, - src.dimentions_); - return back; -} - -TensorWrapper& TensorPool::CloneNativeTensorFrom( - const TensorWrapper& src, const std::vector& dimentions) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back(id, QNN_TENSOR_TYPE_NATIVE, - src.GetDataType(), - src.quantize_params_, dimentions); - return back; -} - -TensorWrapper& TensorPool::CloneStaticTensorFrom(const TensorWrapper& src, - Qnn_DataType_t data_type) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back( - id, QNN_TENSOR_TYPE_STATIC, data_type, src.quantize_params_, - src.dimentions_, src.owned_data_.size(), src.owned_data_.data()); - return back; -} - -TensorWrapper& TensorPool::CloneStaticTensorFrom( - const TensorWrapper& src, const std::vector& dimentions) { - const auto id = tensor_wrappers_.size(); - auto& back = tensor_wrappers_.emplace_back( - id, QNN_TENSOR_TYPE_STATIC, src.qnn_tensor_.v2.dataType, - src.quantize_params_, dimentions, src.qnn_tensor_.v2.clientBuf.dataSize, - src.qnn_tensor_.v2.clientBuf.data); - - return back; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h deleted file mode 100644 index a21199ad2e40c2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/tensor_pool.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_TENSOR_POOL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_TENSOR_POOL_H_ - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -class TensorPool { - public: - TensorPool(); - - TensorWrapper& CreateInputTensor( - Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions); - - TensorWrapper& CreateOutpuTensor( - Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions); - - TensorWrapper& CreateNativeTensor( - Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions); - - TensorWrapper& CreateStaticTensor( - Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quant_params, - const std::vector& dimentions, std::uint32_t bytes, - const void* data); - - TensorWrapper& CloneNativeTensorFrom(const TensorWrapper& src); - - TensorWrapper& CloneNativeTensorFrom( - const TensorWrapper& src, const std::vector& dimentions); - - TensorWrapper& CloneStaticTensorFrom(const TensorWrapper& src, - Qnn_DataType_t data_type); - - TensorWrapper& CloneStaticTensorFrom( - const TensorWrapper& src, const std::vector& dimentions); - - template - void ForEach(UnaryFunc f) { - for (auto& tensor_wrapper : tensor_wrappers_) { - f(tensor_wrapper); - } - } - - private: - std::list tensor_wrappers_{}; -}; - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_TENSOR_POOL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/BUILD deleted file mode 100644 index 3ce72dec755646..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/BUILD +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/qualcomm:__subpackages__"], -) - -cc_library( - name = "log", - srcs = select({ - "//tensorflow:android": ["log_android.cc"], - "//conditions:default": ["log_default.cc"], - }), - hdrs = ["log.h"], - linkopts = select({ - "//tensorflow:android": ["-llog"], - "//conditions:default": [], - }), - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:common", - ], -) - -cc_library( - name = "miscs", - srcs = ["miscs.cc"], - hdrs = ["miscs.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "utils_test", - srcs = [ - "utils_test.cc", - ], - tags = [ - # Tests with ungrte deps do not currently work on forge. - "no-remote-exec", - "notap", - # Don't build/test in OS until qnn is available. - "nobuilder", - "no_oss", - # Sanitizer runtime doesn't work with anything that loads libQnnHtp.so. - "nosan", - ], - deps = [ - ":log", - ":miscs", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core:common", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h deleted file mode 100644 index f89b4131dea4b6..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_UTILS_LOG_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_UTILS_LOG_H_ - -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/common.h" - -namespace qnn { - -class QNNLogger { - public: - // Logging hook that takes variadic args. - static void Log(LiteRtQnnLogLevel severity, const char* format, ...); - - // Set file descriptor - static void SetLogFilePointer(FILE* fp); - - // Set log level - static void SetLogLevel(LiteRtQnnLogLevel log_level); - - private: - // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) - static FILE* log_file_pointer_; - static LiteRtQnnLogLevel log_level_; - // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) -}; -} // namespace qnn - -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define QNN_LOG_VERBOSE(format, ...) \ - ::qnn::QNNLogger::Log(kLogLevelVerbose, ("VERBOSE: [Qnn] " format), \ - ##__VA_ARGS__); - -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define QNN_LOG_INFO(format, ...) \ - ::qnn::QNNLogger::Log(kLogLevelInfo, ("INFO: [Qnn] " format), ##__VA_ARGS__); - -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define QNN_LOG_WARNING(format, ...) \ - ::qnn::QNNLogger::Log(kLogLevelWarn, ("WARNING: [Qnn] " format), \ - ##__VA_ARGS__); - -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define QNN_LOG_ERROR(format, ...) \ - ::qnn::QNNLogger::Log(kLogLevelError, ("ERROR: [Qnn] " format), \ - ##__VA_ARGS__); - -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define QNN_LOG_DEBUG(format, ...) \ - ::qnn::QNNLogger::Log(kLogLevelDebug, ("DEBUG: [Qnn] " format), \ - ##__VA_ARGS__); - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_UTILS_LOG_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_android.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_android.cc deleted file mode 100644 index ec13856cda945b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_android.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include "log.h" - -namespace qnn { -namespace { - -int GetPlatformSeverity(LiteRtQnnLogLevel severity) { - switch (severity) { - case kLogLevelError: - return ANDROID_LOG_ERROR; - case kLogLevelWarn: - return ANDROID_LOG_WARN; - case kLogLevelInfo: - return ANDROID_LOG_INFO; - case kLogLevelVerbose: - return ANDROID_LOG_VERBOSE; - default: - return ANDROID_LOG_DEBUG; - } -} - -} // namespace - -// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) -FILE* QNNLogger::log_file_pointer_ = stderr; -LiteRtQnnLogLevel QNNLogger::log_level_ = kLogLevelInfo; -// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) -void QNNLogger::SetLogFilePointer(FILE* fp) { log_file_pointer_ = fp; } -void QNNLogger::SetLogLevel(LiteRtQnnLogLevel log_level) { - log_level_ = log_level; -} -// NOLINTNEXTLINE(cert-dcl50-cpp) -void QNNLogger::Log(LiteRtQnnLogLevel severity, const char* format, ...) { - if (severity > log_level_) { - return; - } - - // Pass to LogFormatted - va_list args; - va_start(args, format); - - // First log to Android's explicit log(cat) API. - va_list args_copy; - va_copy(args_copy, args); - __android_log_vprint(GetPlatformSeverity(severity), "qnn", format, args_copy); - va_end(args_copy); - - // Print to file pointer. - vfprintf(log_file_pointer_, format, args); - fputc('\n', log_file_pointer_); - - va_end(args); -} -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_default.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_default.cc deleted file mode 100644 index 6d9067d26d61a3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log_default.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include - -#include "log.h" - -namespace qnn { - -// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) -FILE* QNNLogger::log_file_pointer_ = stderr; -LiteRtQnnLogLevel QNNLogger::log_level_ = kLogLevelInfo; -// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) -void QNNLogger::SetLogFilePointer(FILE* fp) { log_file_pointer_ = fp; } -void QNNLogger::SetLogLevel(LiteRtQnnLogLevel log_level) { - log_level_ = log_level; -} -// NOLINTNEXTLINE(cert-dcl50-cpp) -void QNNLogger::Log(LiteRtQnnLogLevel severity, const char* format, ...) { - if (severity > log_level_) { - return; - } - - // Pass to LogFormatted - va_list args; - va_start(args, format); - - // Print to file pointer. - vfprintf(log_file_pointer_, format, args); - fputc('\n', log_file_pointer_); - - va_end(args); -} -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.cc deleted file mode 100644 index e07ef251adcc10..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h" - -#include -#include - -#include "absl/types/span.h" - -namespace qnn { -void ConvertDataFromInt16toUInt16(absl::Span src, - std::vector& dst) { - dst.clear(); - dst.reserve(src.size()); - for (const auto& data : src) { - dst.emplace_back(data + kUint16ZeroPoint); - } -} - -void ConvertDataFromUInt16toInt16(absl::Span src, - std::vector& dst) { - dst.clear(); - dst.reserve(src.size()); - for (const auto& data : src) { - dst.emplace_back(data - kUint16ZeroPoint); - } -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h deleted file mode 100644 index 7b12cc09eecf3d..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_UTILS_MISCS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_UTILS_MISCS_H_ - -#include -#include -#include -#include -#include - -#include "absl/types/span.h" - -namespace qnn { - -constexpr uint32_t kUint16ZeroPoint = -std::numeric_limits::min(); - -template -inline constexpr bool always_false = false; - -template -T Quantize(const float val, const float scale, const int32_t zero_point) { - static_assert(std::is_integral::value, - "Integral required in Quantize function."); - return std::round(val / scale) + zero_point; -} - -template -float Dequantize(const T val, const float scale, const int32_t zero_point) { - static_assert(std::is_integral::value, - "Integral required in Dequantize function."); - return scale * (val - zero_point); -} - -void ConvertDataFromInt16toUInt16(absl::Span src, - std::vector& dst); - -void ConvertDataFromUInt16toInt16(absl::Span src, - std::vector& dst); - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_UTILS_MISCS_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/utils_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/utils_test.cc deleted file mode 100644 index c8953157ada8fb..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/utils_test.cc +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include -#include -#include -#include -#include -#include -#include - -#include -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h" - -namespace qnn { -namespace { - -bool IsPrefix(std::string_view prefix, std::string_view full) { - return prefix == full.substr(0, prefix.size()); -} - -bool CheckLoggoing(const std::string log_path, LiteRtQnnLogLevel log_level) { - std::ifstream fin(log_path); - std::string msg; - while (std::getline(fin, msg)) { - // Log severity: DEBUG > VERBOSE > INFO > WARN > ERROR - switch (log_level) { - case kLogOff: - if (IsPrefix("ERROR:", msg)) return false; - [[fallthrough]]; - case kLogLevelError: - if (IsPrefix("WARNING:", msg)) return false; - [[fallthrough]]; - case kLogLevelWarn: - if (IsPrefix("INFO:", msg)) return false; - [[fallthrough]]; - case kLogLevelInfo: - if (IsPrefix("VERBOSE:", msg)) return false; - [[fallthrough]]; - case kLogLevelVerbose: - if (IsPrefix("DEBUG:", msg)) return false; - [[fallthrough]]; - default: - break; - } - } - return true; -} - -} // namespace - -class LiteRtLog : public ::testing::TestWithParam {}; -INSTANTIATE_TEST_SUITE_P(, LiteRtLog, - ::testing::Values(kLogOff, kLogLevelError, - kLogLevelWarn, kLogLevelInfo, - kLogLevelVerbose, kLogLevelDebug)); - -TEST_P(LiteRtLog, SanityTest) { - // Create temp file for log - std::filesystem::path temp_path = - std::filesystem::temp_directory_path() / "temp.log"; - std::ofstream fout(temp_path); - ASSERT_TRUE(fout.is_open()); - - // Set log file pointer - FILE* file_ptr = fopen(temp_path.c_str(), "w"); - ASSERT_NE(file_ptr, nullptr); - qnn::QNNLogger::SetLogFilePointer(file_ptr); - - // Set log_level and print message to file - LiteRtQnnLogLevel log_level = GetParam(); - qnn::QNNLogger::SetLogLevel(log_level); - QNN_LOG_VERBOSE("This is a verbose message."); - QNN_LOG_INFO("This is an info message."); - QNN_LOG_WARNING("This is a warning message."); - QNN_LOG_ERROR("This is an error message."); - QNN_LOG_DEBUG("This is a debug message."); - qnn::QNNLogger::SetLogFilePointer(stderr); - fclose(file_ptr); - - // Check logging messages are as expected - ASSERT_EQ(CheckLoggoing(temp_path.string(), log_level), true); - - // Delete the temporary log file - std::filesystem::remove(temp_path); -} - -TEST(MiscTest, TestAlwaysFalse) { - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); - ASSERT_FALSE(::qnn::always_false); -} - -TEST(MiscTests, Quantize) { - float val = 1; - float scale = 0.1; - int32_t zero_point = 1; - auto q_val = Quantize(val, scale, zero_point); - EXPECT_EQ(q_val, 11); -} - -TEST(MiscTests, Dequantize) { - std::int8_t q_val = 11; - float scale = 0.1; - int32_t zero_point = 1; - auto val = Dequantize(q_val, scale, zero_point); - EXPECT_FLOAT_EQ(val, 1); -} - -TEST(MiscTests, ConvertDataFromInt16toUInt16) { - constexpr int16_t int16_data[4] = {0, 1, 2, 3}; - size_t data_len = sizeof(int16_data) / sizeof(int16_data[0]); - absl::Span int16_span(int16_data, data_len); - std::vector uint16_data; - - ConvertDataFromInt16toUInt16(int16_span, uint16_data); - EXPECT_EQ(uint16_data[0], 32768); - EXPECT_EQ(uint16_data[1], 32769); - EXPECT_EQ(uint16_data[2], 32770); - EXPECT_EQ(uint16_data[3], 32771); -} - -TEST(MiscTests, ConvertDataFromUInt16toInt16) { - constexpr uint16_t uint16_data[4] = {32768, 32769, 32770, 32771}; - size_t data_len = sizeof(uint16_data) / sizeof(uint16_data[0]); - absl::Span uint16_span(uint16_data, data_len); - std::vector int16_data; - - ConvertDataFromUInt16toInt16(uint16_span, int16_data); - EXPECT_EQ(int16_data[0], 0); - EXPECT_EQ(int16_data[1], 1); - EXPECT_EQ(int16_data[2], 2); - EXPECT_EQ(int16_data[3], 3); -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/BUILD deleted file mode 100644 index e904d2a9c4efb3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/BUILD +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/qualcomm:__subpackages__"], -) - -cc_library( - name = "quantize_params_wrapper", - srcs = ["quantize_params_wrapper.cc"], - hdrs = ["quantize_params_wrapper.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - ], -) - -cc_library( - name = "tensor_wrapper", - srcs = ["tensor_wrapper.cc"], - hdrs = ["tensor_wrapper.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - ":quantize_params_wrapper", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:miscs", - ], -) - -cc_library( - name = "param_wrapper", - srcs = ["param_wrapper.cc"], - hdrs = ["param_wrapper.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:log", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:miscs", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_library( - name = "op_wrapper", - srcs = ["op_wrapper.cc"], - hdrs = ["op_wrapper.h"], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:param_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.cc deleted file mode 100644 index 43ac6a1a0704f9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -OpWrapper::OpWrapper(std::string name, const char* op_type) - : type_name_{op_type}, name_{std::move(name)} {} - -OpWrapper::OpWrapper(OpWrapper&& other) - : type_name_{other.type_name_}, - name_{std::move(other.name_)}, - input_tensors_{std::move(other.input_tensors_)}, - output_tensors_{std::move(other.output_tensors_)}, - scalar_params_{std::move(other.scalar_params_)}, - tensor_params_{std::move(other.tensor_params_)}, - qnn_input_tensors_{std::move(other.qnn_input_tensors_)}, - qnn_output_tensors_{std::move(other.qnn_output_tensors_)}, - qnn_params_{std::move(other.qnn_params_)} {} - -OpWrapper::~OpWrapper() = default; - -void OpWrapper::AddInputTensor(const TensorWrapper& tensor) { - input_tensors_.emplace_back(tensor); -} - -void OpWrapper::AddOutputTensor(const TensorWrapper& tensor) { - output_tensors_.emplace_back(tensor); -} - -void OpWrapper::AddTensorParam(const char* name, const TensorWrapper& tensor) { - tensor_params_.emplace_back(name, tensor); -} - -Qnn_OpConfig_t OpWrapper::GetOpConfig() { - Qnn_OpConfig_t qnn_op = QNN_OPCONFIG_INIT; - qnn_op.v1.packageName = QNN_OP_PACKAGE_NAME_QTI_AISW; - qnn_op.v1.typeName = type_name_; - qnn_op.v1.name = name_.data(); - // input tensors - qnn_input_tensors_.reserve(input_tensors_.size()); - qnn_input_tensors_.clear(); - for (const auto& input_tensor : input_tensors_) { - auto& back = qnn_input_tensors_.emplace_back(); - input_tensor.get().CloneTo(back); - } - qnn_op.v1.numOfInputs = qnn_input_tensors_.size(); - qnn_op.v1.inputTensors = qnn_input_tensors_.data(); - // output tensors - qnn_output_tensors_.reserve(output_tensors_.size()); - qnn_output_tensors_.clear(); - for (const auto& output_tensor : output_tensors_) { - auto& back = qnn_output_tensors_.emplace_back(); - output_tensor.get().CloneTo(back); - } - qnn_op.v1.numOfOutputs = qnn_output_tensors_.size(); - qnn_op.v1.outputTensors = qnn_output_tensors_.data(); - // params - qnn_params_.reserve(scalar_params_.size() + tensor_params_.size()); - qnn_params_.clear(); - for (const auto& scalar_param : scalar_params_) { - auto& back = qnn_params_.emplace_back(); - scalar_param.CloneTo(back); - } - for (const auto& tensor_param : tensor_params_) { - auto& back = qnn_params_.emplace_back(); - tensor_param.CloneTo(back); - } - qnn_op.v1.numOfParams = qnn_params_.size(); - qnn_op.v1.params = qnn_params_.data(); - return qnn_op; -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h deleted file mode 100644 index 62858fb2ec2421..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_OP_WRAPPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_OP_WRAPPER_H_ - -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -class OpWrapper final { - public: - explicit OpWrapper(std::string name, const char* op_type); - - OpWrapper(const OpWrapper& other) = delete; - - OpWrapper(OpWrapper&& other); - - ~OpWrapper(); - - void AddInputTensor(const TensorWrapper& tensor); - - void AddOutputTensor(const TensorWrapper& tensor); - - template - void AddScalarParam(const char* name, const T data, - const bool is_quant = false) { - scalar_params_.emplace_back(name, data, is_quant); - } - - void AddTensorParam(const char* name, const TensorWrapper& tensor); - - Qnn_OpConfig_t GetOpConfig(); - - private: - const char* type_name_{nullptr}; - std::string name_{}; // human readable name - std::vector> input_tensors_{}; - std::vector> output_tensors_{}; - std::vector scalar_params_{}; - std::vector tensor_params_{}; - std::vector qnn_input_tensors_{}; - std::vector qnn_output_tensors_{}; - std::vector qnn_params_{}; -}; - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_OP_WRAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.cc deleted file mode 100644 index 9be8b2b4d635c2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.h" - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -void ScalarParamWrapper::CloneTo(Qnn_Param_t& dst) const { - dst.name = name_; - dst.paramType = QNN_PARAMTYPE_SCALAR; - dst.scalarParam = qnn_scalar_; -} - -TensorParamWrapper::TensorParamWrapper(const char* name, - const TensorWrapper& tensor) - : name_{name}, tensor_{tensor} {} - -void TensorParamWrapper::CloneTo(Qnn_Param_t& dst) const { - dst.name = name_; - dst.paramType = QNN_PARAMTYPE_TENSOR; - tensor_.CloneTo(dst.tensorParam); -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.h deleted file mode 100644 index 9dbc63102cf6f2..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_PARAM_WRAPPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_PARAM_WRAPPER_H_ - -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { - -class ScalarParamWrapper { - public: - template - explicit ScalarParamWrapper(const char* name, const T data, - const bool is_quant) - : name_{name} { - if constexpr (std::is_same_v) { - qnn_scalar_.dataType = QNN_DATATYPE_BOOL_8; - qnn_scalar_.bool8Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = - is_quant ? QNN_DATATYPE_UFIXED_POINT_8 : QNN_DATATYPE_UINT_8; - qnn_scalar_.uint8Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = - is_quant ? QNN_DATATYPE_SFIXED_POINT_8 : QNN_DATATYPE_INT_8; - qnn_scalar_.int8Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = - is_quant ? QNN_DATATYPE_UFIXED_POINT_16 : QNN_DATATYPE_UINT_16; - qnn_scalar_.uint16Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = - is_quant ? QNN_DATATYPE_SFIXED_POINT_16 : QNN_DATATYPE_INT_16; - qnn_scalar_.int16Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = - is_quant ? QNN_DATATYPE_UFIXED_POINT_32 : QNN_DATATYPE_UINT_32; - qnn_scalar_.uint32Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = - is_quant ? QNN_DATATYPE_SFIXED_POINT_32 : QNN_DATATYPE_INT_32; - qnn_scalar_.int32Value = data; - } else if constexpr (std::is_same_v) { - qnn_scalar_.dataType = QNN_DATATYPE_FLOAT_32; - qnn_scalar_.floatValue = data; - } else { - static_assert(::qnn::always_false, - "Unsupported data type for scalar param."); - } - } - - void CloneTo(Qnn_Param_t& dst) const; - - private: - const char* name_ = nullptr; - Qnn_Scalar_t qnn_scalar_ = QNN_SCALAR_INIT; -}; - -class TensorParamWrapper { - public: - explicit TensorParamWrapper(const char* name, const TensorWrapper& tensor); - - void CloneTo(Qnn_Param_t& dst) const; - - private: - const char* name_ = nullptr; - const TensorWrapper& tensor_; -}; - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_PARAM_WRAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.cc deleted file mode 100644 index ce327633207ef5..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.cc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" - -#include -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" - -namespace qnn { - -UndefinedQuantizeParamsWrapper::UndefinedQuantizeParamsWrapper() = default; - -UndefinedQuantizeParamsWrapper::UndefinedQuantizeParamsWrapper( - const UndefinedQuantizeParamsWrapper&) = default; - -UndefinedQuantizeParamsWrapper::UndefinedQuantizeParamsWrapper( - UndefinedQuantizeParamsWrapper&&) = default; - -void UndefinedQuantizeParamsWrapper::CloneTo(Qnn_QuantizeParams_t& dst) { - dst = qnn_quantize_param_; -} - -ScaleOffsetQuantizeParamsWrapper::ScaleOffsetQuantizeParamsWrapper( - const float scale, const std::int32_t zero_point) { - qnn_quantize_param_.encodingDefinition = QNN_DEFINITION_DEFINED; - qnn_quantize_param_.quantizationEncoding = - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; - qnn_quantize_param_.scaleOffsetEncoding.scale = scale; - qnn_quantize_param_.scaleOffsetEncoding.offset = -1 * zero_point; -} - -ScaleOffsetQuantizeParamsWrapper::ScaleOffsetQuantizeParamsWrapper( - const ScaleOffsetQuantizeParamsWrapper&) = default; - -ScaleOffsetQuantizeParamsWrapper::ScaleOffsetQuantizeParamsWrapper( - ScaleOffsetQuantizeParamsWrapper&&) = default; - -void ScaleOffsetQuantizeParamsWrapper::CloneTo(Qnn_QuantizeParams_t& dst) { - dst = qnn_quantize_param_; -} - -AxisScaleOffsetQuantizeParamsWrapper::AxisScaleOffsetQuantizeParamsWrapper( - const std::int32_t axis, const absl::Span scales, - const absl::Span zero_points) - : scale_offsets_(scales.size()) { - assert(scales.size() == zero_points.size()); - for (size_t i = 0; i < scale_offsets_.size(); ++i) { - scale_offsets_[i].scale = scales[i]; - scale_offsets_[i].offset = -1 * zero_points[i]; - } - - qnn_quantize_param_.encodingDefinition = QNN_DEFINITION_DEFINED; - qnn_quantize_param_.quantizationEncoding = - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; - qnn_quantize_param_.axisScaleOffsetEncoding.axis = axis; - qnn_quantize_param_.axisScaleOffsetEncoding.numScaleOffsets = - scale_offsets_.size(); - qnn_quantize_param_.axisScaleOffsetEncoding.scaleOffset = - scale_offsets_.data(); -} - -AxisScaleOffsetQuantizeParamsWrapper::AxisScaleOffsetQuantizeParamsWrapper( - const AxisScaleOffsetQuantizeParamsWrapper& rhs) - : qnn_quantize_param_{rhs.qnn_quantize_param_}, - scale_offsets_{rhs.scale_offsets_} { - qnn_quantize_param_.axisScaleOffsetEncoding.scaleOffset = - scale_offsets_.data(); -} - -AxisScaleOffsetQuantizeParamsWrapper::AxisScaleOffsetQuantizeParamsWrapper( - AxisScaleOffsetQuantizeParamsWrapper&& rhs) - : qnn_quantize_param_{rhs.qnn_quantize_param_}, - scale_offsets_{std::move(rhs.scale_offsets_)} { - qnn_quantize_param_.axisScaleOffsetEncoding.scaleOffset = - scale_offsets_.data(); -} - -void AxisScaleOffsetQuantizeParamsWrapper::CloneTo(Qnn_QuantizeParams_t& dst) { - dst = qnn_quantize_param_; -} - -std::int32_t AxisScaleOffsetQuantizeParamsWrapper::GetAxis() const { - return qnn_quantize_param_.axisScaleOffsetEncoding.axis; -} - -void AxisScaleOffsetQuantizeParamsWrapper::SetAxis(const std::int32_t axis) { - qnn_quantize_param_.axisScaleOffsetEncoding.axis = axis; -} - -void AxisScaleOffsetQuantizeParamsWrapper::GetScales( - std::vector& scales) const { - scales.clear(); - scales.reserve(scale_offsets_.size()); - for (size_t i = 0; i < scale_offsets_.size(); ++i) { - scales.emplace_back(scale_offsets_[i].scale); - } -} - -void AxisScaleOffsetQuantizeParamsWrapper::GetZeroPoints( - std::vector& zero_points) const { - zero_points.clear(); - zero_points.reserve(scale_offsets_.size()); - for (size_t i = 0; i < scale_offsets_.size(); ++i) { - zero_points.emplace_back(-1 * scale_offsets_[i].offset); - } -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h deleted file mode 100644 index ee209ef4c7d2f9..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. -// All Rights Reserved. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_QUANTIZE_PARAMS_WRAPPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_QUANTIZE_PARAMS_WRAPPER_H_ - -#include -#include -#include -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" - -namespace qnn { - -class UndefinedQuantizeParamsWrapper final { - public: - UndefinedQuantizeParamsWrapper(); - - UndefinedQuantizeParamsWrapper(const UndefinedQuantizeParamsWrapper&); - - UndefinedQuantizeParamsWrapper(UndefinedQuantizeParamsWrapper&&); - - void CloneTo(Qnn_QuantizeParams_t& dst); - - private: - Qnn_QuantizeParams_t qnn_quantize_param_ = QNN_QUANTIZE_PARAMS_INIT; -}; - -class ScaleOffsetQuantizeParamsWrapper final { - public: - explicit ScaleOffsetQuantizeParamsWrapper(const float scale, - const std::int32_t zero_point); - - ScaleOffsetQuantizeParamsWrapper(const ScaleOffsetQuantizeParamsWrapper&); - - ScaleOffsetQuantizeParamsWrapper(ScaleOffsetQuantizeParamsWrapper&&); - - void CloneTo(Qnn_QuantizeParams_t& dst); - - float GetScale() const { - return qnn_quantize_param_.scaleOffsetEncoding.scale; - } - - std::int32_t GetZeroPoint() const { - return -1 * qnn_quantize_param_.scaleOffsetEncoding.offset; - } - - private: - Qnn_QuantizeParams_t qnn_quantize_param_ = QNN_QUANTIZE_PARAMS_INIT; -}; - -class AxisScaleOffsetQuantizeParamsWrapper final { - public: - explicit AxisScaleOffsetQuantizeParamsWrapper( - const std::int32_t axis, const absl::Span scales, - const absl::Span zero_points); - - AxisScaleOffsetQuantizeParamsWrapper( - const AxisScaleOffsetQuantizeParamsWrapper& rhs); - - AxisScaleOffsetQuantizeParamsWrapper( - AxisScaleOffsetQuantizeParamsWrapper&& rhs); - - void CloneTo(Qnn_QuantizeParams_t& dst); - - std::int32_t GetAxis() const; - - void SetAxis(const std::int32_t axis); - - void GetScales(std::vector& scales) const; - - void GetZeroPoints(std::vector& zero_points) const; - - private: - Qnn_QuantizeParams_t qnn_quantize_param_ = QNN_QUANTIZE_PARAMS_INIT; - std::vector scale_offsets_; -}; - -using QuantizeParamsWrapperVariant = - std::variant; - -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_QUANTIZE_PARAMS_WRAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.cc deleted file mode 100644 index 1e78b7922d866e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.cc +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" - -namespace qnn { - -std::size_t GetDataTypeSize(const Qnn_DataType_t data_type) { - std::size_t bytes = 0; - switch (data_type) { - case QNN_DATATYPE_INT_8: - case QNN_DATATYPE_UINT_8: - case QNN_DATATYPE_SFIXED_POINT_8: - case QNN_DATATYPE_UFIXED_POINT_8: - case QNN_DATATYPE_BOOL_8: - bytes = 1; - break; - case QNN_DATATYPE_INT_16: - case QNN_DATATYPE_UINT_16: - case QNN_DATATYPE_FLOAT_16: - case QNN_DATATYPE_SFIXED_POINT_16: - case QNN_DATATYPE_UFIXED_POINT_16: - bytes = 2; - break; - case QNN_DATATYPE_INT_32: - case QNN_DATATYPE_UINT_32: - case QNN_DATATYPE_FLOAT_32: - case QNN_DATATYPE_SFIXED_POINT_32: - case QNN_DATATYPE_UFIXED_POINT_32: - bytes = 4; - break; - case QNN_DATATYPE_INT_64: - case QNN_DATATYPE_UINT_64: - case QNN_DATATYPE_FLOAT_64: - bytes = 8; - break; - case QNN_DATATYPE_UNDEFINED: - case QNN_DATATYPE_SFIXED_POINT_4: - case QNN_DATATYPE_UFIXED_POINT_4: - default: - bytes = 0; - break; - } - return bytes; -} - -TensorWrapper::TensorWrapper() = default; - -TensorWrapper::TensorWrapper( - std::uint32_t id, Qnn_TensorType_t tensor_type, Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quantize_params, - const std::vector& dimentions) - : name_{std::to_string(id)}, - dimentions_{dimentions}, - quantize_params_{quantize_params} { - qnn_tensor_.v2.name = name_.c_str(); - qnn_tensor_.v2.type = tensor_type; - qnn_tensor_.v2.dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER; - qnn_tensor_.v2.dataType = data_type; - std::visit( - [this](auto&& quantize_params) -> void { - quantize_params.CloneTo(qnn_tensor_.v2.quantizeParams); - }, - quantize_params_); - qnn_tensor_.v2.rank = dimentions_.size(); - qnn_tensor_.v2.dimensions = dimentions_.data(); - qnn_tensor_.v2.memType = QNN_TENSORMEMTYPE_RAW; -} - -TensorWrapper::TensorWrapper( - std::uint32_t id, Qnn_TensorType_t tensor_type, Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quantize_params, - const std::vector& dimentions, std::uint32_t bytes, - const void* data) - : TensorWrapper(id, tensor_type, data_type, quantize_params, dimentions) { - SetDataBy(bytes, data); -} - -TensorWrapper::TensorWrapper(const TensorWrapper& other) - : qnn_tensor_{other.qnn_tensor_}, - name_{other.name_}, - dimentions_{other.dimentions_}, - quantize_params_{other.quantize_params_}, - owned_data_{other.owned_data_} { - qnn_tensor_.v2.name = name_.c_str(); - qnn_tensor_.v2.dimensions = dimentions_.data(); - qnn_tensor_.v2.clientBuf.data = owned_data_.data(); - std::visit( - [this](auto&& quant_params) -> void { - quant_params.CloneTo(qnn_tensor_.v2.quantizeParams); - }, - quantize_params_); -} - -TensorWrapper::TensorWrapper(TensorWrapper&& other) - : qnn_tensor_{other.qnn_tensor_}, - name_{std::move(other.name_)}, - dimentions_{std::move(other.dimentions_)}, - quantize_params_{std::move(other.quantize_params_)}, - owned_data_{std::move(other.owned_data_)} { - qnn_tensor_.v2.name = name_.c_str(); - qnn_tensor_.v2.dimensions = dimentions_.data(); - qnn_tensor_.v2.clientBuf.data = owned_data_.data(); - std::visit( - [this](auto&& quant_params) -> void { - quant_params.CloneTo(qnn_tensor_.v2.quantizeParams); - }, - quantize_params_); -} - -TensorWrapper::~TensorWrapper() = default; - -std::uint32_t TensorWrapper::GetDim(size_t index) const { - return dimentions_[index]; -} - -Qnn_DataType_t TensorWrapper::GetDataType() const { - return qnn_tensor_.v2.dataType; -} - -void TensorWrapper::CloneTo(Qnn_Tensor_t& dst) const { dst = qnn_tensor_; } - -std::uint32_t TensorWrapper::GetRank() const { return qnn_tensor_.v2.rank; } - -Qnn_TensorType_t TensorWrapper::GetTensorType() const { - return qnn_tensor_.v2.type; -} - -std::uint32_t TensorWrapper::GetTensorNumElements() const { - return GetDims().empty() ? 0 - : std::accumulate(GetDims().begin(), GetDims().end(), - 1, std::multiplies<>()); -} - -size_t TensorWrapper::GetTensorBytes() const { - return GetDataTypeSize(GetDataType()) * GetTensorNumElements(); -} - -bool TensorWrapper::IsPerTensorQuantWithOffsetDiff( - const TensorWrapper& rhs) const { - const auto& lhs_quant = qnn_tensor_.v2.quantizeParams; - const auto& rhs_quant = rhs.qnn_tensor_.v2.quantizeParams; - - if (lhs_quant.encodingDefinition != QNN_DEFINITION_DEFINED || - rhs_quant.encodingDefinition != QNN_DEFINITION_DEFINED) { - return false; - } - - if (lhs_quant.quantizationEncoding != - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET || - rhs_quant.quantizationEncoding != - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { - return false; - } - - const auto lhs_scale = lhs_quant.scaleOffsetEncoding.scale; - const auto lhs_offset = lhs_quant.scaleOffsetEncoding.offset; - const auto rhs_scale = rhs_quant.scaleOffsetEncoding.scale; - const auto rhs_offset = rhs_quant.scaleOffsetEncoding.offset; - if ((GetDataType() == QNN_DATATYPE_SFIXED_POINT_8 && - rhs.GetDataType() == QNN_DATATYPE_UFIXED_POINT_8) || - (GetDataType() == QNN_DATATYPE_UFIXED_POINT_8 && - rhs.GetDataType() == QNN_DATATYPE_SFIXED_POINT_8)) { - constexpr int kSUFixed8OffsetDiff = 128; - if (std::fabs(lhs_scale - rhs_scale) < - std::numeric_limits::epsilon() && - std::abs(lhs_offset - rhs_offset) == kSUFixed8OffsetDiff) { - return true; - } - } else if ((GetDataType() == QNN_DATATYPE_SFIXED_POINT_16 && - rhs.GetDataType() == QNN_DATATYPE_UFIXED_POINT_16) || - (GetDataType() == QNN_DATATYPE_UFIXED_POINT_16 && - rhs.GetDataType() == QNN_DATATYPE_SFIXED_POINT_16)) { - constexpr int kSUFixed16OffsetDiff = 32768; - if (std::fabs(lhs_scale - rhs_scale) < - std::numeric_limits::epsilon() && - std::abs(lhs_offset - rhs_offset) == kSUFixed16OffsetDiff) { - return true; - } - } - return false; -} - -void TensorWrapper::SetDataBy(std::uint32_t bytes, const void* data) { - if (bytes != GetTensorBytes()) { - QNN_LOG_WARNING( - "Bytes: %d != GetTensorBytes(): %d, use GetTensorBytes() instead.", - bytes, GetTensorBytes()); - bytes = GetTensorBytes(); - } - owned_data_.resize(bytes); - std::memcpy(owned_data_.data(), reinterpret_cast(data), bytes); - qnn_tensor_.v2.clientBuf.dataSize = owned_data_.size(); - qnn_tensor_.v2.clientBuf.data = owned_data_.data(); -} - -void TensorWrapper::ConvertQint16ToQuint16() { - if (GetDataType() != QNN_DATATYPE_SFIXED_POINT_16) { - return; - } - - // adjust static data - if (IsTensorStatic()) { - auto int16_data = GetStaticTensorData(); - if (!int16_data.has_value()) { - QNN_LOG_ERROR( - "Cannot convert static QInt16 data to QUint16 data failed since " - "GetStaticTensorData failed."); - return; - } - QNN_LOG_DEBUG("Converting static tensor data from QInt16 to QUint16..."); - std::vector uint16_data; - ConvertDataFromInt16toUInt16((*int16_data), uint16_data); - std::memcpy(owned_data_.data(), - reinterpret_cast(uint16_data.data()), - GetTensorBytes()); - qnn_tensor_.v2.clientBuf.dataSize = owned_data_.size(); - qnn_tensor_.v2.clientBuf.data = owned_data_.data(); - } - - // adjust quant param; - if (IsPerTensorQuant()) { - const auto& q_param = - std::get(GetQuantParams()); - quantize_params_.emplace( - q_param.GetScale(), q_param.GetZeroPoint() + kUint16ZeroPoint); - - } else if (IsPerChannelQuant()) { - const auto& q_param = - std::get(GetQuantParams()); - std::int32_t axis = q_param.GetAxis(); - std::vector scales; - q_param.GetScales(scales); - std::vector zero_points; - q_param.GetZeroPoints(zero_points); - std::for_each(zero_points.begin(), zero_points.end(), - [](std::int32_t& val) { val += kUint16ZeroPoint; }); - quantize_params_.emplace( - axis, absl::MakeSpan(scales), absl::MakeSpan(zero_points)); - } - - std::visit( - [this](auto&& quantize_params) -> void { - quantize_params.CloneTo(qnn_tensor_.v2.quantizeParams); - }, - quantize_params_); - - // change data type here since GetStaticTensorData checks data type - qnn_tensor_.v2.dataType = QNN_DATATYPE_UFIXED_POINT_16; - QNN_LOG_DEBUG( - "QNN does not fully support QInt16 now, converting to QUint16 for better " - "compatibility."); -} - -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h deleted file mode 100644 index 5a079868d2b98e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_TENSOR_WRAPPER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_TENSOR_WRAPPER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/log.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" - -namespace qnn { - -// Get the Qnn_DataType_t associated with given C++ type. -template -inline constexpr Qnn_DataType_t GetQnnDataType(const bool is_quant) { - if constexpr (std::is_same_v) { - return QNN_DATATYPE_BOOL_8; - } else if constexpr (std::is_same_v) { - return is_quant ? QNN_DATATYPE_UFIXED_POINT_8 : QNN_DATATYPE_UINT_8; - } else if constexpr (std::is_same_v) { - return is_quant ? QNN_DATATYPE_SFIXED_POINT_8 : QNN_DATATYPE_INT_8; - } else if constexpr (std::is_same_v) { - return is_quant ? QNN_DATATYPE_UFIXED_POINT_16 : QNN_DATATYPE_UINT_16; - } else if constexpr (std::is_same_v) { - return is_quant ? QNN_DATATYPE_SFIXED_POINT_16 : QNN_DATATYPE_INT_16; - } else if constexpr (std::is_same_v) { - return is_quant ? QNN_DATATYPE_UFIXED_POINT_32 : QNN_DATATYPE_UINT_32; - } else if constexpr (std::is_same_v) { - return is_quant ? QNN_DATATYPE_SFIXED_POINT_32 : QNN_DATATYPE_INT_32; - } else if constexpr (std::is_same_v) { - return QNN_DATATYPE_FLOAT_32; - } else { - static_assert(always_false, "Uknown C++ type"); - } - return QNN_DATATYPE_UNDEFINED; -} - -std::size_t GetDataTypeSize(const Qnn_DataType_t data_type); - -template -void TransposeFromOHWIToHWIO(absl::Span weight_data, - const std::vector& weight_dims, - std::vector& weight_data_transpose) { - weight_data_transpose.resize(weight_data.size()); - uint32_t output = weight_dims[0]; - uint32_t height = weight_dims[1]; - uint32_t width = weight_dims[2]; - uint32_t input = weight_dims[3]; - // OHWI->HWIO - uint32_t map_o = 0; - uint32_t map_w = 0; - uint32_t map_h = 0; - for (uint32_t index_o = 0; index_o < output; index_o++) { - map_o = index_o * height * width * input; - for (uint32_t index_h = 0; index_h < height; index_h++) { - map_h = index_h * width * input; - for (uint32_t index_w = 0; index_w < width; index_w++) { - map_w = index_w * input; - for (uint32_t index_i = 0; index_i < input; index_i++) { - T inval = weight_data[map_o + map_h + map_w + index_i]; - uint32_t index_transpose = index_h * width * input * output + - index_w * input * output + - index_i * output + index_o; - weight_data_transpose[index_transpose] = inval; - } - } - } - } -} - -class TensorWrapper final { - friend class TensorPool; - - public: - explicit TensorWrapper(); - - explicit TensorWrapper(std::uint32_t id, Qnn_TensorType_t tensor_type, - Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quantize_params, - const std::vector& dimentions); - - explicit TensorWrapper(std::uint32_t id, Qnn_TensorType_t tensor_type, - Qnn_DataType_t data_type, - const QuantizeParamsWrapperVariant& quantize_params, - const std::vector& dimentions, - std::uint32_t bytes, const void* data); - - TensorWrapper(const TensorWrapper& other); - - TensorWrapper(TensorWrapper&& other); - - ~TensorWrapper(); - - void CloneTo(Qnn_Tensor_t& dst) const; - - Qnn_Tensor_t& GetQnnTensor() { return qnn_tensor_; } - - std::uint32_t GetRank() const; - - std::uint32_t GetDim(size_t index) const; - - const std::vector& GetDims() const { return dimentions_; }; - - std::uint32_t GetTensorNumElements() const; - - const QuantizeParamsWrapperVariant& GetQuantParams() const { - return quantize_params_; - }; - - QuantizeParamsWrapperVariant& GetQuantParams() { return quantize_params_; }; - - bool IsQuant() const { - return !std::holds_alternative( - quantize_params_); - }; - - bool IsPerTensorQuant() const { - return std::holds_alternative( - quantize_params_); - } - - bool IsPerChannelQuant() const { - return std::holds_alternative( - quantize_params_); - } - - bool IsPerTensorQuantWithOffsetDiff(const TensorWrapper& rhs) const; - - bool IsQuant8() const { - return GetDataType() == QNN_DATATYPE_SFIXED_POINT_8 || - GetDataType() == QNN_DATATYPE_UFIXED_POINT_8; - } - - bool IsQuant16() const { - return GetDataType() == QNN_DATATYPE_SFIXED_POINT_16 || - GetDataType() == QNN_DATATYPE_UFIXED_POINT_16; - } - - bool IsF32() const { return GetDataType() == QNN_DATATYPE_FLOAT_32; } - bool IsF16() const { return GetDataType() == QNN_DATATYPE_FLOAT_16; } - - Qnn_DataType_t GetDataType() const; - - bool IsSubgraphInput() const { - return GetTensorType() == QNN_TENSOR_TYPE_APP_WRITE; - } - - bool IsSubgraphOutput() const { - return GetTensorType() == QNN_TENSOR_TYPE_APP_READ; - } - - bool IsTensorStatic() const { - return GetTensorType() == QNN_TENSOR_TYPE_STATIC; - } - - template - bool SetTensorData(absl::Span data) { - if (!IsSubgraphInput() && !IsTensorStatic()) { - QNN_LOG_ERROR( - "Cannot set tensor data of tensor type other than " - "QNN_TENSOR_TYPE_APP_WRITE or QNN_TENSOR_TYPE_STATIC."); - return false; - } - - size_t num_elements = GetTensorNumElements(); - if (!num_elements) { - QNN_LOG_ERROR("Cannot set tensor data, number of elements = 0"); - return false; - } - - size_t data_bytes = sizeof(T) * data.size(); - size_t tensor_bytes = GetTensorBytes(); - if (tensor_bytes > data_bytes) { - QNN_LOG_ERROR( - "Tensor bytes: %d > given data bytes: %d, SetTensorData failed.", - tensor_bytes, data_bytes); - return false; - } - if (tensor_bytes < data_bytes) { - QNN_LOG_WARNING( - "Tensor bytes : %d < given data bytes: %d, using only %d.", - tensor_bytes, data_bytes, tensor_bytes); - } - - if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_FLOAT_32) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting float data on QNN data type %d.", - qnn_tensor_.v2.dataType); - return false; - } - } else if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_INT_8 && - qnn_tensor_.v2.dataType != QNN_DATATYPE_SFIXED_POINT_8) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting std::int8_t data on QNN data type " - "%d.", - qnn_tensor_.v2.dataType); - return false; - } - } else if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_UINT_8 && - qnn_tensor_.v2.dataType != QNN_DATATYPE_UFIXED_POINT_8) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting std::uint8_t data on QNN data " - "type %d.", - qnn_tensor_.v2.dataType); - return false; - } - } else if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_INT_16 && - qnn_tensor_.v2.dataType != QNN_DATATYPE_SFIXED_POINT_16) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting std::int16_t data on QNN data " - "type %d.", - qnn_tensor_.v2.dataType); - return false; - } - } else if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_UINT_16 && - qnn_tensor_.v2.dataType != QNN_DATATYPE_UFIXED_POINT_16) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting std::uint16_t data on QNN data " - "type %d.", - qnn_tensor_.v2.dataType); - return false; - } - - } else if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_INT_32 && - qnn_tensor_.v2.dataType != QNN_DATATYPE_SFIXED_POINT_32) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting std::int32_t data on QNN data " - "type %d.", - qnn_tensor_.v2.dataType); - return false; - } - } else if constexpr (std::is_same_v) { - if (qnn_tensor_.v2.dataType != QNN_DATATYPE_UINT_32 && - qnn_tensor_.v2.dataType != QNN_DATATYPE_UFIXED_POINT_32) { - QNN_LOG_ERROR( - "Cannot set tensor data, setting std::uint32_t data on QNN data " - "type %d.", - qnn_tensor_.v2.dataType); - return false; - } - } else { - QNN_LOG_ERROR("Cannot set tensor data, unknown data type."); - return false; - } - - owned_data_.resize(tensor_bytes); - std::memcpy(owned_data_.data(), reinterpret_cast(data.data()), - tensor_bytes); - qnn_tensor_.v2.clientBuf.dataSize = owned_data_.size(); - qnn_tensor_.v2.clientBuf.data = owned_data_.data(); - return true; - } - - // Allocate memory on owned_data_ for output tensors - void AllocateOutputTensorBuffer() { - owned_data_.resize(GetTensorBytes()); - qnn_tensor_.v2.clientBuf.dataSize = owned_data_.size(); - qnn_tensor_.v2.clientBuf.data = owned_data_.data(); - } - - template - std::optional> GetStaticTensorData() const; - - void ConvertAxisScaleOffsetToScaleOffset() { - if (!std::holds_alternative( - quantize_params_)) { - return; - } - - quantize_params_.emplace(0.0, 0); - } - - size_t GetTensorBytes() const; - - void ConvertQint16ToQuint16(); - - private: - Qnn_TensorType_t GetTensorType() const; - - void SetDataBy(std::uint32_t bytes, const void* data); - - bool HasStaticData() const { - return qnn_tensor_.v2.clientBuf.dataSize != 0 && - qnn_tensor_.v2.clientBuf.data != nullptr; - } - - Qnn_Tensor_t qnn_tensor_{.version = QNN_TENSOR_VERSION_2, - .v2 = QNN_TENSOR_V2_INIT}; - std::string name_{}; - std::vector dimentions_{}; - QuantizeParamsWrapperVariant quantize_params_{}; - std::vector owned_data_{}; -}; - -using TensorWrapperRef = std::reference_wrapper; - -template -std::optional> TensorWrapper::GetStaticTensorData() const { - if (!IsTensorStatic()) { - QNN_LOG_ERROR( - "Cannot GetStaticTensorData() on a non-static tensor, tensor type %d.", - GetTensorType()); - return std::nullopt; - } - - if (GetDataType() != GetQnnDataType(IsQuant())) { - QNN_LOG_ERROR("GetStaticTensorData() with incorrect template type."); - return std::nullopt; - } - - if (!HasStaticData()) { - QNN_LOG_ERROR("Empty static tensor data."); - return std::nullopt; - } - - if (qnn_tensor_.v2.clientBuf.dataSize != GetTensorBytes()) { - QNN_LOG_ERROR("Tensor bytes != stored data bytes."); - return std::nullopt; - } - - uint32_t num_elements = qnn_tensor_.v2.clientBuf.dataSize / sizeof(T); - if (!num_elements) { - QNN_LOG_ERROR("No element in this tensor."); - return std::nullopt; - } - - return absl::MakeConstSpan( - reinterpret_cast(qnn_tensor_.v2.clientBuf.data), num_elements); -} -} // namespace qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CORE_WRAPPERS_TENSOR_WRAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/BUILD deleted file mode 100644 index 6617e6ff0198e0..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/BUILD +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert/vendors/qualcomm:__subpackages__"], -) - -cc_test( - name = "op_wrapper_test", - srcs = [ - "op_wrapper_test.cc", - ], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_googletest//:gtest_main", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:op_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_test( - name = "tensor_wrapper_test", - srcs = [ - "tensor_wrapper_test.cc", - ], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/types:span", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils:miscs", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_test( - name = "param_wrapper_test", - srcs = [ - "param_wrapper_test.cc", - ], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_googletest//:gtest_main", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:param_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:tensor_wrapper", - ], -) - -cc_test( - name = "quantize_params_wrapper_test", - srcs = [ - "quantize_params_wrapper_test.cc", - ], - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - deps = [ - "@com_google_googletest//:gtest_main", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/op_wrapper_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/op_wrapper_test.cc deleted file mode 100644 index 60d121142ab5f7..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/op_wrapper_test.cc +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/op_wrapper.h" - -#include -#include -#include -#include -#include -#include - -#include -#include "third_party/qairt/latest/include/QNN/QnnOpDef.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { -namespace { - -void EXPECT_TENSOR_EQ(Qnn_Tensor_t actual, Qnn_Tensor_t expected) { - EXPECT_EQ(actual.v2.id, expected.v2.id); - EXPECT_EQ(actual.v2.type, expected.v2.type); - EXPECT_EQ(actual.v2.dataFormat, expected.v2.dataFormat); - EXPECT_EQ(actual.v2.dataType, expected.v2.dataType); - EXPECT_EQ(actual.v2.quantizeParams.encodingDefinition, - expected.v2.quantizeParams.encodingDefinition); - EXPECT_EQ(actual.v2.rank, expected.v2.rank); - for (size_t i = 0; i < actual.v2.rank; i++) { - EXPECT_EQ(actual.v2.dimensions[i], expected.v2.dimensions[i]); - } - EXPECT_EQ(actual.v2.memType, expected.v2.memType); - EXPECT_EQ(actual.v2.clientBuf.dataSize, expected.v2.clientBuf.dataSize); - const auto* actual_data = - reinterpret_cast(actual.v2.clientBuf.data); - const auto* expected_data = - reinterpret_cast(expected.v2.clientBuf.data); - for (size_t i = 0; i < actual.v2.clientBuf.dataSize; i++) { - EXPECT_EQ(actual_data[i], expected_data[i]); - } -} - -TEST(OpWrapperTest, SanityTest) { - OpWrapper op_wrapper{"name", "OP_TYPE"}; - const Qnn_OpConfig_t& op_config = op_wrapper.GetOpConfig(); - EXPECT_EQ(op_config.version, QNN_OPCONFIG_VERSION_1); - - const Qnn_OpConfigV1_t& op_config_v1 = op_config.v1; - EXPECT_STREQ(op_config_v1.typeName, "OP_TYPE"); - EXPECT_STREQ(op_config_v1.packageName, QNN_OP_PACKAGE_NAME_QTI_AISW); - EXPECT_STREQ(op_config_v1.name, "name"); - EXPECT_EQ(op_config_v1.numOfInputs, 0); - EXPECT_EQ(op_config_v1.numOfOutputs, 0); - EXPECT_EQ(op_config_v1.numOfParams, 0); - EXPECT_EQ(op_config_v1.params, nullptr); - EXPECT_EQ(op_config_v1.inputTensors, nullptr); - EXPECT_EQ(op_config_v1.outputTensors, nullptr); -} - -TEST(OpWrapperTest, MoveCtorSanityTest) { - OpWrapper op_wrapper{"name", "OP_TYPE"}; - OpWrapper moved{std::move(op_wrapper)}; - const Qnn_OpConfig_t& op_config = moved.GetOpConfig(); - EXPECT_EQ(op_config.version, QNN_OPCONFIG_VERSION_1); - - const Qnn_OpConfigV1_t& op_config_v1 = op_config.v1; - EXPECT_STREQ(op_config_v1.typeName, "OP_TYPE"); - EXPECT_STREQ(op_config_v1.packageName, QNN_OP_PACKAGE_NAME_QTI_AISW); - EXPECT_STREQ(op_config_v1.name, "name"); - EXPECT_EQ(op_config_v1.numOfInputs, 0); - EXPECT_EQ(op_config_v1.numOfOutputs, 0); - EXPECT_EQ(op_config_v1.numOfParams, 0); - EXPECT_EQ(op_config_v1.params, nullptr); - EXPECT_EQ(op_config_v1.inputTensors, nullptr); - EXPECT_EQ(op_config_v1.outputTensors, nullptr); -} - -TEST(OpWrapperTest, OpConfigTest) { - std::vector dummy_dims = {1, 1, 3}; - std::vector data = {1, 2, 3}; - void* data_ptr = reinterpret_cast(data.data()); - const auto data_size = - std::accumulate(dummy_dims.begin(), dummy_dims.end(), - sizeof(decltype(data)::value_type), std::multiplies<>()); - - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_APP_WRITE, - QNN_DATATYPE_UFIXED_POINT_8, - QuantizeParamsWrapperVariant(), - dummy_dims, - static_cast(data_size), - data_ptr}; - - Qnn_Tensor_t golden_qnn_tensor; - tensor_wrapper.CloneTo(golden_qnn_tensor); - - std::uint8_t value = 255; - OpWrapper op_wrapper{"name", "OP_TYPE"}; - op_wrapper.AddInputTensor(tensor_wrapper); - op_wrapper.AddOutputTensor(tensor_wrapper); - op_wrapper.AddScalarParam("uint8_param", value, false); - op_wrapper.AddTensorParam("tensor_param", tensor_wrapper); - - Qnn_OpConfig_t op_config = op_wrapper.GetOpConfig(); - EXPECT_EQ(op_config.version, QNN_OPCONFIG_VERSION_1); - EXPECT_STREQ(op_config.v1.typeName, "OP_TYPE"); - EXPECT_STREQ(op_config.v1.packageName, QNN_OP_PACKAGE_NAME_QTI_AISW); - EXPECT_STREQ(op_config.v1.name, "name"); - - Qnn_OpConfigV1_t op_config_v1 = op_config.v1; - - EXPECT_EQ(op_config_v1.numOfInputs, 1); - EXPECT_EQ(op_config_v1.numOfOutputs, 1); - EXPECT_EQ(op_config_v1.numOfParams, 2); - EXPECT_TENSOR_EQ(op_config_v1.inputTensors[0], golden_qnn_tensor); - EXPECT_TENSOR_EQ(op_config_v1.outputTensors[0], golden_qnn_tensor); - EXPECT_EQ(op_config_v1.params[0].paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(op_config_v1.params[0].name, "uint8_param"); - EXPECT_EQ(op_config_v1.params[0].scalarParam.dataType, QNN_DATATYPE_UINT_8); - EXPECT_EQ(op_config_v1.params[0].scalarParam.uint8Value, value); - EXPECT_EQ(op_config_v1.params[1].paramType, QNN_PARAMTYPE_TENSOR); - EXPECT_EQ(op_config_v1.params[1].name, "tensor_param"); - EXPECT_TENSOR_EQ(op_config_v1.params[1].tensorParam, golden_qnn_tensor); -} - -TEST(OpWrapperTest, MoveConstructorTest) { - std::vector dummy_dims = {1, 1, 3}; - std::vector data = {1, 2, 3}; - void* data_ptr = reinterpret_cast(data.data()); - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_APP_WRITE, - QNN_DATATYPE_UFIXED_POINT_8, - QuantizeParamsWrapperVariant(), - dummy_dims, - static_cast(data.size()), - data_ptr}; - Qnn_Tensor_t golden_qnn_tensor; - tensor_wrapper.CloneTo(golden_qnn_tensor); - std::uint8_t value = 255; - OpWrapper op_wrapper{"name", "OP_TYPE"}; - op_wrapper.AddInputTensor(tensor_wrapper); - op_wrapper.AddOutputTensor(tensor_wrapper); - op_wrapper.AddScalarParam("uint8_param", value, false); - op_wrapper.AddTensorParam("tensor_param", tensor_wrapper); - OpWrapper op_wrapper_move(std::move(op_wrapper)); - Qnn_OpConfig_t op_config = op_wrapper_move.GetOpConfig(); - EXPECT_EQ(op_config.version, QNN_OPCONFIG_VERSION_1); - EXPECT_STREQ(op_config.v1.typeName, "OP_TYPE"); - EXPECT_STREQ(op_config.v1.packageName, QNN_OP_PACKAGE_NAME_QTI_AISW); - EXPECT_STREQ(op_config.v1.name, "name"); - Qnn_OpConfigV1_t op_config_v1 = op_config.v1; - EXPECT_EQ(op_config_v1.numOfInputs, 1); - EXPECT_EQ(op_config_v1.numOfOutputs, 1); - EXPECT_EQ(op_config_v1.numOfParams, 2); - EXPECT_TENSOR_EQ(op_config_v1.inputTensors[0], golden_qnn_tensor); - EXPECT_TENSOR_EQ(op_config_v1.outputTensors[0], golden_qnn_tensor); - EXPECT_EQ(op_config_v1.params[0].paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(op_config_v1.params[0].name, "uint8_param"); - EXPECT_EQ(op_config_v1.params[0].scalarParam.dataType, QNN_DATATYPE_UINT_8); - EXPECT_EQ(op_config_v1.params[0].scalarParam.uint8Value, value); - EXPECT_EQ(op_config_v1.params[1].paramType, QNN_PARAMTYPE_TENSOR); - EXPECT_EQ(op_config_v1.params[1].name, "tensor_param"); - EXPECT_TENSOR_EQ(op_config_v1.params[1].tensorParam, golden_qnn_tensor); -} - -} // namespace -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/param_wrapper_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/param_wrapper_test.cc deleted file mode 100644 index 1472e494306911..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/param_wrapper_test.cc +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/param_wrapper.h" - -#include -#include -#include -#include -#include - -#include -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -namespace qnn { -namespace { - -TEST(ScalarParamWrapperTest, BoolParamTest) { - ScalarParamWrapper bool_param{"bool_param", true, false}; - Qnn_Param_t bool_qnn_param = QNN_PARAM_INIT; - bool_param.CloneTo(bool_qnn_param); - EXPECT_EQ(bool_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(bool_qnn_param.name, "bool_param"); - EXPECT_EQ(bool_qnn_param.scalarParam.dataType, QNN_DATATYPE_BOOL_8); - EXPECT_EQ(bool_qnn_param.scalarParam.bool8Value, 1); -} - -TEST(ScalarParamWrapperTest, Uint8ParamTest) { - constexpr std::uint8_t value = 255; - ScalarParamWrapper uint8_param{"uint8_param", value, false}; - Qnn_Param_t uint8_qnn_param = QNN_PARAM_INIT; - uint8_param.CloneTo(uint8_qnn_param); - EXPECT_EQ(uint8_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(uint8_qnn_param.name, "uint8_param"); - EXPECT_EQ(uint8_qnn_param.scalarParam.dataType, QNN_DATATYPE_UINT_8); - EXPECT_EQ(uint8_qnn_param.scalarParam.uint8Value, value); -} - -TEST(ScalarParamWrapperTest, Int8ParamTest) { - constexpr std::int8_t value = -128; - ScalarParamWrapper int8_param{"int8_param", value, false}; - Qnn_Param_t int8_qnn_param = QNN_PARAM_INIT; - int8_param.CloneTo(int8_qnn_param); - EXPECT_EQ(int8_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(int8_qnn_param.name, "int8_param"); - EXPECT_EQ(int8_qnn_param.scalarParam.dataType, QNN_DATATYPE_INT_8); - EXPECT_EQ(int8_qnn_param.scalarParam.int8Value, value); -} - -TEST(ScalarParamWrapperTest, Uint16ParamTest) { - constexpr std::uint16_t value = 65535; - ScalarParamWrapper uint16_param{"uint16_param", value, false}; - Qnn_Param_t uint16_qnn_param = QNN_PARAM_INIT; - uint16_param.CloneTo(uint16_qnn_param); - EXPECT_EQ(uint16_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(uint16_qnn_param.name, "uint16_param"); - EXPECT_EQ(uint16_qnn_param.scalarParam.dataType, QNN_DATATYPE_UINT_16); - EXPECT_EQ(uint16_qnn_param.scalarParam.uint16Value, value); -} - -TEST(ScalarParamWrapperTest, Int16ParamTest) { - constexpr std::int16_t value = -32768; - ScalarParamWrapper int16_param{"int16_param", value, false}; - Qnn_Param_t int16_qnn_param = QNN_PARAM_INIT; - int16_param.CloneTo(int16_qnn_param); - EXPECT_EQ(int16_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(int16_qnn_param.name, "int16_param"); - EXPECT_EQ(int16_qnn_param.scalarParam.dataType, QNN_DATATYPE_INT_16); - EXPECT_EQ(int16_qnn_param.scalarParam.int16Value, value); -} - -TEST(ScalarParamWrapperTest, Uint32ParamTest) { - constexpr std::uint32_t value = 4294967295; - ScalarParamWrapper uint32_param{"uint32_param", value, false}; - Qnn_Param_t uint32_qnn_param = QNN_PARAM_INIT; - uint32_param.CloneTo(uint32_qnn_param); - EXPECT_EQ(uint32_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(uint32_qnn_param.name, "uint32_param"); - EXPECT_EQ(uint32_qnn_param.scalarParam.dataType, QNN_DATATYPE_UINT_32); - EXPECT_EQ(uint32_qnn_param.scalarParam.uint32Value, value); -} - -TEST(ScalarParamWrapperTest, Int32ParamTest) { - constexpr std::int32_t value = -2147483648; - ScalarParamWrapper int32_param{"int32_param", value, false}; - Qnn_Param_t int32_qnn_param = QNN_PARAM_INIT; - int32_param.CloneTo(int32_qnn_param); - EXPECT_EQ(int32_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(int32_qnn_param.name, "int32_param"); - EXPECT_EQ(int32_qnn_param.scalarParam.dataType, QNN_DATATYPE_INT_32); - EXPECT_EQ(int32_qnn_param.scalarParam.int32Value, value); -} - -TEST(ScalarParamWrapperTest, FloatParamTest) { - constexpr float value = 3.14f; - ScalarParamWrapper float_param{"float_param", value, false}; - Qnn_Param_t float_qnn_param = QNN_PARAM_INIT; - float_param.CloneTo(float_qnn_param); - EXPECT_EQ(float_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(float_qnn_param.name, "float_param"); - EXPECT_EQ(float_qnn_param.scalarParam.dataType, QNN_DATATYPE_FLOAT_32); - EXPECT_FLOAT_EQ(float_qnn_param.scalarParam.floatValue, value); -} - -TEST(ScalarParamWrapperTest, QuantizedBoolParamTest) { - ScalarParamWrapper bool_quant_param{"bool_quant_param", true, true}; - Qnn_Param_t bool_quant_qnn_param = QNN_PARAM_INIT; - bool_quant_param.CloneTo(bool_quant_qnn_param); - EXPECT_EQ(bool_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(bool_quant_qnn_param.name, "bool_quant_param"); - EXPECT_EQ(bool_quant_qnn_param.scalarParam.dataType, QNN_DATATYPE_BOOL_8); - EXPECT_EQ(bool_quant_qnn_param.scalarParam.bool8Value, 1); -} - -TEST(ScalarParamWrapperTest, QuantizedUint8ParamTest) { - constexpr std::uint8_t value = 255; - ScalarParamWrapper uint8_quant_param{"uint8_quant_param", value, true}; - Qnn_Param_t uint8_quant_qnn_param = QNN_PARAM_INIT; - uint8_quant_param.CloneTo(uint8_quant_qnn_param); - EXPECT_EQ(uint8_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(uint8_quant_qnn_param.name, "uint8_quant_param"); - EXPECT_EQ(uint8_quant_qnn_param.scalarParam.dataType, - QNN_DATATYPE_UFIXED_POINT_8); - EXPECT_EQ(uint8_quant_qnn_param.scalarParam.uint8Value, value); -} - -TEST(ScalarParamWrapperTest, QuantizedInt8ParamTest) { - constexpr std::int8_t value = -128; - ScalarParamWrapper int8_quant_param{"int8_quant_param", value, true}; - Qnn_Param_t int8_quant_qnn_param = QNN_PARAM_INIT; - int8_quant_param.CloneTo(int8_quant_qnn_param); - EXPECT_EQ(int8_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(int8_quant_qnn_param.name, "int8_quant_param"); - EXPECT_EQ(int8_quant_qnn_param.scalarParam.dataType, - QNN_DATATYPE_SFIXED_POINT_8); - EXPECT_EQ(int8_quant_qnn_param.scalarParam.int8Value, value); -} - -TEST(ScalarParamWrapperTest, QuantizedUint16ParamTest) { - constexpr std::uint16_t value = 65535; - ScalarParamWrapper uint16_quant_param{"uint16_quant_param", value, true}; - Qnn_Param_t uint16_quant_qnn_param = QNN_PARAM_INIT; - uint16_quant_param.CloneTo(uint16_quant_qnn_param); - EXPECT_EQ(uint16_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(uint16_quant_qnn_param.name, "uint16_quant_param"); - EXPECT_EQ(uint16_quant_qnn_param.scalarParam.dataType, - QNN_DATATYPE_UFIXED_POINT_16); - EXPECT_EQ(uint16_quant_qnn_param.scalarParam.uint16Value, value); -} - -TEST(ScalarParamWrapperTest, QuantizedInt16ParamTest) { - constexpr std::int16_t value = -32768; - ScalarParamWrapper int16_quant_param{"int16_quant_param", value, true}; - Qnn_Param_t int16_quant_qnn_param = QNN_PARAM_INIT; - int16_quant_param.CloneTo(int16_quant_qnn_param); - EXPECT_EQ(int16_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(int16_quant_qnn_param.name, "int16_quant_param"); - EXPECT_EQ(int16_quant_qnn_param.scalarParam.dataType, - QNN_DATATYPE_SFIXED_POINT_16); - EXPECT_EQ(int16_quant_qnn_param.scalarParam.int16Value, value); -} - -TEST(ScalarParamWrapperTest, QuantizedUint32ParamTest) { - constexpr std::uint32_t value = 4294967295; - ScalarParamWrapper uint32_quant_param{"uint32_quant_param", value, true}; - Qnn_Param_t uint32_quant_qnn_param = QNN_PARAM_INIT; - uint32_quant_param.CloneTo(uint32_quant_qnn_param); - EXPECT_EQ(uint32_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(uint32_quant_qnn_param.name, "uint32_quant_param"); - EXPECT_EQ(uint32_quant_qnn_param.scalarParam.dataType, - QNN_DATATYPE_UFIXED_POINT_32); - EXPECT_EQ(uint32_quant_qnn_param.scalarParam.uint32Value, value); -} - -TEST(ScalarParamWrapperTest, QuantizedInt32ParamTest) { - constexpr std::int32_t value = -2147483648; - ScalarParamWrapper int32_quant_param{"int32_quant_param", value, true}; - Qnn_Param_t int32_quant_qnn_param = QNN_PARAM_INIT; - int32_quant_param.CloneTo(int32_quant_qnn_param); - EXPECT_EQ(int32_quant_qnn_param.paramType, QNN_PARAMTYPE_SCALAR); - EXPECT_EQ(int32_quant_qnn_param.name, "int32_quant_param"); - EXPECT_EQ(int32_quant_qnn_param.scalarParam.dataType, - QNN_DATATYPE_SFIXED_POINT_32); - EXPECT_EQ(int32_quant_qnn_param.scalarParam.int32Value, value); -} - -TEST(ParamWrapperTest, TensorParamTest) { - std::vector dummy_dims = {1, 1, 3}; - std::vector data = {1, 2, 3}; - void* data_ptr = reinterpret_cast(data.data()); - - const auto data_size = - std::accumulate(dummy_dims.begin(), dummy_dims.end(), - sizeof(decltype(data)::value_type), std::multiplies<>()); - - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UFIXED_POINT_8, - QuantizeParamsWrapperVariant(), - dummy_dims, - static_cast(data_size), - data_ptr}; - - TensorParamWrapper tensor_param{"tensor_param", tensor_wrapper}; - - Qnn_Param_t qnn_tensor_param = QNN_PARAM_INIT; - tensor_param.CloneTo(qnn_tensor_param); - EXPECT_EQ(qnn_tensor_param.paramType, QNN_PARAMTYPE_TENSOR); - EXPECT_EQ(qnn_tensor_param.name, "tensor_param"); - - Qnn_Tensor_t& ref = qnn_tensor_param.tensorParam; - EXPECT_EQ(ref.v2.id, 0); - EXPECT_EQ(ref.v2.type, QNN_TENSOR_TYPE_STATIC); - EXPECT_EQ(ref.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER); - EXPECT_EQ(ref.v2.dataType, QNN_DATATYPE_UFIXED_POINT_8); - EXPECT_EQ(ref.v2.quantizeParams.encodingDefinition, QNN_DEFINITION_UNDEFINED); - EXPECT_EQ(ref.v2.rank, dummy_dims.size()); - for (size_t i = 0; i < ref.v2.rank; i++) { - EXPECT_EQ(ref.v2.dimensions[i], dummy_dims[i]); - } - EXPECT_EQ(ref.v2.memType, QNN_TENSORMEMTYPE_RAW); - EXPECT_EQ(ref.v2.clientBuf.dataSize, data_size); - const auto* ref_data = - reinterpret_cast(ref.v2.clientBuf.data); - for (size_t i = 0; i < data.size(); i++) { - EXPECT_EQ(ref_data[i], data[i]); - } -} -} // namespace -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/quantize_params_wrapper_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/quantize_params_wrapper_test.cc deleted file mode 100644 index 8ed03dc50689ea..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/quantize_params_wrapper_test.cc +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" - -#include -#include -#include -#include - -#include -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" - -namespace qnn { -namespace { - -TEST(UndefinedQuantizeParamsWrapperTest, DefaultConstructorTest) { - UndefinedQuantizeParamsWrapper wrapper; - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_UNDEFINED); - EXPECT_EQ(dst.quantizationEncoding, QNN_QUANTIZATION_ENCODING_UNDEFINED); -} - -TEST(UndefinedQuantizeParamsWrapperTest, CopyConstructorTest) { - UndefinedQuantizeParamsWrapper wrapper1; - UndefinedQuantizeParamsWrapper wrapper2(wrapper1); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper2.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_UNDEFINED); - EXPECT_EQ(dst.quantizationEncoding, QNN_QUANTIZATION_ENCODING_UNDEFINED); -} - -TEST(UndefinedQuantizeParamsWrapperTest, MoveConstructorTest) { - UndefinedQuantizeParamsWrapper wrapper1; - UndefinedQuantizeParamsWrapper wrapper2(std::move(wrapper1)); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper2.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_UNDEFINED); - EXPECT_EQ(dst.quantizationEncoding, QNN_QUANTIZATION_ENCODING_UNDEFINED); -} - -TEST(ScaleOffsetQuantizeParamsWrapperTest, ConstructorTest) { - float scale = 1.5f; - std::int32_t zero_point = 10; - ScaleOffsetQuantizeParamsWrapper wrapper(scale, zero_point); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_DEFINED); - EXPECT_EQ(dst.quantizationEncoding, QNN_QUANTIZATION_ENCODING_SCALE_OFFSET); - EXPECT_FLOAT_EQ(dst.scaleOffsetEncoding.scale, scale); - EXPECT_EQ(dst.scaleOffsetEncoding.offset, -zero_point); -} - -TEST(ScaleOffsetQuantizeParamsWrapperTest, CopyConstructorTest) { - float scale = 1.5f; - std::int32_t zero_point = 10; - ScaleOffsetQuantizeParamsWrapper wrapper1(scale, zero_point); - ScaleOffsetQuantizeParamsWrapper wrapper2(wrapper1); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper2.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_DEFINED); - EXPECT_EQ(dst.quantizationEncoding, QNN_QUANTIZATION_ENCODING_SCALE_OFFSET); - EXPECT_FLOAT_EQ(dst.scaleOffsetEncoding.scale, scale); - EXPECT_EQ(dst.scaleOffsetEncoding.offset, -zero_point); -} - -TEST(ScaleOffsetQuantizeParamsWrapperTest, MoveConstructorTest) { - float scale = 1.5f; - std::int32_t zero_point = 10; - ScaleOffsetQuantizeParamsWrapper wrapper1(scale, zero_point); - ScaleOffsetQuantizeParamsWrapper wrapper2(std::move(wrapper1)); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper2.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_DEFINED); - EXPECT_EQ(dst.quantizationEncoding, QNN_QUANTIZATION_ENCODING_SCALE_OFFSET); - EXPECT_FLOAT_EQ(dst.scaleOffsetEncoding.scale, scale); - EXPECT_EQ(dst.scaleOffsetEncoding.offset, -zero_point); -} - -TEST(ScaleOffsetQuantizeParamsWrapperTest, GetterTest) { - float scale = 1.5f; - std::int32_t zero_point = 10; - ScaleOffsetQuantizeParamsWrapper wrapper(scale, zero_point); - EXPECT_FLOAT_EQ(wrapper.GetScale(), scale); - EXPECT_EQ(wrapper.GetZeroPoint(), zero_point); -} - -TEST(AxisScaleOffsetQuantizeParamsWrapperTest, ConstructorTest) { - std::int32_t axis = 1; - std::vector scales = {1.5f, 2.5f}; - std::vector zero_points = {10, 20}; - AxisScaleOffsetQuantizeParamsWrapper wrapper(axis, scales, zero_points); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_DEFINED); - EXPECT_EQ(dst.quantizationEncoding, - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET); - EXPECT_EQ(dst.axisScaleOffsetEncoding.axis, axis); - EXPECT_EQ(dst.axisScaleOffsetEncoding.numScaleOffsets, scales.size()); - for (size_t i = 0; i < scales.size(); ++i) { - EXPECT_FLOAT_EQ(dst.axisScaleOffsetEncoding.scaleOffset[i].scale, - scales[i]); - EXPECT_EQ(dst.axisScaleOffsetEncoding.scaleOffset[i].offset, - -zero_points[i]); - } -} - -TEST(AxisScaleOffsetQuantizeParamsWrapperTest, CopyConstructorTest) { - std::int32_t axis = 1; - std::vector scales = {1.5f, 2.5f}; - std::vector zero_points = {10, 20}; - AxisScaleOffsetQuantizeParamsWrapper wrapper1(axis, scales, zero_points); - AxisScaleOffsetQuantizeParamsWrapper wrapper2(wrapper1); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper2.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_DEFINED); - EXPECT_EQ(dst.quantizationEncoding, - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET); - EXPECT_EQ(dst.axisScaleOffsetEncoding.axis, axis); - EXPECT_EQ(dst.axisScaleOffsetEncoding.numScaleOffsets, scales.size()); - for (size_t i = 0; i < scales.size(); ++i) { - EXPECT_FLOAT_EQ(dst.axisScaleOffsetEncoding.scaleOffset[i].scale, - scales[i]); - EXPECT_EQ(dst.axisScaleOffsetEncoding.scaleOffset[i].offset, - -zero_points[i]); - } -} - -TEST(AxisScaleOffsetQuantizeParamsWrapperTest, MoveConstructorTest) { - std::int32_t axis = 1; - std::vector scales = {1.5f, 2.5f}; - std::vector zero_points = {10, 20}; - AxisScaleOffsetQuantizeParamsWrapper wrapper1(axis, scales, zero_points); - AxisScaleOffsetQuantizeParamsWrapper wrapper2(std::move(wrapper1)); - Qnn_QuantizeParams_t dst = QNN_QUANTIZE_PARAMS_INIT; - wrapper2.CloneTo(dst); - EXPECT_EQ(dst.encodingDefinition, QNN_DEFINITION_DEFINED); - EXPECT_EQ(dst.quantizationEncoding, - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET); - EXPECT_EQ(dst.axisScaleOffsetEncoding.axis, axis); - EXPECT_EQ(dst.axisScaleOffsetEncoding.numScaleOffsets, scales.size()); - for (size_t i = 0; i < scales.size(); ++i) { - EXPECT_FLOAT_EQ(dst.axisScaleOffsetEncoding.scaleOffset[i].scale, - scales[i]); - EXPECT_EQ(dst.axisScaleOffsetEncoding.scaleOffset[i].offset, - -zero_points[i]); - } -} -TEST(AxisScaleOffsetQuantizeParamsWrapperTest, GetterTest) { - std::int32_t axis = 1; - std::vector scales = {1.5f, 2.5f}; - std::vector zero_points = {10, 20}; - AxisScaleOffsetQuantizeParamsWrapper wrapper(axis, scales, zero_points); - std::vector scales_out; - wrapper.GetScales(scales_out); - EXPECT_EQ(scales, scales_out); - std::vector zero_points_out; - wrapper.GetZeroPoints(zero_points_out); - EXPECT_EQ(zero_points, zero_points_out); -} -} // namespace -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/tensor_wrapper_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/tensor_wrapper_test.cc deleted file mode 100644 index 68e1828181ef9b..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tests/tensor_wrapper_test.cc +++ /dev/null @@ -1,309 +0,0 @@ -// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" - -#include -#include -#include -#include -#include -#include -#include - -#include -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/utils/miscs.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/core/wrappers/quantize_params_wrapper.h" - -namespace qnn { -namespace { - -TEST(TensorWrapperTest, SanityTest) { - TensorWrapper tensor_wrapper{}; - - EXPECT_EQ(tensor_wrapper.GetRank(), 0); - EXPECT_TRUE(tensor_wrapper.GetDims().empty()); - EXPECT_TRUE(std::holds_alternative( - tensor_wrapper.GetQuantParams())); - EXPECT_FALSE(tensor_wrapper.IsPerTensorQuantWithOffsetDiff(tensor_wrapper)); - EXPECT_FALSE(tensor_wrapper.IsQuant8()); - EXPECT_FALSE(tensor_wrapper.IsQuant16()); - EXPECT_EQ(tensor_wrapper.GetDataType(), QNN_DATATYPE_UNDEFINED); - EXPECT_FALSE(tensor_wrapper.IsSubgraphInput()); - EXPECT_FALSE(tensor_wrapper.IsSubgraphOutput()); - EXPECT_FALSE(tensor_wrapper.IsTensorStatic()); - EXPECT_EQ(tensor_wrapper.GetStaticTensorData(), std::nullopt); - std::vector data = {1, 2, 3}; - // expect no use, since tensor type not correct - tensor_wrapper.SetTensorData( - absl::MakeSpan(data.data(), data.size())); - EXPECT_EQ(tensor_wrapper.GetStaticTensorData(), std::nullopt); -} - -TEST(TensorWrapperTest, CopyTensorTest) { - std::vector dummy_dims = {1, 1, 3}; - ScaleOffsetQuantizeParamsWrapper q_param(1, 0); - TensorWrapper tensor_wrapper{0, QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UFIXED_POINT_8, q_param, - dummy_dims}; - TensorWrapper copied{tensor_wrapper}; - - EXPECT_EQ(copied.GetRank(), 3); - EXPECT_EQ(copied.GetDims(), dummy_dims); - EXPECT_TRUE(std::holds_alternative( - copied.GetQuantParams())); - EXPECT_FALSE(copied.IsPerTensorQuantWithOffsetDiff(copied)); - EXPECT_TRUE(copied.IsQuant8()); - EXPECT_FALSE(copied.IsQuant16()); - EXPECT_EQ(copied.GetDataType(), QNN_DATATYPE_UFIXED_POINT_8); - EXPECT_FALSE(copied.IsSubgraphInput()); - EXPECT_FALSE(copied.IsSubgraphOutput()); - EXPECT_TRUE(copied.IsTensorStatic()); - EXPECT_EQ(copied.GetStaticTensorData(), std::nullopt); - std::vector data = {1, 2, 3}; - copied.SetTensorData(absl::MakeSpan(data.data(), data.size())); - const auto tensor_data = copied.GetStaticTensorData(); - EXPECT_TRUE(tensor_data.has_value()); - for (size_t i = 0; i < data.size(); i++) { - EXPECT_EQ((*tensor_data)[i], data[i]); - } -} - -TEST(TensorWrapperTest, MoveTensorTest) { - std::vector dummy_dims = {1, 1, 3}; - ScaleOffsetQuantizeParamsWrapper q_param(1, 0); - std::vector data = {1, 2, 3}; - void* data_ptr = reinterpret_cast(data.data()); - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UFIXED_POINT_8, - q_param, - dummy_dims, - static_cast(data.size()), - data_ptr}; - TensorWrapper moved{tensor_wrapper}; - - EXPECT_EQ(moved.GetRank(), 3); - EXPECT_EQ(moved.GetDims(), dummy_dims); - EXPECT_TRUE(std::holds_alternative( - moved.GetQuantParams())); - EXPECT_FALSE(moved.IsPerTensorQuantWithOffsetDiff(moved)); - EXPECT_TRUE(moved.IsQuant8()); - EXPECT_FALSE(moved.IsQuant16()); - EXPECT_EQ(moved.GetDataType(), QNN_DATATYPE_UFIXED_POINT_8); - EXPECT_FALSE(moved.IsSubgraphInput()); - EXPECT_FALSE(moved.IsSubgraphOutput()); - EXPECT_TRUE(moved.IsTensorStatic()); - const auto tensor_data = moved.GetStaticTensorData(); - EXPECT_TRUE(tensor_data.has_value()); - for (size_t i = 0; i < data.size(); i++) { - EXPECT_EQ(tensor_data.value()[i], data[i]); - } -} - -TEST(TensorWrapperTest, QnnTensorTest) { - std::vector dummy_dims = {1, 1, 3}; - std::vector data = {1, 2, 3}; - void* data_ptr = reinterpret_cast(data.data()); - const auto data_size = - std::accumulate(dummy_dims.begin(), dummy_dims.end(), - sizeof(decltype(data)::value_type), std::multiplies<>()); - - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_APP_WRITE, - QNN_DATATYPE_UFIXED_POINT_8, - QuantizeParamsWrapperVariant(), - dummy_dims, - static_cast(data_size), - data_ptr}; - - Qnn_Tensor_t cloned; - tensor_wrapper.CloneTo(cloned); - EXPECT_EQ(cloned.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(cloned.v2.id, 0); - EXPECT_EQ(cloned.v2.type, QNN_TENSOR_TYPE_APP_WRITE); - EXPECT_EQ(cloned.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER); - EXPECT_EQ(cloned.v2.dataType, QNN_DATATYPE_UFIXED_POINT_8); - EXPECT_EQ(cloned.v2.quantizeParams.encodingDefinition, - QNN_DEFINITION_UNDEFINED); - EXPECT_EQ(cloned.v2.rank, dummy_dims.size()); - for (size_t i = 0; i < cloned.v2.rank; i++) { - EXPECT_EQ(cloned.v2.dimensions[i], dummy_dims[i]); - } - EXPECT_EQ(cloned.v2.memType, QNN_TENSORMEMTYPE_RAW); - EXPECT_EQ(cloned.v2.clientBuf.dataSize, data_size); - const auto* cloned_data = - reinterpret_cast(cloned.v2.clientBuf.data); - for (size_t i = 0; i < data.size(); i++) { - EXPECT_EQ(cloned_data[i], data[i]); - } - - Qnn_Tensor_t& ref = tensor_wrapper.GetQnnTensor(); - EXPECT_EQ(ref.version, QNN_TENSOR_VERSION_2); - EXPECT_EQ(ref.v2.id, 0); - EXPECT_EQ(ref.v2.type, QNN_TENSOR_TYPE_APP_WRITE); - EXPECT_EQ(ref.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER); - EXPECT_EQ(ref.v2.dataType, QNN_DATATYPE_UFIXED_POINT_8); - EXPECT_EQ(ref.v2.quantizeParams.encodingDefinition, QNN_DEFINITION_UNDEFINED); - EXPECT_EQ(ref.v2.rank, dummy_dims.size()); - for (size_t i = 0; i < ref.v2.rank; i++) { - EXPECT_EQ(ref.v2.dimensions[i], dummy_dims[i]); - } - EXPECT_EQ(ref.v2.memType, QNN_TENSORMEMTYPE_RAW); - EXPECT_EQ(ref.v2.clientBuf.dataSize, data_size); - const auto* ref_data = - reinterpret_cast(ref.v2.clientBuf.data); - for (size_t i = 0; i < data.size(); i++) { - EXPECT_EQ(ref_data[i], data[i]); - } -} - -TEST(TensorWrapperTest, IsPerTensorQuantWithOffsetDiff8BitTest) { - constexpr int kSUFixed8OffsetDiff = 128; - ScaleOffsetQuantizeParamsWrapper wrapper1(1, 0); - ScaleOffsetQuantizeParamsWrapper wrapper2(1, kSUFixed8OffsetDiff); - TensorWrapper tensor_wrapper0{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UFIXED_POINT_8, - QuantizeParamsWrapperVariant(wrapper1), - {}}; - TensorWrapper tensor_wrapper1{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_SFIXED_POINT_8, - QuantizeParamsWrapperVariant(wrapper2), - {}}; - EXPECT_TRUE(tensor_wrapper0.IsPerTensorQuantWithOffsetDiff(tensor_wrapper1)); -} - -TEST(TensorWrapperTest, IsPerTensorQuantWithOffsetDiff16BitTest) { - constexpr int kSUFixed16OffsetDiff = 32768; - ScaleOffsetQuantizeParamsWrapper wrapper1(1, 0); - ScaleOffsetQuantizeParamsWrapper wrapper2(1, kSUFixed16OffsetDiff); - TensorWrapper tensor_wrapper0{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UFIXED_POINT_16, - QuantizeParamsWrapperVariant(wrapper1), - {}}; - TensorWrapper tensor_wrapper1{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_SFIXED_POINT_16, - QuantizeParamsWrapperVariant(wrapper2), - {}}; - EXPECT_TRUE(tensor_wrapper0.IsPerTensorQuantWithOffsetDiff(tensor_wrapper1)); -} - -TEST(TensorWrapperTest, StaticTensorTest) { - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UNDEFINED, - QuantizeParamsWrapperVariant(), - {}}; - - EXPECT_TRUE(tensor_wrapper.IsTensorStatic()); - EXPECT_FALSE(tensor_wrapper.IsSubgraphInput()); - EXPECT_FALSE(tensor_wrapper.IsSubgraphOutput()); -} - -TEST(TensorWrapperTest, SubgraphInputTensorTest) { - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_APP_WRITE, - QNN_DATATYPE_UNDEFINED, - QuantizeParamsWrapperVariant(), - {}}; - - EXPECT_FALSE(tensor_wrapper.IsTensorStatic()); - EXPECT_TRUE(tensor_wrapper.IsSubgraphInput()); - EXPECT_FALSE(tensor_wrapper.IsSubgraphOutput()); -} - -TEST(TensorWrapperTest, SubgraphOutputTensorTest) { - TensorWrapper tensor_wrapper{0, - QNN_TENSOR_TYPE_APP_READ, - QNN_DATATYPE_UNDEFINED, - QuantizeParamsWrapperVariant(), - {}}; - - EXPECT_FALSE(tensor_wrapper.IsTensorStatic()); - EXPECT_FALSE(tensor_wrapper.IsSubgraphInput()); - EXPECT_TRUE(tensor_wrapper.IsSubgraphOutput()); -} - -TEST(TensorWrapperTest, GetStaticTensorDataNonStaticTest) { - std::vector dummy_dims = {1, 1, 3}; - ScaleOffsetQuantizeParamsWrapper q_param(1, 0); - TensorWrapper tensor_wrapper{0, QNN_TENSOR_TYPE_APP_WRITE, - QNN_DATATYPE_UFIXED_POINT_8, q_param, - dummy_dims}; - EXPECT_FALSE(tensor_wrapper.GetStaticTensorData().has_value()); -} - -TEST(TensorWrapperTest, GetStaticTensorDataTest) { - std::vector dummy_dims = {1, 1, 3}; - ScaleOffsetQuantizeParamsWrapper q_param(1, 0); - TensorWrapper tensor_wrapper{0, QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_UFIXED_POINT_8, q_param, - dummy_dims}; - - EXPECT_FALSE(tensor_wrapper.GetStaticTensorData().has_value()); - EXPECT_FALSE(tensor_wrapper.GetStaticTensorData().has_value()); - EXPECT_FALSE(tensor_wrapper.GetStaticTensorData().has_value()); - std::vector data = {1, 2, 3}; - tensor_wrapper.SetTensorData( - absl::MakeSpan(data.data(), data.size())); - const auto tensor_data = - *(tensor_wrapper.GetStaticTensorData()); - for (size_t i = 0; i < data.size(); i++) { - EXPECT_EQ(tensor_data[i], data[i]); - } -} - -TEST(TensorWrapperTest, ConvertQint16ToQuint16Test) { - std::vector dummy_dims = {1, 1, 3}; - ScaleOffsetQuantizeParamsWrapper q_param(0.0001, 0); - TensorWrapper tensor_wrapper{0, QNN_TENSOR_TYPE_STATIC, - QNN_DATATYPE_SFIXED_POINT_16, q_param, - dummy_dims}; - - std::vector data = {1, 2, 3}; - const auto& int16_q_param_ref = tensor_wrapper.GetQuantParams(); - EXPECT_TRUE(std::holds_alternative( - int16_q_param_ref)); - const float int16_scale = - std::get(int16_q_param_ref).GetScale(); - const std::int32_t int16_zero_point = - std::get(int16_q_param_ref) - .GetZeroPoint(); - std::vector int16_data; - for (int i = 0; i < data.size(); ++i) { - int16_data.emplace_back( - Quantize(data[i], int16_scale, int16_zero_point)); - } - tensor_wrapper.SetTensorData( - absl::MakeSpan(int16_data.data(), int16_data.size())); - - tensor_wrapper.ConvertQint16ToQuint16(); - - const auto& uint16_q_param_ref = tensor_wrapper.GetQuantParams(); - EXPECT_TRUE(std::holds_alternative( - uint16_q_param_ref)); - const float uint16_scale = - std::get(uint16_q_param_ref).GetScale(); - const std::int32_t uint16_zero_point = - std::get(uint16_q_param_ref) - .GetZeroPoint(); - const auto uint16_data = - *(tensor_wrapper.GetStaticTensorData()); - std::vector deq_data; - for (size_t i = 0; i < data.size(); i++) { - deq_data.emplace_back( - Dequantize(uint16_data[i], uint16_scale, uint16_zero_point)); - } - ASSERT_EQ(data.size(), deq_data.size()); - for (size_t i = 0; i < data.size(); ++i) { - EXPECT_NEAR(data[i], deq_data[i], 1e-3); - } -} -} // namespace -} // namespace qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD deleted file mode 100644 index 2809dcc115188f..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "copy_file", "litert_dynamic_lib") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -litert_dynamic_lib( - name = "dispatch_api", - srcs = [ - "dispatch_api.cc", - "litert_dispatch_device_context.cc", - "litert_dispatch_invocation_context.cc", - ], - hdrs = [ - "litert_dispatch_device_context.h", - "litert_dispatch_invocation_context.h", - "registry.h", - ], - copts = [ - "-Os", - "-fno-exceptions", - "-fno-unwind-tables", - "-fno-asynchronous-unwind-tables", - "-ffunction-sections", - "-fdata-sections", - ], - export_litert_only = True, - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }) + [ - "-Wl,-soname=libLiteRtDispatch_Qualcomm.so", - "-Wl,-lc++abi", - ], - shared_lib_name = "dispatch_api_so", - so_name = "libLiteRtDispatch_Qualcomm.so", - tags = [ - # Don't build/test in OS until qnn is available. - "nobuilder", - ], - visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], - deps = [ - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/c:litert_runtime_c_api_shared_lib", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/core/util:tensor_type_util", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:context_binary_info", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", - ], -) - -# This is cc_library target for `libLiteRtDispatch_Qualcomm.so`. -cc_library( - name = "dispatch_api_shared_lib", - srcs = [":dispatch_api_so"], -) - -# Copies the shared library so that it is available for use in test data as libLiteRtDispatch_Qualcomm.so. -copy_file( - name = "copy_dispatch_api_so", - src = "//tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch:dispatch_api_so", - target = "libLiteRtDispatch_Qualcomm.so", -) - -cc_test( - name = "dispatch_api_qualcomm_test", - srcs = [ - "dispatch_api_qualcomm_test.cc", - ], - data = [ - ":dispatch_api_so", - ], - linkopts = select({ - "//tensorflow:android": ["-landroid"], - "//conditions:default": [], - }), - linkstatic = 1, - tags = [ - "no-remote-exec", - "no_oss", - "notap", - ], - deps = [ - "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/cc:litert_any", - "//tensorflow/lite/experimental/litert/core:filesystem", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/experimental/litert/test:simple_model_npu", - "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc deleted file mode 100644 index f377e1a26581e1..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc +++ /dev/null @@ -1,292 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace { - -using ::litert::qnn::QnnManager; - -static std::unique_ptr TheQnnManager; - -QnnManager& Qnn() { return *TheQnnManager; } - -char BuildId[256]; - -// ///////////////////////////////////////////////////////////////////////////// -// Basic Execution API -// ///////////////////////////////////////////////////////////////////////////// - -const char* GetSharedLibraryDir(const LiteRtDispatchOption* options, - int num_options) { - for (auto i = 0; i < num_options; ++i) { - auto& option = options[i]; - if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { - return option.value.str_value; - } - } - return nullptr; -} - -LiteRtStatus Initialize(const LiteRtDispatchOption* options, int num_options) { - auto* shared_library_dir = GetSharedLibraryDir(options, num_options); - std::optional shared_library_dir_opt = - shared_library_dir ? std::make_optional(std::string(shared_library_dir)) - : std::nullopt; - - auto configs = QnnManager::DefaultBackendConfigs(); - if (auto qnn_manager = QnnManager::Create(configs, shared_library_dir_opt); - !qnn_manager) { - LITERT_LOG(LITERT_ERROR, "%s", qnn_manager.Error().Message().c_str()); - return qnn_manager.Error().Status(); - } else { - std::swap(TheQnnManager, *qnn_manager); - } - - Qnn_ApiVersion_t qnn_api_version; - if (auto status = Qnn().Api()->backendGetApiVersion(&qnn_api_version); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to get QNN API version: %d", status); - return kLiteRtStatusErrorRuntimeFailure; - } - - const char* build_id; - if (auto status = Qnn().Api()->backendGetBuildId(&build_id); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to get QNN build ID: %d", status); - return kLiteRtStatusErrorRuntimeFailure; - } - - snprintf(BuildId, sizeof(BuildId), - "Qualcomm Dispatch API version %d.%d.%d, QNN API version %d.%d.%d, " - "build id: %s", - LITERT_API_VERSION_MAJOR, LITERT_API_VERSION_MINOR, - LITERT_API_VERSION_PATCH, qnn_api_version.coreApiVersion.major, - qnn_api_version.coreApiVersion.minor, - qnn_api_version.coreApiVersion.patch, build_id); - BuildId[sizeof(BuildId) - 1] = 0; - - return kLiteRtStatusOk; -} - -LiteRtStatus GetVendorId(const char** vendor_id) { - *vendor_id = "Qualcomm"; - return kLiteRtStatusOk; -} - -LiteRtStatus GetBuildId(const char** build_id) { - *build_id = BuildId; - return kLiteRtStatusOk; -} - -LiteRtStatus GetCapabilities(int* capabilities) { - *capabilities = kLiteRtDispatchCapabilitiesBasic; - return kLiteRtStatusOk; -} - -LiteRtStatus DeviceContextCreate(LiteRtDispatchDeviceContext* device_context) { - if (auto context = LiteRtDispatchDeviceContextT::Create(Qnn()); context) { - *device_context = context->release(); - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to create device context: %s", - context.Error().Message().c_str()); - return context.Error().Status(); - } -} - -LiteRtStatus DeviceContextDestroy(LiteRtDispatchDeviceContext device_context) { - delete device_context; - return kLiteRtStatusOk; -} - -LiteRtStatus GetInputRequirements( - LiteRtDispatchInvocationContext invocation_context, int input_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (auto requirements = - invocation_context->GetInputRequirements(input_index, *tensor_type); - requirements) { - *tensor_buffer_requirements = *requirements; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", - requirements.Error().Message().c_str()); - return requirements.Error().Status(); - } -} - -LiteRtStatus GetOutputRequirements( - LiteRtDispatchInvocationContext invocation_context, int output_index, - const LiteRtRankedTensorType* tensor_type, - LiteRtTensorBufferRequirements* tensor_buffer_requirements) { - if (auto requirements = - invocation_context->GetOutputRequirements(output_index, *tensor_type); - requirements) { - *tensor_buffer_requirements = *requirements; - return kLiteRtStatusOk; - } else { - LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", - requirements.Error().Message().c_str()); - return requirements.Error().Status(); - } -} - -LiteRtStatus RegisterTensorBuffer( - LiteRtDispatchDeviceContext device_context, LiteRtTensorBuffer buffer, - LiteRtTensorBufferHandle* tensor_buffer_handle) { - if (auto status = device_context->RegisterTensorBuffer(buffer); !status) { - LITERT_LOG(LITERT_ERROR, "Failed to register buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } else { - *tensor_buffer_handle = *status; - return kLiteRtStatusOk; - } -} - -LiteRtStatus UnregisterTensorBuffer(LiteRtDispatchDeviceContext device_context, - LiteRtTensorBufferHandle handle) { - if (auto status = device_context->UnregisterTensorBuffer(handle); !status) { - LITERT_LOG(LITERT_ERROR, "Failed to unregister buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } else { - return kLiteRtStatusOk; - } -} - -LiteRtStatus InvocationContextCreate( - LiteRtDispatchDeviceContext device_context, - LiteRtDispatchExecutableType exec_type, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name, - int num_inputs, int num_outputs, - LiteRtDispatchInvocationContext* invocation_context) { - auto context = LiteRtDispatchInvocationContextT::Create( - Qnn(), *device_context, exec_bytecode_buffer, function_name); - if (!context) { - LITERT_LOG(LITERT_ERROR, "Failed to create context from context binary: %s", - context.Error().Message().c_str()); - return context.Error().Status(); - } - *invocation_context = context->release(); - device_context->SetInvocationContext(*invocation_context); - return kLiteRtStatusOk; -} - -LiteRtStatus InvocationContextDestroy( - LiteRtDispatchInvocationContext invocation_context) { - delete invocation_context; - return kLiteRtStatusOk; -} - -LiteRtStatus AttachInput(LiteRtDispatchInvocationContext invocation_context, - int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = invocation_context->AttachInput(graph_input_index, - tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to attach input buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus AttachOutput(LiteRtDispatchInvocationContext invocation_context, - int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - if (auto status = invocation_context->AttachOutput(graph_output_index, - tensor_buffer_handle); - !status) { - LITERT_LOG(LITERT_ERROR, "Failed to attach output buffer: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -LiteRtStatus DetachInput(LiteRtDispatchInvocationContext invocation_context, - int graph_input_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - // Nothing to do here. - return kLiteRtStatusOk; -} - -LiteRtStatus DetachOutput(LiteRtDispatchInvocationContext invocation_context, - int graph_output_index, - LiteRtTensorBufferHandle tensor_buffer_handle) { - // Nothing to do here. - return kLiteRtStatusOk; -} - -LiteRtStatus Invoke(LiteRtDispatchInvocationContext invocation_context) { - if (auto status = invocation_context->Execute(); !status) { - LITERT_LOG(LITERT_ERROR, "Failed to execute invocation context: %s", - status.Error().Message().c_str()); - return status.Error().Status(); - } - return kLiteRtStatusOk; -} - -// ///////////////////////////////////////////////////////////////////////////// - -LiteRtDispatchInterface TheInterface = { - /*.initialize=*/Initialize, - /*.get_vendor_id=*/GetVendorId, - /*.get_build_id=*/GetBuildId, - /*.get_capabilities=*/GetCapabilities, - /*.device_context_create=*/DeviceContextCreate, - /*.device_context_destroy=*/DeviceContextDestroy, - /*.get_input_requirements=*/GetInputRequirements, - /*.get_output_requirements=*/GetOutputRequirements, - /*.register_tensor_buffer=*/RegisterTensorBuffer, - /*.unregister_tensor_buffer=*/UnregisterTensorBuffer, - /*.invocation_context_create=*/InvocationContextCreate, - /*.invocation_context_destroy=*/InvocationContextDestroy, - /*.attach_input=*/AttachInput, - /*.attach_output=*/AttachOutput, - /*.detach_input=*/DetachInput, - /*.detach_output=*/DetachOutput, - /*.invoke=*/Invoke, -}; - -LiteRtDispatchApi TheApi = { - /*.version=*/{/*.major=*/LITERT_API_VERSION_MAJOR, - /*.minor=*/LITERT_API_VERSION_MINOR, - /*.patch=*/LITERT_API_VERSION_PATCH}, - /*.interface=*/&TheInterface, - /*.async_interface=*/nullptr, - /*.graph_interface=*/nullptr, -}; - -} // namespace - -LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api) { - *api = TheApi; - return kLiteRtStatusOk; -} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc deleted file mode 100644 index c1ae8d1c53d4e4..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc +++ /dev/null @@ -1,544 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include -#include -#include "absl/log/absl_log.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_any.h" -#include "tensorflow/lite/experimental/litert/core/filesystem.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" - -using ::testing::Pointwise; - -TEST(Qualcomm, DispatchApiWithFastRpc) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "This test is specific to Android devices with a Qualcomm NPU"; -#endif - - LiteRtDispatchOption dispatch_option = { - /*.name=*/kDispatchOptionSharedLibraryDir, - /*.value=*/*litert::ToLiteRtAny(std::any("/data/local/tmp")), - }; - ASSERT_EQ( - LiteRtDispatchInitialize(/*options=*/&dispatch_option, /*num_options=*/1), - kLiteRtStatusOk); - - const char* vendor_id; - EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "vendor_id: " << vendor_id; - - const char* build_id; - EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "build_id: " << build_id; - - LiteRtApiVersion api_version; - EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); - ABSL_LOG(INFO) << "api_version: " << api_version.major << "." - << api_version.minor << "." << api_version.patch; - - int capabilities; - EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); - ABSL_LOG(INFO) << "capabilities: " << capabilities; - - LiteRtDispatchDeviceContext device_context = nullptr; - EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "device_context: " << device_context; - - auto model_file_name = - litert::testing::GetTestFilePath(kQualcommModelFileName); - auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model) << model.Error(); - ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() - << " bytes"; - - // /////////////////////////////////////////////////////////////////////////// - // Set up an invocation context for a given model. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtMemBuffer exec_bytecode_buffer = {/*.fd=*/-1, - /*.base_addr=*/model->Data(), - /*.offset=*/0, - /*.size=*/model->Size()}; - LiteRtDispatchInvocationContext invocation_context = nullptr; - EXPECT_EQ(LiteRtDispatchInvocationContextCreate( - device_context, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, /*function_name=*/"simple", - /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "Invocation context: " << invocation_context; - - // /////////////////////////////////////////////////////////////////////////// - // Determine tensor buffer requirements. - // /////////////////////////////////////////////////////////////////////////// - - int num_tensor_buffer_types; - LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/0, &kInput0TensorType, - &input_0_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_0_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_0_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_0_tensor_buffer_requirements, /*type_index=*/0, - &input_0_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeFastRpc); - size_t input_0_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); - - LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/1, &kInput1TensorType, - &input_1_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_1_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_1_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_1_tensor_buffer_requirements, /*type_index=*/0, - &input_1_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeFastRpc); - size_t input_1_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); - - LiteRtTensorBufferRequirements output_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetOutputRequirements( - invocation_context, /*output_index=*/0, &kOutputTensorType, - &output_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - output_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType output_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - output_tensor_buffer_requirements, /*type_index=*/0, - &output_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeFastRpc); - size_t output_tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - output_tensor_buffer_requirements, &output_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); - - // /////////////////////////////////////////////////////////////////////////// - // Allocate tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBuffer input_0_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_0_tensor_buffer_type, &kInput0TensorType, - input_0_tensor_buffer_size, &input_0_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer input_1_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_1_tensor_buffer_type, &kInput1TensorType, - input_1_tensor_buffer_size, &input_1_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer output_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - output_tensor_buffer_type, &kOutputTensorType, - output_tensor_buffer_size, &output_tensor_buffer), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Register tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBufferHandle input_1_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_1_tensor_buffer, &input_1_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle input_0_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_0_tensor_buffer, &input_0_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle output_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, output_tensor_buffer, &output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Attach tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Clean up resources. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), - kLiteRtStatusOk); - LiteRtDestroyTensorBuffer(output_tensor_buffer); - LiteRtDestroyTensorBuffer(input_1_tensor_buffer); - LiteRtDestroyTensorBuffer(input_0_tensor_buffer); - EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), - kLiteRtStatusOk); -} - -TEST(Qualcomm, DispatchApiWithDmaBuf) { -#if !defined(__ANDROID__) - GTEST_SKIP() - << "This test is specific to Android devices with a Qualcomm NPU"; -#endif - - EXPECT_EQ(LiteRtDispatchInitialize(/*options=*/nullptr, /*num_options=*/0), - kLiteRtStatusOk); - - const char* vendor_id; - EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "vendor_id: " << vendor_id; - - const char* build_id; - EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); - ABSL_LOG(INFO) << "build_id: " << build_id; - - LiteRtApiVersion api_version; - EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); - ABSL_LOG(INFO) << "api_version: " << api_version.major << "." - << api_version.minor << "." << api_version.patch; - - int capabilities; - EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); - ABSL_LOG(INFO) << "capabilities: " << capabilities; - - LiteRtDispatchDeviceContext device_context = nullptr; - EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "device_context: " << device_context; - - auto model_file_name = - litert::testing::GetTestFilePath(kQualcommModelFileName); - auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model) << model.Error(); - ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() - << " bytes"; - - // /////////////////////////////////////////////////////////////////////////// - // Set up an invocation context for a given model. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtMemBuffer exec_bytecode_buffer = {/*.fd=*/-1, - /*.base_addr=*/model->Data(), - /*.offset=*/0, - /*.size=*/model->Size()}; - LiteRtDispatchInvocationContext invocation_context = nullptr; - EXPECT_EQ(LiteRtDispatchInvocationContextCreate( - device_context, kLiteRtDispatchExecutableTypeMlModel, - &exec_bytecode_buffer, /*function_name=*/"simple", - /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), - kLiteRtStatusOk); - ABSL_LOG(INFO) << "Invocation context: " << invocation_context; - - // /////////////////////////////////////////////////////////////////////////// - // Determine tensor buffer requirements. - // /////////////////////////////////////////////////////////////////////////// - - int num_tensor_buffer_types; - LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/0, &kInput0TensorType, - &input_0_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_0_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_0_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_0_tensor_buffer_requirements, /*type_index=*/1, - &input_0_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); - size_t input_0_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); - - LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetInputRequirements( - invocation_context, /*input_index=*/1, &kInput1TensorType, - &input_1_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - input_1_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType input_1_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - input_1_tensor_buffer_requirements, /*type_index=*/1, - &input_1_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); - size_t input_1_tensor_buffer_size; - EXPECT_EQ( - LiteRtGetTensorBufferRequirementsBufferSize( - input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); - - LiteRtTensorBufferRequirements output_tensor_buffer_requirements; - EXPECT_EQ(LiteRtDispatchGetOutputRequirements( - invocation_context, /*output_index=*/0, &kOutputTensorType, - &output_tensor_buffer_requirements), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( - output_tensor_buffer_requirements, &num_tensor_buffer_types), - kLiteRtStatusOk); - EXPECT_GE(num_tensor_buffer_types, 1); - LiteRtTensorBufferType output_tensor_buffer_type; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( - output_tensor_buffer_requirements, /*type_index=*/1, - &output_tensor_buffer_type), - kLiteRtStatusOk); - EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); - size_t output_tensor_buffer_size; - EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( - output_tensor_buffer_requirements, &output_tensor_buffer_size), - kLiteRtStatusOk); - EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); - - // /////////////////////////////////////////////////////////////////////////// - // Allocate tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBuffer input_0_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_0_tensor_buffer_type, &kInput0TensorType, - input_0_tensor_buffer_size, &input_0_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer input_1_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - input_1_tensor_buffer_type, &kInput1TensorType, - input_1_tensor_buffer_size, &input_1_tensor_buffer), - kLiteRtStatusOk); - - LiteRtTensorBuffer output_tensor_buffer; - EXPECT_EQ(LiteRtCreateManagedTensorBuffer( - output_tensor_buffer_type, &kOutputTensorType, - output_tensor_buffer_size, &output_tensor_buffer), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Register tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - LiteRtTensorBufferHandle input_1_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_1_tensor_buffer, &input_1_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle input_0_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, input_0_tensor_buffer, &input_0_handle), - kLiteRtStatusOk); - - LiteRtTensorBufferHandle output_handle; - EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( - device_context, output_tensor_buffer, &output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Attach tensor buffers. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Fill the input buffers with data. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Filling inputs with data"; - void* host_mem_addr; - - ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); - - ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Execute model. - // /////////////////////////////////////////////////////////////////////////// - - ABSL_LOG(INFO) << "Invoking execution..."; - EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); - - // /////////////////////////////////////////////////////////////////////////// - // Check output for correctness. - // /////////////////////////////////////////////////////////////////////////// - - { - ABSL_LOG(INFO) << "Checking output..."; - void* host_mem_addr; - ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr), - kLiteRtStatusOk); - auto output = absl::MakeSpan(static_cast(host_mem_addr), - kTestOutputSize); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); - ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); - } - - // /////////////////////////////////////////////////////////////////////////// - // Clean up resources. - // /////////////////////////////////////////////////////////////////////////// - - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/0, input_0_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, - /*graph_input_index=*/1, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, - /*graph_output_index=*/0, output_handle), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), - kLiteRtStatusOk); - EXPECT_EQ( - LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), - kLiteRtStatusOk); - LiteRtDestroyTensorBuffer(output_tensor_buffer); - LiteRtDestroyTensorBuffer(input_1_tensor_buffer); - LiteRtDestroyTensorBuffer(input_0_tensor_buffer); - EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), - kLiteRtStatusOk); - EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), - kLiteRtStatusOk); -} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc deleted file mode 100644 index adf0ed86f80ca3..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h" - -#include -#include - -#include "absl/log/absl_check.h" -#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpMem.h" -#include "third_party/qairt/latest/include/QNN/QnnBackend.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnMem.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -using litert::Expected; -using litert::Unexpected; -using litert::qnn::QnnManager; - -Expected -LiteRtDispatchDeviceContextT::Create(QnnManager& qnn) { - return Ptr(new LiteRtDispatchDeviceContextT(qnn)); -} - -Expected LiteRtDispatchDeviceContextT::GetTensorBuffer( - LiteRtTensorBufferHandle tensor_buffer_handle) { - auto registry_entry = tensor_buffer_registry_.Get(tensor_buffer_handle); - if (!registry_entry) { - return Unexpected(registry_entry.Error()); - } - - return (*registry_entry)->tensor_buffer; -} - -Expected LiteRtDispatchDeviceContextT::GetMemHandle( - LiteRtTensorBufferHandle tensor_buffer_handle, const Qnn_Tensor_t& tensor) { - auto registry_entry = tensor_buffer_registry_.Get(tensor_buffer_handle); - if (!registry_entry) { - return Unexpected(registry_entry.Error()); - } - - if (!(*registry_entry)->qnn_mem_handle) { - auto qnn_mem_handle = - RegisterTensorBuffer((*registry_entry)->tensor_buffer, tensor); - if (!qnn_mem_handle) { - return Unexpected(qnn_mem_handle.Error()); - } - (*registry_entry)->qnn_mem_handle = *qnn_mem_handle; - } - - return (*registry_entry)->qnn_mem_handle; -} - -Expected LiteRtDispatchDeviceContextT::RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer, const Qnn_Tensor_t& tensor) { - LiteRtTensorBufferType tensor_buffer_type; - if (auto status = - LiteRtGetTensorBufferType(tensor_buffer, &tensor_buffer_type); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get tensor buffer type"); - } - - size_t tensor_buffer_size; - if (auto status = - LiteRtGetTensorBufferSize(tensor_buffer, &tensor_buffer_size); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get tensor buffer size"); - } - - size_t tensor_buffer_offset; - if (auto status = - LiteRtGetTensorBufferOffset(tensor_buffer, &tensor_buffer_offset); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get tensor buffer offset"); - } - - LiteRtRankedTensorType tensor_type; - if (auto status = - LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get tensor buffer's type"); - } - - auto element_type = - static_cast(tensor_type.element_type); - Qnn_DataType_t tensor_data_type; - if (auto status = LegalizeElementType(element_type, &tensor_data_type); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to legalize datatype"); - } - - uint32_t tensor_rank = tensor_type.layout.rank; - uint32_t* tensor_dimensions = reinterpret_cast( - const_cast(tensor_type.layout.dimensions)); - auto* tensor_strides = tensor_type.layout.strides; - if (tensor_strides != nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Tensor strides are not supported by QNN"); - } - - void* buffer_host_addr; - int buffer_fd; - (void)buffer_host_addr; - - switch (tensor_buffer_type) { - case kLiteRtTensorBufferTypeFastRpc: -#if LITERT_HAS_FASTRPC_SUPPORT - if (auto status = LiteRtGetTensorBufferFastRpcBuffer( - tensor_buffer, &buffer_host_addr, &buffer_fd); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get FastRPC buffer"); - } -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "FastRPC support is missing on this platform"); -#endif // LRT_HAS_FASTRPC_SUPPORT - break; - - case kLiteRtTensorBufferTypeDmaBuf: -#if LITERT_HAS_DMABUF_SUPPORT - if (auto status = LiteRtGetTensorBufferDmaBufBuffer( - tensor_buffer, &buffer_host_addr, &buffer_fd); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get DMA-BUF buffer"); - } -#else - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "DmaBuf support is missing on this platform"); -#endif // LRT_HAS_DMABUF_SUPPORT - break; - - default: - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported tensor buffer type"); - } - - QnnMemHtp_Descriptor_t mem_htp_descriptor = {}; - mem_htp_descriptor.type = QNN_HTP_MEM_SHARED_BUFFER; - mem_htp_descriptor.size = tensor_buffer_size; - mem_htp_descriptor.sharedBufferConfig = - QnnHtpMem_SharedBufferConfig_t{buffer_fd, tensor_buffer_offset}; - - Qnn_MemDescriptor_t mem_descriptor = {}; - mem_descriptor.memShape = {tensor_rank, tensor_dimensions, nullptr}; - mem_descriptor.dataType = tensor_data_type; - mem_descriptor.memType = QNN_MEM_TYPE_CUSTOM; - mem_descriptor.customInfo = &mem_htp_descriptor; - - if (invocation_context_ == nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Missing invocation context"); - } - - Qnn_ContextHandle_t context_handle = invocation_context_->ContextHandle(); - - Qnn_MemHandle_t mem_handle = nullptr; - if (auto status = qnn_manager_.Api()->memRegister( - context_handle, &mem_descriptor, 1UL, &mem_handle); - status != QNN_SUCCESS) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to register tensor buffer"); - } - - if (!mem_handle) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to register buffer: null mem_handle"); - } - - return mem_handle; -} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h deleted file mode 100644 index bd375c5137fcba..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ - -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -class LiteRtDispatchDeviceContextT { - public: - using Ptr = std::unique_ptr; - - ~LiteRtDispatchDeviceContextT() = default; - - static litert::Expected Create(litert::qnn::QnnManager& qnn_manager); - - litert::Expected RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer) { - return tensor_buffer_registry_.Register( - TensorBufferRegistryEntry(tensor_buffer)); - } - - litert::Expected UnregisterTensorBuffer( - LiteRtTensorBufferHandle tensor_buffer_handle) { - return tensor_buffer_registry_.Unregister(tensor_buffer_handle); - } - - litert::Expected GetTensorBuffer( - LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected GetMemHandle( - LiteRtTensorBufferHandle tensor_buffer_handle, - const Qnn_Tensor_t& tensor); - - void SetInvocationContext( - LiteRtDispatchInvocationContextT* invocation_context) { - invocation_context_ = invocation_context; - } - - private: - struct TensorBufferRegistryEntry { - LiteRtTensorBuffer tensor_buffer; - Qnn_MemHandle_t qnn_mem_handle = nullptr; - explicit TensorBufferRegistryEntry(LiteRtTensorBuffer tensor_buffer_) - : tensor_buffer(tensor_buffer_) {} - }; - - using TensorBufferRegistry = litert::qnn::Registry; - - LiteRtDispatchDeviceContextT(litert::qnn::QnnManager& qnn_manager) - : qnn_manager_(qnn_manager) {} - - litert::Expected RegisterTensorBuffer( - LiteRtTensorBuffer tensor_buffer, const Qnn_Tensor_t& tensor); - - litert::qnn::QnnManager& qnn_manager_; - TensorBufferRegistry tensor_buffer_registry_; - LiteRtDispatchInvocationContextT* invocation_context_ = nullptr; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc deleted file mode 100644 index 6d05088cf06681..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h" - -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/core/util/tensor_type_util.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -using litert::Expected; -using litert::Unexpected; -using litert::qnn::QnnManager; - -LiteRtDispatchInvocationContextT::LiteRtDispatchInvocationContextT( - litert::qnn::QnnManager& qnn_manager, - const litert::qnn::ContextBinaryInfo& context_binary_info, - LiteRtDispatchDeviceContextT& device_context, - QnnManager::ContextHandle&& context_handle, - Qnn_ProfileHandle_t profile_handle, int graph_index, - Qnn_GraphHandle_t graph_handle) - : qnn_manager_(qnn_manager), - device_context_(device_context), - context_handle_(std::move(context_handle)), - profile_handle_(profile_handle), - graph_index_(graph_index), - graph_handle_(graph_handle), - inputs_(context_binary_info.Graphs()[graph_index].Inputs()), - outputs_(context_binary_info.Graphs()[graph_index].Outputs()) {} - -Expected -LiteRtDispatchInvocationContextT::Create( - QnnManager& qnn, LiteRtDispatchDeviceContextT& device_context, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name) { - auto exec_bytecode_ptr = - static_cast(exec_bytecode_buffer->base_addr) + - exec_bytecode_buffer->offset; - auto context_binary_info = litert::qnn::ContextBinaryInfo::Create( - qnn, exec_bytecode_ptr, exec_bytecode_buffer->size); - if (!context_binary_info) { - return Unexpected(context_binary_info.Error()); - } - - const auto& graphs = context_binary_info->Graphs(); - if (graphs.empty()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, "No graph found"); - } - - int graph_index = -1; - // If the function_name is not specified and there is only one graph, then - // take that graph. - if (absl::string_view(function_name).empty() && graphs.size() == 1) { - graph_index = 0; - const auto& graph = graphs[graph_index]; - function_name = graph.Name().c_str(); - } else { - for (auto i = 0; i < graphs.size(); ++i) { - const auto& graph = graphs[i]; - if (graph.Name() == absl::string_view(function_name)) { - graph_index = i; - break; - } - } - } - if (graph_index < 0) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Function name not found"); - } - - auto configs = QnnManager::DefaultContextConfigs(); - Qnn_ProfileHandle_t profile_handle = nullptr; - auto context_handle = qnn.CreateContextHandle( - configs, - absl::MakeSpan(static_cast(exec_bytecode_ptr), - exec_bytecode_buffer->size), - profile_handle); - if (!context_handle) { - return Unexpected(context_handle.Error()); - } - - Qnn_GraphHandle_t graph_handle; - if (auto status = qnn.Api()->graphRetrieve(context_handle->get(), - function_name, &graph_handle); - status != QNN_SUCCESS) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to retrieve graph"); - } - - return Ptr(new LiteRtDispatchInvocationContextT( - qnn, std::move(*context_binary_info), device_context, - std::move(*context_handle), profile_handle, graph_index, graph_handle)); -} - -namespace { - -Expected GetTensorBufferRequirements( - const LiteRtRankedTensorType& tensor_type) { - auto* tensor_strides = tensor_type.layout.strides; - if (tensor_strides != nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Tensor strides are not supported by QNN"); - } - - static constexpr std::array - kSupportedTensorBufferTypes = { - kLiteRtTensorBufferTypeFastRpc, - kLiteRtTensorBufferTypeDmaBuf, - }; - - auto buffer_size = litert::internal::GetNumPackedBytes(tensor_type); - if (!buffer_size) { - return Unexpected(buffer_size.Error()); - } - - LiteRtTensorBufferRequirements requirements; - if (auto status = LiteRtCreateTensorBufferRequirements( - kSupportedTensorBufferTypes.size(), - kSupportedTensorBufferTypes.data(), *buffer_size, /*num_strides=*/0, - /*strides=*/nullptr, &requirements); - status != kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, "Not implemented"); - } - - return requirements; -} - -} // namespace - -Expected -LiteRtDispatchInvocationContextT::GetInputRequirements( - int input_index, const LiteRtRankedTensorType& tensor_type) { - return GetTensorBufferRequirements(tensor_type); -} - -Expected -LiteRtDispatchInvocationContextT::GetOutputRequirements( - int output_index, const LiteRtRankedTensorType& tensor_type) { - return GetTensorBufferRequirements(tensor_type); -} - -Expected LiteRtDispatchInvocationContextT::AttachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - if (graph_input_index < 0 || graph_input_index >= inputs_.size()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Invalid graph_input_index"); - } - - auto& tensor = inputs_[graph_input_index]; - return AttachBuffer(tensor.Tensor(), tensor_buffer_handle); -} - -Expected LiteRtDispatchInvocationContextT::AttachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { - if (graph_output_index < 0 || graph_output_index >= outputs_.size()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Invalid graph_output_index"); - } - - auto& tensor = outputs_[graph_output_index]; - return AttachBuffer(tensor.Tensor(), tensor_buffer_handle); -} - -Expected LiteRtDispatchInvocationContextT::AttachBuffer( - Qnn_Tensor_t& tensor, LiteRtTensorBufferHandle tensor_buffer_handle) { - auto tensor_buffer = device_context_.GetTensorBuffer(tensor_buffer_handle); - if (!tensor_buffer) { - return Unexpected(tensor_buffer.Error()); - } - - auto mem_handle = device_context_.GetMemHandle(tensor_buffer_handle, tensor); - if (!mem_handle) { - return Unexpected(mem_handle.Error()); - } - - if (tensor.version == QNN_TENSOR_VERSION_1) { - tensor.v1.memType = QNN_TENSORMEMTYPE_MEMHANDLE; - tensor.v1.memHandle = *mem_handle; - - } else if (tensor.version == QNN_TENSOR_VERSION_2) { - if (tensor.v2.isDynamicDimensions != nullptr) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Dynamic dimensions not yet supported"); - } - tensor.v2.memType = QNN_TENSORMEMTYPE_MEMHANDLE; - tensor.v2.memHandle = *mem_handle; - - } else { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported QNN tensor version"); - } - - return {}; -} - -Expected LiteRtDispatchInvocationContextT::Execute() { - const size_t num_ins = inputs_.size(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, inputs, num_ins, QNN_TENSOR_INIT); - for (size_t i = 0; i < num_ins; ++i) { - *(inputs + i) = inputs_.at(i).Tensor(); - } - - const size_t num_outs = outputs_.size(); - LITERT_STACK_ARRAY(Qnn_Tensor_t, outputs, num_outs, QNN_TENSOR_INIT); - for (size_t i = 0; i < num_outs; ++i) { - *(outputs + i) = outputs_.at(i).Tensor(); - } - - if (auto status = qnn_manager_.Api()->graphExecute( - graph_handle_, inputs, num_ins, outputs, num_outs, - /*profileHandle=*/nullptr, /*signalHandle=*/nullptr); - status != QNN_SUCCESS) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to execute graph"); - } - - return {}; -} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h deleted file mode 100644 index 17759238816302..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ - -#include - -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -class LiteRtDispatchDeviceContextT; - -class LiteRtDispatchInvocationContextT { - public: - using Ptr = std::unique_ptr; - - ~LiteRtDispatchInvocationContextT() = default; - - static litert::Expected Create( - litert::qnn::QnnManager& qnn_manager, - LiteRtDispatchDeviceContextT& device_context, - const LiteRtMemBuffer* exec_bytecode_buffer, const char* function_name); - - litert::Expected GetInputRequirements( - int input_index, const LiteRtRankedTensorType& tensor_type); - litert::Expected GetOutputRequirements( - int output_index, const LiteRtRankedTensorType& tensor_type); - - litert::Expected AttachInput( - int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected AttachOutput( - int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::Expected Execute(); - - Qnn_ContextHandle_t ContextHandle() { return context_handle_.get(); } - - private: - LiteRtDispatchInvocationContextT( - litert::qnn::QnnManager& qnn_manager, - const litert::qnn::ContextBinaryInfo& context_binary_info, - LiteRtDispatchDeviceContextT& device_context, - litert::qnn::QnnManager::ContextHandle&& context_handle, - Qnn_ProfileHandle_t profile_handle, int graph_index, - Qnn_GraphHandle_t graph_handle); - - litert::Expected AttachBuffer( - Qnn_Tensor_t& tensor, LiteRtTensorBufferHandle tensor_buffer_handle); - - litert::qnn::QnnManager& qnn_manager_; - LiteRtDispatchDeviceContextT& device_context_; - litert::qnn::QnnManager::ContextHandle context_handle_; - Qnn_ProfileHandle_t profile_handle_; - int graph_index_; - Qnn_GraphHandle_t graph_handle_; - std::vector inputs_; - std::vector outputs_; -}; - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h deleted file mode 100644 index 8a80e342568e32..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_REGISTRY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_REGISTRY_H_ - -#include - -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::qnn { - -template -class Registry { - public: - Expected Register(const V& value) { - // TODO: improve this linear search by keeping an index to the first unused - // element. - for (auto i = 0; i < entries_.size(); ++i) { - auto& entry = entries_[i]; - if (!entry.used) { - entry.value = value; - entry.used = true; - return static_cast(i); - } - } - // Grow the set of entries. - H handle = static_cast(entries_.size()); - entries_.emplace_back(value); - return handle; - } - - Expected Unregister(H handle) { - if (handle < 0 || handle >= entries_.size()) { - return Unexpected(kLiteRtStatusErrorNotFound, "Unexpected handle"); - } - entries_[handle].used = false; - return {}; - } - - Expected Get(H handle) { - if (handle < 0 || handle >= entries_.size()) { - return Unexpected(kLiteRtStatusErrorNotFound, "Unexpected handle"); - } - return &entries_[handle].value; - } - - private: - struct Entry { - V value; - bool used; - explicit Entry(const V& v) : value(v), used(true) {} - }; - - std::vector entries_; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_REGISTRY_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.cc deleted file mode 100644 index a0967992192570..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h" - -#include -#include -#include -#include - -#include "third_party/qairt/latest/include/QNN/QnnLog.h" - -namespace litert::qnn { -namespace { - -void DefaultStdOutLogger(const char* fmt, QnnLog_Level_t level, - uint64_t timestamp, va_list argp) { - const char* levelStr = ""; - switch (level) { - case QNN_LOG_LEVEL_ERROR: - levelStr = " ERROR "; - break; - case QNN_LOG_LEVEL_WARN: - levelStr = "WARNING"; - break; - case QNN_LOG_LEVEL_INFO: - levelStr = " INFO "; - break; - case QNN_LOG_LEVEL_DEBUG: - levelStr = " DEBUG "; - break; - case QNN_LOG_LEVEL_VERBOSE: - levelStr = "VERBOSE"; - break; - case QNN_LOG_LEVEL_MAX: - levelStr = "UNKNOWN"; - break; - } - char buffer1[256]; - char buffer2[256]; - double ms = timestamp; - snprintf(buffer1, sizeof(buffer1), "%8.1fms [%-7s] ", ms, levelStr); - buffer1[sizeof(buffer1) - 1] = 0; - vsnprintf(buffer2, sizeof(buffer2), fmt, argp); - buffer2[sizeof(buffer1) - 2] = 0; - std::cout << buffer1 << buffer2; -} - -} // namespace - -QnnLog_Callback_t GetDefaultStdOutLogger() { return DefaultStdOutLogger; } - -} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h deleted file mode 100644 index 934a164b49f933..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_LOG_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_LOG_H_ - -#include "third_party/qairt/latest/include/QNN/QnnLog.h" - -namespace litert::qnn { - -// Gets a default logger implementation to stdout. -// This is used when initializing qnn logging. -QnnLog_Callback_t GetDefaultStdOutLogger(); - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_LOG_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.cc deleted file mode 100644 index 0094d76cb6a340..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.cc +++ /dev/null @@ -1,411 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -#include - -#include -#include // NOLINT -#include -#include -#include - -#include "absl/strings/match.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpContext.h" -#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpDevice.h" -#include "third_party/qairt/latest/include/QNN/QnnBackend.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnContext.h" -#include "third_party/qairt/latest/include/QNN/QnnDevice.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnLog.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemCommon.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemContext.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h" - -namespace litert::qnn { - -namespace { - -constexpr char kLibQnnGetProvidersSymbol[] = "QnnInterface_getProviders"; - -constexpr char kLibQnnSystemGetProvidersSymbol[] = - "QnnSystemInterface_getProviders"; - -typedef Qnn_ErrorHandle_t (*QnnInterfaceGetProvidersFn_t)( - const QnnInterface_t*** provider_list, uint32_t* num_providers); - -typedef Qnn_ErrorHandle_t (*QnnSystemInterfaceGetProvidersFn_t)( - const QnnSystemInterface_t***, uint32_t*); - -Expected> LoadProvidersFromLib( - SharedLibrary& lib) { - QnnInterfaceGetProvidersFn_t get_providers = nullptr; - LITERT_ASSIGN_OR_RETURN(get_providers, - lib.LookupSymbol( - kLibQnnGetProvidersSymbol)); - const QnnInterface_t** interface_providers = nullptr; - uint32_t num_providers = 0; - if (QNN_SUCCESS != get_providers(&interface_providers, &num_providers)) { - return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to get providers"); - } - return absl::MakeSpan(interface_providers, num_providers); -} - -Expected> LoadSystemProvidersFromLib( - SharedLibrary& lib) { - LITERT_ASSIGN_OR_RETURN(QnnSystemInterfaceGetProvidersFn_t get_providers, - lib.LookupSymbol( - kLibQnnSystemGetProvidersSymbol)); - const QnnSystemInterface_t** interface_providers = nullptr; - uint32_t num_providers = 0; - if (QNN_SUCCESS != get_providers(&interface_providers, &num_providers)) { - return Error(kLiteRtStatusErrorRuntimeFailure, - "Failed to get system providers"); - } - return absl::MakeSpan(interface_providers, num_providers); -} - -} // namespace - -QnnManager::~QnnManager() { - (void)FreeDevice(); - (void)FreeBackend(); - (void)FreeLogging(); -} - -LiteRtStatus QnnManager::LoadLib(absl::string_view path) { - LITERT_LOG(LITERT_INFO, "Loading qnn shared library from \"%s\"", - path.data()); - LITERT_ASSIGN_OR_RETURN(lib_, - SharedLibrary::Load(path, RtldFlags::Default())); - LITERT_LOG(LITERT_INFO, "Loaded qnn shared library", ""); - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::LoadSystemLib(absl::string_view path) { - LITERT_ASSIGN_OR_RETURN(lib_system_, - SharedLibrary::Load(path, RtldFlags::Default())); - return kLiteRtStatusOk; -} - -const QnnApi* QnnManager::Api() const { - if (interface_ == nullptr) { - return nullptr; - } - return &interface_->QNN_INTERFACE_VER_NAME; -} - -LiteRtStatus QnnManager::ResolveApi() { - if (!lib_.Loaded()) { - LITERT_LOG(LITERT_ERROR, "%s", - "Cannot resolve functions: libQnn*.so has not been loaded.\n"); - return kLiteRtStatusErrorDynamicLoading; - } - - LITERT_ASSIGN_OR_RETURN(auto providers, LoadProvidersFromLib(lib_)); - for (const auto& prov : providers) { - const bool major = - prov->apiVersion.coreApiVersion.major == QNN_API_VERSION_MAJOR; - - const bool minor = - prov->apiVersion.coreApiVersion.minor == QNN_API_VERSION_MINOR; - - const bool patch = - prov->apiVersion.coreApiVersion.patch == QNN_API_VERSION_PATCH; - - if (major && minor && patch) { - interface_ = prov; - break; - } - } - - if (interface_ == nullptr) { - LITERT_LOG(LITERT_ERROR, "%s", "No valid interface was provided\n"); - return kLiteRtStatusErrorDynamicLoading; - } - - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::ResolveSystemApi() { - if (!lib_.Loaded()) { - LITERT_LOG(LITERT_ERROR, "%s", - "Cannot resolve functions: libQnn*.so has not been loaded.\n"); - return kLiteRtStatusErrorDynamicLoading; - } - - LITERT_ASSIGN_OR_RETURN(auto system_providers, - LoadSystemProvidersFromLib(lib_system_)); - for (const auto& system_prov : system_providers) { - const bool major = - system_prov->systemApiVersion.major == QNN_SYSTEM_API_VERSION_MAJOR; - - const bool minor = - system_prov->systemApiVersion.minor == QNN_SYSTEM_API_VERSION_MINOR; - - const bool patch = - system_prov->systemApiVersion.patch == QNN_SYSTEM_API_VERSION_PATCH; - - if (major && minor && patch) { - system_interface_ = system_prov; - break; - } - } - - if (system_interface_ == nullptr) { - LITERT_LOG(LITERT_ERROR, "%s", "No valid system interface was provided\n"); - return kLiteRtStatusErrorDynamicLoading; - } - - return kLiteRtStatusOk; -} - -const QnnSystemApi* QnnManager::SystemApi() const { - if (system_interface_ == nullptr) { - return nullptr; - } - return &system_interface_->QNN_SYSTEM_INTERFACE_VER_NAME; -} - -LiteRtStatus QnnManager::FreeLogging() { - if (log_handle_ != nullptr) { - if (QNN_SUCCESS != Api()->logFree(log_handle_)) { - LITERT_LOG(LITERT_ERROR, "%s", "Failed to free logging\n"); - return kLiteRtStatusErrorNotFound; - } - } - log_handle_ = nullptr; - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::FreeBackend() { - if (backend_handle_ != nullptr) { - if (QNN_SUCCESS != Api()->backendFree(backend_handle_)) { - LITERT_LOG(LITERT_ERROR, "%s", "Failed to free backend\n"); - return kLiteRtStatusErrorNotFound; - } - } - backend_handle_ = nullptr; - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::FreeDevice() { - if (device_handle_ != nullptr) { - if (QNN_SUCCESS != Api()->deviceFree(device_handle_)) { - LITERT_LOG(LITERT_ERROR, "%s", "Failed to free device\n"); - return kLiteRtStatusErrorNotFound; - } - } - device_handle_ = nullptr; - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::GenerateContextBinary( - Qnn_ContextHandle_t context_handle, std::vector& buffer) { - Qnn_ContextBinarySize_t bin_size = 0; - if (QNN_SUCCESS != Api()->contextGetBinarySize(context_handle, &bin_size)) { - LITERT_LOG(LITERT_ERROR, "%s", "Failed to get context bin size\n"); - return kLiteRtStatusErrorNotFound; - } - buffer.clear(); - buffer.resize(bin_size); - - Qnn_ContextBinarySize_t written_bin_size = 0; - if (QNN_SUCCESS != Api()->contextGetBinary(context_handle, buffer.data(), - buffer.size(), - &written_bin_size)) { - LITERT_LOG(LITERT_ERROR, "%s", "Failed to generated context binary \n"); - return kLiteRtStatusErrorNotFound; - } - - LITERT_LOG(LITERT_INFO, "Serialized a context bin of size (bytes): %lu\n", - written_bin_size); - - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::ValidateOp(const Qnn_OpConfig_t& op_config) { - if (Qnn_ErrorHandle_t error = - Api()->backendValidateOpConfig(BackendHandle(), op_config); - QNN_SUCCESS != error) { - LITERT_LOG(LITERT_ERROR, "Failed to validate op %s\n, error: %lld", - op_config.v1.name, static_cast(error)); - return kLiteRtStatusErrorInvalidLegalization; - } - - return kLiteRtStatusOk; -} - -LiteRtStatus QnnManager::Init(absl::Span configs, - std::optional shared_library_dir, - std::optional soc_model) { - // If shared_library_dir is provided, add it to the path as it may contain - // libs to be loaded. - // TOOD: This should probably be done upstream in litert_dispatch. - if (shared_library_dir) { - LITERT_LOG(LITERT_INFO, "Adding shared library dir to path: %s", - shared_library_dir->c_str()); - - static constexpr char kAdsp[] = "ADSP_LIBRARY_PATH"; - if (getenv(kAdsp) == nullptr) { - setenv(kAdsp, shared_library_dir->data(), /*overwrite=*/1); - } - - // TODO: Put dynamic loading module in cc or vendor/cc. - litert::internal::PutLibOnLdPath(shared_library_dir->data(), kLibQnnHtpSo); - } - - LITERT_RETURN_IF_ERROR(LoadLib(kLibQnnHtpSo)); - LITERT_RETURN_IF_ERROR(ResolveApi()); - - LITERT_RETURN_IF_ERROR(LoadSystemLib(kLibQnnSystemSo)); - LITERT_RETURN_IF_ERROR(ResolveSystemApi()); - - if (auto status = Api()->logCreate(GetDefaultStdOutLogger(), - QNN_LOG_LEVEL_INFO, &LogHandle()); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to create QNN logger: %d", status); - return kLiteRtStatusErrorRuntimeFailure; - } - - if (auto status = - Api()->backendCreate(LogHandle(), configs.data(), &BackendHandle()); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to create QNN backend: %d", status); - return kLiteRtStatusErrorRuntimeFailure; - } - - if (soc_model.has_value()) { - soc_model_ = *soc_model; - LITERT_LOG(LITERT_INFO, - "Initializing QNN backend for device architecture %d", - *soc_model); - QnnHtpDevice_CustomConfig_t arch_custom_config = {}; - arch_custom_config.option = QNN_HTP_DEVICE_CONFIG_OPTION_ARCH; - arch_custom_config.arch.arch = *soc_model; - arch_custom_config.arch.deviceId = 0; - - QnnDevice_Config_t arch_device_config = {}; - arch_device_config.option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; - arch_device_config.customConfig = &arch_custom_config; - - const QnnDevice_Config_t* device_configs[2] = { - &arch_device_config, - nullptr, - }; - - if (auto status = - Api()->deviceCreate(nullptr, device_configs, &DeviceHandle()); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to create QNN device: %d", status); - return kLiteRtStatusErrorRuntimeFailure; - } - } - - return kLiteRtStatusOk; -} - -Expected -QnnManager::CreateSystemContextHandle() { - QnnSystemContext_Handle_t system_context_handle; - if (auto status = SystemApi()->systemContextCreate(&system_context_handle); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to create QNN system context: %d", status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create QNN system context"); - } - auto deleter = SystemApi()->systemContextFree; - return SystemContextHandle{system_context_handle, deleter}; -} - -Expected QnnManager::CreateContextHandle( - absl::Span configs) { - Qnn_ContextHandle_t context_handle; - if (auto status = Api()->contextCreate(BackendHandle(), DeviceHandle(), - configs.data(), &context_handle); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to create QNN context: %d", status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create QNN context"); - } - auto deleter = Api()->contextFree; - return ContextHandle{context_handle, /*profile=*/nullptr, deleter}; -} - -Expected QnnManager::CreateContextHandle( - absl::Span configs, - absl::Span bytecode, Qnn_ProfileHandle_t profile_handle) { - Qnn_ContextHandle_t context_handle; - if (auto status = Api()->contextCreateFromBinary( - BackendHandle(), DeviceHandle(), configs.data(), bytecode.data(), - bytecode.size(), &context_handle, profile_handle); - status != QNN_SUCCESS) { - LITERT_LOG(LITERT_ERROR, "Failed to create QNN context: %d", status); - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create QNN context"); - } - auto deleter = Api()->contextFree; - return ContextHandle{context_handle, profile_handle, deleter}; -} - -Expected QnnManager::Create( - absl::Span configs, - std::optional shared_library_dir, - std::optional soc_model) { - Ptr qnn_manager(new QnnManager); - if (auto status = qnn_manager->Init(configs, shared_library_dir, soc_model); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to set up QNN manager"); - } - return qnn_manager; -} - -absl::Span QnnManager::DefaultBackendConfigs() { - static const QnnBackend_Config_t* configs[] = {nullptr}; - return absl::MakeSpan(configs); -} - -absl::Span QnnManager::DefaultContextConfigs() { - static const QnnContext_Config_t* configs[] = {nullptr}; - return absl::MakeSpan(configs); -} - -absl::Span -QnnManager::WeightSharingContextConfigs() { - static QnnHtpContext_CustomConfig_t customConfig = - QNN_HTP_CONTEXT_CUSTOM_CONFIG_INIT; - customConfig.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; - customConfig.weightSharingEnabled = true; - static QnnContext_Config_t contextConfig = QNN_CONTEXT_CONFIG_INIT; - contextConfig.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; - contextConfig.customConfig = &customConfig; - static const QnnContext_Config_t* configs[2] = {&contextConfig, nullptr}; - return absl::MakeSpan(configs); -} - -}; // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h deleted file mode 100644 index 30d00ab7169706..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h +++ /dev/null @@ -1,239 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_MANAGER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_MANAGER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpDevice.h" -#include "third_party/qairt/latest/include/QNN/QnnBackend.h" -#include "third_party/qairt/latest/include/QNN/QnnCommon.h" -#include "third_party/qairt/latest/include/QNN/QnnContext.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemContext.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" // IWYU pragma: keep -#include "tensorflow/lite/experimental/litert/cc/litert_shared_library.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" - -//===----------------------------------------------------------------------===// -// -// QnnManger -// -// Syntactic sugar for various Qnn Sdk routines. -// -// Provides various utilities for linking shared libraries at runtime -// against Qnn symbols as well as convience getters and storage of handles -// (pointers). Provides simple wrappers for freeing handles and returning -// LiteRtStatus rather than Qnn ones. Additionally exposes hooks for dumping -// api and shared libarary details. -// -// Does not own any memory and will always have trivial cstor/dstor. The -// user is responsible for freeing any Qnn handles explicitly. Note, -// Qnn handles will be automatically freed when the library is unloaded -// if they have been already. -// -//===----------------------------------------------------------------------===// - -namespace litert::qnn { - -class QnnManager; - -namespace internal { - -void Dump(const QnnManager& qnn, std::ostream& out); - -} // namespace internal - -class QnnManager { - friend void internal::Dump(const QnnManager& qnn, std::ostream& out); - - public: - using Ptr = std::unique_ptr; - using SystemContextHandle = - std::unique_ptr::type, - QnnSystemContext_FreeFn_t>; - class ContextHandle; - - ~QnnManager(); - - static Expected Create( - absl::Span configs, - std::optional shared_library_dir = std::nullopt, - std::optional soc_model = std::nullopt); - - static absl::Span DefaultBackendConfigs(); - static absl::Span DefaultContextConfigs(); - static absl::Span WeightSharingContextConfigs(); - - // Get resolved function pointers for qnn sdk calls. Nullptr if functions - // have not been resolved yet. - const QnnApi* Api() const; - - // Get resolved function pointers for qnn sdk calls. Nullptr if functions - // have not been resolved yet. - const QnnSystemApi* SystemApi() const; - - // - // QNN SDK Objects. - // - - // Create system context handle. - Expected CreateSystemContextHandle(); - - // Create a context handle for compilation. - Expected CreateContextHandle( - absl::Span configs); - - // Create a context handle for inference, from a given bytecode. - Expected CreateContextHandle( - absl::Span configs, - absl::Span bytecode, Qnn_ProfileHandle_t profile_handle); - - // - // Context Binary - // - - // Generates QNN context binary from current context. Writes to given - // buffer. - LiteRtStatus GenerateContextBinary(Qnn_ContextHandle_t context_handle, - std::vector& buffer); - - LiteRtStatus ValidateOp(const Qnn_OpConfig_t& op_config); - - bool IsLegacySocModel() { return soc_model_ == QNN_HTP_DEVICE_ARCH_V68; } - - private: - QnnManager() = default; - - LiteRtStatus Init(absl::Span configs, - std::optional shared_library_dir, - std::optional soc_model); - - // - // Manage libQnn*.so Loading - // - - // Loads the libQnn*.so at given path. - LiteRtStatus LoadLib(absl::string_view path); - - // Loads the libQnnSystem.so at given path. - LiteRtStatus LoadSystemLib(absl::string_view path); - - // - // Resolve and Access QNN SDK Functions - // - - // Resolve all available QNN SDK functions from (already) loaded so. If - // multiple providers are found, selects the first one with a suitable - // version. Fails if none can be found. - LiteRtStatus ResolveApi(); - - // Resolve all available QNN SDK functions from (already) loaded so. If - // multiple providers are found, selects the first one with a suitable - // version. Fails if none can be found. - LiteRtStatus ResolveSystemApi(); - - // Get qnn log handle. Nullptr if logCreate has not been successfully called. - Qnn_LogHandle_t& LogHandle() { return log_handle_; } - - // Get qnn backend handle. Nullptr if backendCreate has not been successfully - // called. - Qnn_BackendHandle_t& BackendHandle() { return backend_handle_; } - - // Get qnn device handle. Nullptr if deviceCreate has not been successfully - // called. - Qnn_DeviceHandle_t& DeviceHandle() { return device_handle_; } - - // Signal QNN SDK to free any memory related to the device. Does nothing - // if deviceCreate has not been called. - LiteRtStatus FreeDevice(); - - // Signal QNN SDK to free any memory related to logging. Does nothing - // if logCreate has not been called. - LiteRtStatus FreeLogging(); - - // Signal QNN SDK to free any memory related to backend. Does nothing - // if backendCreate has not been called. - LiteRtStatus FreeBackend(); - - // Handle to the shared library that implements the API. The library is - // released when the manager is destroyed. - SharedLibrary lib_; - - // Handle to the system shared library that implements the API. The library is - // released when the manager is destroyed. - SharedLibrary lib_system_; - - const QnnInterface_t* interface_ = nullptr; - const QnnSystemInterface_t* system_interface_ = nullptr; - - Qnn_LogHandle_t log_handle_ = nullptr; - Qnn_BackendHandle_t backend_handle_ = nullptr; - Qnn_DeviceHandle_t device_handle_ = nullptr; - QnnHtpDevice_Arch_t soc_model_ = QNN_HTP_DEVICE_ARCH_UNKNOWN; -}; - -// Unfortunately we can't use std::unique_ptr with a deleter because -// QnnContext_FreeFn_t takes a profile handle as a second argument. -class QnnManager::ContextHandle { - public: - ContextHandle(Qnn_ContextHandle_t context_handle, Qnn_ProfileHandle_t profile, - QnnContext_FreeFn_t free_fn) - : context_handle_(context_handle), profile_(profile), free_fn_(free_fn) {} - - ~ContextHandle() { - if (context_handle_ && free_fn_) { - free_fn_(context_handle_, profile_); - } - } - - ContextHandle(ContextHandle&& other) { *this = std::move(other); } - - ContextHandle(const ContextHandle& other) = delete; - - ContextHandle& operator=(ContextHandle&& other) { - std::swap(context_handle_, other.context_handle_); - std::swap(profile_, other.profile_); - std::swap(free_fn_, other.free_fn_); - return *this; - } - - ContextHandle& operator=(const ContextHandle& other) = delete; - - Qnn_ContextHandle_t get() const noexcept { return context_handle_; } - explicit operator bool() const noexcept { return context_handle_ != nullptr; } - - private: - Qnn_ContextHandle_t context_handle_ = nullptr; - Qnn_ProfileHandle_t profile_ = nullptr; - QnnContext_FreeFn_t free_fn_ = nullptr; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_MANAGER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc deleted file mode 100644 index 742af4f508dd64..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -#include - -#include -#include -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h" - -namespace { - -using ::litert::qnn::QnnManager; -using ::litert::qnn::internal::Dump; -using ::testing::HasSubstr; - -// NOTE: This tests that all of the dynamic loading works properly and -// the QNN SDK instance can be properly initialized and destroyed. - -TEST(QnnManagerTest, SetupQnnManager) { - auto configs = QnnManager::DefaultBackendConfigs(); - auto qnn = QnnManager::Create(configs); - ASSERT_TRUE(qnn); -} - -TEST(QnnManagerTest, Dump) { - auto configs = QnnManager::DefaultBackendConfigs(); - auto qnn = QnnManager::Create(configs); - ASSERT_TRUE(qnn); - - std::ostringstream dump; - Dump(**qnn, dump); - - EXPECT_THAT(dump.str(), HasSubstr("< QnnInterface_t >")); - EXPECT_THAT(dump.str(), HasSubstr("< QnnSystemInterface_t >")); -} - -} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.cc deleted file mode 100644 index 557a5d2f9ed56c..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h" - -#include -#include -#include - -#include "absl/log/absl_check.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert { -namespace qnn { - -QnnTensor::QnnTensor(const QnnTensor& other) : QnnTensor(other.Tensor()) { - auto status = DeepCopy(); - // This should never fail because the input QnnTensor was already deep-copied. - if (!status) { - LITERT_LOG(LITERT_ERROR, "Failed to build QnnTensor: %s", - status.Error().Message().c_str()); - ABSL_CHECK(status); - } -} - -QnnTensor::QnnTensor(QnnTensor&& other) { - tensor_ = other.tensor_; - // Swap managed memory. - std::swap(name_, other.name_); - std::swap(dimensions_, other.dimensions_); - std::swap(is_dynamic_dimensions_, other.is_dynamic_dimensions_); -} - -Expected QnnTensor::Create(const Qnn_Tensor_t& tensor) { - QnnTensor qnn_tensor(tensor); - if (auto status = qnn_tensor.DeepCopy(); !status) { - return Unexpected(status.Error()); - } - return qnn_tensor; -} - -Expected QnnTensor::DeepCopy() { - if (tensor_.version == QNN_TENSOR_VERSION_1) { - dimensions_.reserve(tensor_.v1.rank); - std::copy(tensor_.v1.dimensions, tensor_.v1.dimensions + tensor_.v1.rank, - std::back_inserter(dimensions_)); - tensor_.v1.dimensions = dimensions_.data(); - - // FIXME: Implement deep copy for quantizeParams. - if (tensor_.v1.quantizeParams.quantizationEncoding == - QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION || - tensor_.v1.quantizeParams.quantizationEncoding == - QNN_QUANTIZATION_ENCODING_VECTOR) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported QNN quantization"); - } - - } else if (tensor_.version == QNN_TENSOR_VERSION_2) { - dimensions_.reserve(tensor_.v2.rank); - std::copy(tensor_.v2.dimensions, tensor_.v2.dimensions + tensor_.v2.rank, - std::back_inserter(dimensions_)); - tensor_.v2.dimensions = dimensions_.data(); - - if (tensor_.v2.isDynamicDimensions) { - is_dynamic_dimensions_.reserve(tensor_.v2.rank); - std::copy(tensor_.v2.isDynamicDimensions, - tensor_.v2.isDynamicDimensions + tensor_.v2.rank, - std::back_inserter(is_dynamic_dimensions_)); - tensor_.v2.isDynamicDimensions = is_dynamic_dimensions_.data(); - } - - // FIXME: Implement deep copy for quantizeParams. - if (tensor_.v2.quantizeParams.quantizationEncoding == - QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION || - tensor_.v2.quantizeParams.quantizationEncoding == - QNN_QUANTIZATION_ENCODING_VECTOR) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported QNN quantization"); - } - - } else { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Unsupported QNN tensor version"); - } - - return {}; -} - -} // namespace qnn -} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h deleted file mode 100644 index c0429ce01864e5..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_TENSOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_TENSOR_H_ - -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" - -namespace litert::qnn { - -class QnnTensor { - public: - static Expected Create(const Qnn_Tensor_t& tensor); - - QnnTensor(const QnnTensor& other); - QnnTensor(QnnTensor&& other); - - QnnTensor& operator=(const QnnTensor&) = delete; - QnnTensor& operator=(QnnTensor&&) = delete; - - Qnn_Tensor_t& Tensor() { return tensor_; } - const Qnn_Tensor_t& Tensor() const { return tensor_; } - - size_t Rank() const { return dimensions_.size(); } - const uint32_t* Dimensions() const { return dimensions_.data(); } - - private: - explicit QnnTensor(const Qnn_Tensor_t& tensor) : tensor_(tensor) {} - Expected DeepCopy(); - - Qnn_Tensor_t tensor_; - std::string name_; - std::vector dimensions_; - std::vector is_dynamic_dimensions_; -}; - -} // namespace litert::qnn - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_TENSOR_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl b/tensorflow/lite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl deleted file mode 100644 index d4c9c70db3674e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Build definitions for QualComm backend.""" - -load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "append_rule_kwargs", "litert_bin", "litert_lib", "make_rpaths") - -_QNN_LIBCC_X86_64 = [ - # copybara:uncomment_begin(google-only) - # "//third_party/qairt/latest:lib/x86_64-linux-clang/libc++.so.1", - # "//third_party/qairt/latest:lib/x86_64-linux-clang/libc++abi.so.1", - # copybara:uncomment_end -] # @unused - -# TODO: Make rpaths dynamic with "$(location {})". -_QNN_LIB_RPATHS_X86_64 = [ - # copybara:uncomment_begin(google-only) - # "third_party/qairt/latest/lib/x86_64-linux-clang", - # copybara:uncomment_end -] - -_QNN_LIB_HTP_X86_64 = [ - # copybara:uncomment_begin(google-only) - # "//third_party/qairt/latest:lib/x86_64-linux-clang/libQnnHtp.so", - # copybara:uncomment_end -] - -_QNN_LIB_SYSTEM_X86_64 = [ - # copybara:uncomment_begin(google-only) - # "//third_party/qairt/latest:lib/x86_64-linux-clang/libQnnSystem.so", - # copybara:uncomment_end -] - -def _litert_with_qnn_base( - litert_rule, - backend, - include_system, - use_custom_libcc, - **litert_rule_kwargs): - if backend != "htp": - fail("Only htp currently supported") - - if use_custom_libcc: - # TODO: Figure out strategy for custom libcc. - fail("Custom libcc not yet supported") - - data_x86_64 = [] - data_x86_64.extend(_QNN_LIB_HTP_X86_64) - if include_system: - data_x86_64.extend(_QNN_LIB_SYSTEM_X86_64) - data = select({ - "//tensorflow:linux_x86_64": data_x86_64, - "//conditions:default": [], - }) - - append_rule_kwargs( - litert_rule_kwargs, - data = data, - linkopts = select({ - "//tensorflow:linux_x86_64": [make_rpaths(_QNN_LIB_RPATHS_X86_64)], - "//conditions:default": [], - }), - ) - - litert_rule(**litert_rule_kwargs) - -def litert_cc_lib_with_qnn( - backend = "htp", - include_system = False, - use_custom_libcc = False, - **litert_lib_kwargs): - """Creates a litert_lib target with QualComm backend dependencies. - - Args: - backend: The backend to use. Currently only "htp" is supported. - include_system: Whether to include libQnnSystem.so. - use_custom_libcc: Whether to use a custom libcc. Not yet supported. - **litert_lib_kwargs: Keyword arguments passed to litert_lib. - """ - _litert_with_qnn_base( - litert_lib, - backend, - include_system, - use_custom_libcc, - **litert_lib_kwargs - ) - -def litert_cc_bin_with_qnn( - backend = "htp", - include_system = False, - use_custom_libcc = False, - **litert_bin_kwargs): - """Creates a litert_bin target with QualComm backend dependencies. - - Args: - backend: The backend to use. Currently only "htp" is supported. - include_system: Whether to include libQnnSystem.so. - use_custom_libcc: Whether to use a custom libcc. Not yet supported. - **litert_bin_kwargs: Keyword arguments passed to litert_bin. - """ - _litert_with_qnn_base( - litert_bin, - backend, - include_system, - use_custom_libcc, - **litert_bin_kwargs - ) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/supported_soc.csv b/tensorflow/lite/experimental/litert/vendors/qualcomm/supported_soc.csv deleted file mode 100644 index 52b7f881f47a3e..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/supported_soc.csv +++ /dev/null @@ -1,41 +0,0 @@ -# manufacturer,model,runtime_library_version,soc_model -Qualcomm,SM8750,v79,69 -Qualcomm,SM8650,v75,57 -Qualcomm,SM8635,v73,68 -Qualcomm,SM8550,v73,43 -Qualcomm,SM7675,v73,70 -Qualcomm,SM7550,v73, -Qualcomm,SM7435,v73, -Qualcomm,SM6450,v73,50 -Qualcomm,QCM8550LA,v73,66 -Qualcomm,QCM8550LE,v73,66 -Qualcomm,SM8475,v69,42 -Qualcomm,SM8450,v69,36 -Qualcomm,SM7475,v69,54 -Qualcomm,SM7450,v69,41 -Qualcomm,SM7425,v69, -Qualcomm,SXR2230P,v69,53 -Qualcomm,SXR2250P,v69, -Qualcomm,SM8350,v68,30 -Qualcomm,SM8350P,v68,30 -Qualcomm,SM7350,v68,32 -Qualcomm,SM7325,v68,35 -Qualcomm,SM7315,v68,38 -Qualcomm,QCM6490,v68,35 -Qualcomm,SM8250,v66,21 -Qualcomm,SM8150,v66, -Qualcomm,SM7250,v66,25 -Qualcomm,SM7225,v66,29 -Qualcomm,SM7125,v66, -Qualcomm,SM6350,v66,29 -Qualcomm,SM6225,v66,40 -Qualcomm,SM6150,v66, -Qualcomm,SM6125,v66, -Qualcomm,SM4350,v66,31 -Qualcomm,QRB5165U,v66,21 -Qualcomm,QRB5165LE,v66,21 -Qualcomm,QCS7230LA,v66,51 -Qualcomm,QCS7230LE,v66,51 -Qualcomm,SM6375,v66,31 -Qualcomm,SM7150,v65, -Qualcomm,SDM845,v65, diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/BUILD deleted file mode 100644 index 45df0fef3b5a21..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], -) - -cc_library( - name = "dump", - srcs = ["dump.cc"], - hdrs = ["dump.h"], - tags = ["nobuilder"], - deps = [ - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager_hdr", - ], -) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.cc deleted file mode 100644 index 0e94b6b0385890..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.cc +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h" - -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "third_party/qairt/latest/include/QNN/QnnInterface.h" -#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn::internal { -namespace { - -static constexpr absl::string_view kNullDumpTpl = "%s : nullptr\n"; - -void Dump(const QnnInterface_t* interface, std::ostream& out) { - static constexpr absl::string_view kQnnInterfaceHeader = "< QnnInterface_t >"; - // NOLINTBEGIN - static constexpr absl::string_view kQnnInterfaceDumpTpl = - "\ - %s\n\ - name: %s\n\ - backend_id: %u\n\ - core_api_version: %u.%u.%u\n\ - backend_api_version: %u.%u.%u\n"; - // NOLINTEND - - if (interface == nullptr) { - out << absl::StreamFormat(kNullDumpTpl, kQnnInterfaceHeader); - return; - } - - const auto core_version = interface->apiVersion.coreApiVersion; - const auto backend_version = interface->apiVersion.backendApiVersion; - - out << absl::StreamFormat(kQnnInterfaceDumpTpl, kQnnInterfaceHeader, - interface->providerName, interface->backendId, - core_version.major, core_version.minor, - core_version.patch, backend_version.major, - backend_version.minor, backend_version.patch); -} - -void Dump(const QnnSystemInterface_t* interface, std::ostream& out) { - static constexpr absl::string_view kQnnSystemInterfaceHeader = - "< QnnSystemInterface_t >"; - // NOLINTBEGIN - static constexpr absl::string_view kQnnSystemInterfaceDumpTpl = - "\ - %s\n\ - name: %s\n\ - backend_id: %u\n\ - system_api_version: %u.%u.%u\n"; - // NOLINTEND - - if (interface == nullptr) { - out << absl::StreamFormat(kNullDumpTpl, kQnnSystemInterfaceHeader); - return; - } - - const auto system_version = interface->systemApiVersion; - - out << absl::StreamFormat(kQnnSystemInterfaceDumpTpl, - kQnnSystemInterfaceHeader, interface->providerName, - interface->backendId, system_version.major, - system_version.minor, system_version.patch); -} - -} // namespace - -void Dump(const QnnManager& qnn, std::ostream& out) { - Dump(qnn.interface_, out); - Dump(qnn.system_interface_, out); -} -} // namespace litert::qnn::internal diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h deleted file mode 100644 index b64650249af0af..00000000000000 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_TOOLS_DUMP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_TOOLS_DUMP_H_ - -#include -#include - -#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" - -namespace litert::qnn::internal { - -void Dump(const QnnManager& qnn, std::ostream& out = std::cerr); - -} - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_TOOLS_DUMP_H_ diff --git a/tensorflow/lite/experimental/microfrontend/lib/fft.cc b/tensorflow/lite/experimental/microfrontend/lib/fft.cc index 8a107e2b492ef5..0f30eec49167f5 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/fft.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/fft.cc @@ -16,6 +16,8 @@ limitations under the License. #include +#include + #define FIXED_POINT 16 #include "kiss_fft.h" #include "kiss_fftr.h" diff --git a/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc b/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc index b913f3c0365eb5..18e0d36b53d7d0 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc @@ -16,6 +16,9 @@ limitations under the License. #include +#include +#include + #define FIXED_POINT 16 #include "kiss_fft.h" #include "kiss_fftr.h" diff --git a/tensorflow/lite/experimental/resource/resource_variable.cc b/tensorflow/lite/experimental/resource/resource_variable.cc index efa8ac979cb543..ff34db90d43069 100644 --- a/tensorflow/lite/experimental/resource/resource_variable.cc +++ b/tensorflow/lite/experimental/resource/resource_variable.cc @@ -77,7 +77,9 @@ TfLiteStatus ResourceVariable::AssignFrom(const TfLiteTensor* tensor) { tensor_.bytes = old_bytes; } - memcpy(tensor_.data.raw, tensor->data.raw, tensor_.bytes); + if (tensor->data.raw) { + memcpy(tensor_.data.raw, tensor->data.raw, tensor_.bytes); + } is_initialized_ = true; return kTfLiteOk; diff --git a/tensorflow/lite/g3doc/android/delegates/gpu_native.md b/tensorflow/lite/g3doc/android/delegates/gpu_native.md index 2221c2066f9cb6..87ecb7291cec54 100644 --- a/tensorflow/lite/g3doc/android/delegates/gpu_native.md +++ b/tensorflow/lite/g3doc/android/delegates/gpu_native.md @@ -1,20 +1,20 @@ # GPU acceleration delegate with C/C++ API Using graphics processing units (GPUs) to run your machine learning (ML) models -can dramatically improve the performance and the user experience -of your ML-enabled applications. On Android devices, you can enable -GPU-accelerated execution of your models using a -[*delegate*](../../performance/delegates) and one of the following APIs: +can dramatically improve the performance and the user experience of your +ML-enabled applications. On Android devices, you can enable GPU-accelerated +execution of your models using a +[*delegate*](https://ai.google.dev/edge/litert/performance/delegates) and one of +the following APIs: -- Interpreter API - [guide](./gpu) -- Task library API - [guide](./gpu_task) -- Native (C/C++) API - this guide +- Interpreter API - [guide](./gpu) +- Task library API - [guide](./gpu_task.md) +- Native (C/C++) API - this guide -This guide covers advanced -uses of the GPU delegate for the C API, C++ API, and use of quantized models. -For more information about using the GPU delegate for TensorFlow Lite, -including best practices and advanced techniques, see the -[GPU delegates](../../performance/gpu) page. +This guide covers advanced uses of the GPU delegate for the C API, C++ API, and +use of quantized models. For more information about using the GPU delegate for +TensorFlow Lite, including best practices and advanced techniques, see the +[GPU delegates](https://ai.google.dev/edge/litert/performance/gpu) page. ## Enable GPU acceleration @@ -65,9 +65,10 @@ thread in which `Interpreter::ModifyGraphWithDelegate()` was called. #### With TensorFlow Lite in Google Play Services: -If you are using TensorFlow Lite in Google Play Services [C API](../native), -you’ll need to use the Java/Kotlin API to check if a GPU delegate is available -for your device before initializing the TensorFlow Lite runtime. +If you are using TensorFlow Lite in Google Play Services +[C API](https://ai.google.dev/edge/litert/android/native), you’ll need to use +the Java/Kotlin API to check if a GPU delegate is available for your device +before initializing the TensorFlow Lite runtime. Add the GPU delegate gradle dependencies to your application: @@ -170,5 +171,6 @@ if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return false; -For more information about running quantized models with GPU acceleration, -see [GPU delegate](../../performance/gpu#quantized-models) overview. \ No newline at end of file +For more information about running quantized models with GPU acceleration, see +[GPU delegate](https://ai.google.dev/edge/litert/performance/gpu#quantized_models) +overview. diff --git a/tensorflow/lite/g3doc/android/tutorials/object_detection.md b/tensorflow/lite/g3doc/android/tutorials/object_detection.md index e9ecc441d651f3..640b082311970b 100644 --- a/tensorflow/lite/g3doc/android/tutorials/object_detection.md +++ b/tensorflow/lite/g3doc/android/tutorials/object_detection.md @@ -147,7 +147,7 @@ convert data such as images, into a tensor data format that can be processed by the model you are using. The example app uses the TensorFlow Lite -[Task library for vision](../../inference_with_metadata/task_library/overview#supported_tasks) +[Task library for vision](../../inference_with_metadata/task_library/overview.md#supported-tasks) to enable execution of the object detection machine learning model. The following instructions explain how to add the required library dependencies to your own Android app project. @@ -263,7 +263,7 @@ device, such as Graphics Processing Units (GPUs), Tensor Processing Units TensorFlow Lite models is recommended, but not required. The object detector is initialized using the current settings on the thread that -is using it. You can use CPU and [NNAPI](../../android/delegates/nnapi) +is using it. You can use CPU and [NNAPI](../../android/delegates/nnapi.md) delegates with detectors that are created on the main thread and used on a background thread, but the thread that initialized the detector must use the GPU delegate. @@ -290,7 +290,7 @@ when (currentDelegate) { ``` For more information about using hardware acceleration delegates with TensorFlow -Lite, see [TensorFlow Lite Delegates](../../performance/delegates). +Lite, see [TensorFlow Lite Delegates](../../performance/delegates.md). ## Prepare data for the model diff --git a/tensorflow/lite/g3doc/android/tutorials/text_classification.md b/tensorflow/lite/g3doc/android/tutorials/text_classification.md index 19baf52bb17ac4..7b614894a4643b 100644 --- a/tensorflow/lite/g3doc/android/tutorials/text_classification.md +++ b/tensorflow/lite/g3doc/android/tutorials/text_classification.md @@ -7,7 +7,7 @@ physical Android device but can also run on a device emulator. The [example application](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification/android) uses TensorFlow Lite to classify text as either positive or negative, using the -[Task library for natural language (NL)](../../inference_with_metadata/task_library/overview#supported_tasks) +[Task library for natural language (NL)](https://ai.google.dev/edge/litert/libraries/task_library/overview) to enable execution of the text classification machine learning models. If you are updating an existing project, you can use the example application as @@ -31,7 +31,7 @@ text being correctly classified as either positive or negative. For more information on how the models in this tutorial are generated, refer to the -[Text classification with TensorFlow Lite Model Maker](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification) +[Text classification with TensorFlow Lite Model Maker](https://ai.google.dev/edge/litert/libraries/modify/text_classification) tutorial. ## Models and dataset @@ -41,7 +41,7 @@ This tutorial uses models that were trained using the Treebank) dataset. SST-2 contains 67,349 movie reviews for training and 872 movie reviews for testing, with each review categorized as either positive or negative. The models used in this app were trained using the TensorFlow Lite -[Model Maker](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification) +[Model Maker](https://ai.google.dev/edge/litert/libraries/modify/text_classification) tool. The example application uses the following pre-trained models: @@ -149,10 +149,10 @@ implement text classification features to your production applications: ## How the example app works {:#how_it_works} The application uses the -[Task library for natural language (NL)](../../inference_with_metadata/task_library/overview#supported_tasks) +[Task library for natural language (NL)](https://ai.google.dev/edge/litert/libraries/task_library/overview) package to implement the text classification models. The two models, Average Word Vector and MobileBERT, were trained using the TensorFlow Lite -[Model Maker](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification). +[Model Maker](https://ai.google.dev/edge/litert/libraries/modify/text_classification). The application runs on CPU by default, with the option of hardware acceleration using the NNAPI delegate. @@ -237,10 +237,10 @@ model with parameters before running predictions with the model. A TensorFlow Lite model is stored as a `*.tflite` file. The model file contains the prediction logic and typically includes -[metadata](../../models/convert/metadata) about how to interpret prediction -results, such as prediction class names. Typically, model files are stored in -the `src/main/assets` directory of your development project, as in the code -example: +[metadata](https://ai.google.dev/edge/litert/models/metadata) about how to +interpret prediction results, such as prediction class names. Typically, model +files are stored in the `src/main/assets` directory of your development project, +as in the code example: - `/src/main/assets/mobilebert.tflite` - `/src/main/assets/wordvec.tflite` @@ -475,7 +475,7 @@ user interface. ## Next steps * Train and implement the models from scratch with the - [Text classification with TensorFlow Lite Model Maker](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification) + [Text classification with TensorFlow Lite Model Maker](https://ai.google.dev/edge/litert/libraries/modify/text_classification) tutorial. * Explore more [text processing tools for TensorFlow](https://www.tensorflow.org/text). diff --git a/tensorflow/lite/graph_info_test.cc b/tensorflow/lite/graph_info_test.cc index 4ff1efcfd2d6f4..38255e081862ac 100644 --- a/tensorflow/lite/graph_info_test.cc +++ b/tensorflow/lite/graph_info_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include diff --git a/tensorflow/lite/java/jni/BUILD b/tensorflow/lite/java/jni/BUILD index 137ca32b0489d5..9638ee99882514 100644 --- a/tensorflow/lite/java/jni/BUILD +++ b/tensorflow/lite/java/jni/BUILD @@ -1,48 +1,33 @@ -package(default_visibility = ["//tensorflow/lite:__subpackages__"]) +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") -licenses(["notice"]) # Apache 2.0 +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite:__subpackages__"], + licenses = ["notice"], +) -# Helper target for exposing JNI headers across multiple platforms. -cc_library( +# We need special handling for JNI inclusion for Android. Rather than duplicating this logic +# for every target that uses JNI, we use a single proxy target that +# encapsulates it. +alias( name = "jni", - hdrs = select({ - # The Android toolchain makes "jni.h" available in the include path. - # For non-Android toolchains, generate jni.h and jni_md.h. - "//tensorflow:android": [], - "//conditions:default": [ - ":jni.h", - ":jni_md.h", - ], - }), - includes = select({ - "//tensorflow:android": [], - "//conditions:default": ["."], + actual = select({ + # The Android toolchain makes available in the system include + # path. + # Aliases need to resolve to a single target however, so alias to an + # empty library instead. + # (Making this target a cc_library with empty deps for the Android case + # doesn't work, because go/cpp-features#layering-check requires targets + # to _directly_ depend on libraries they include, and cc_library doesn't + # have any direct equivalent to java_library's 'export' attribute). + "//tensorflow:android": ":empty", + # For non-Android toolchains, depend on the JDK JNI headers. + "//conditions:default": "@bazel_tools//tools/jdk:jni", }), visibility = ["//visibility:public"], ) -# Silly rules to make -# #include -# in the source headers work -# (in combination with the "includes" attribute of the tf_cuda_library rule -# above. Not needed when using the Android toolchain). -# -# Inspired from: -# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD -# but hopefully there is a simpler alternative to this. -genrule( - name = "copy_jni_h", - srcs = ["@bazel_tools//tools/jdk:jni_header"], - outs = ["jni.h"], - cmd = "cp -f $< $@", -) - -genrule( - name = "copy_jni_md_h", - srcs = select({ - "//tensorflow:macos": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], - "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"], - }), - outs = ["jni_md.h"], - cmd = "cp -f $< $@", +cc_library( + name = "empty", + compatible_with = get_compatible_with_portable(), ) diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index e5c263ea6b563e..0cce95049f1adc 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -1499,6 +1499,7 @@ cc_test( ":test_util", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", + "@eigen_archive//:eigen3", "@flatbuffers", ], ) diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index 40f3b812825497..6472da7ca6601b 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" -#include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/lstm_eval.h" @@ -320,7 +319,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes( const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(context, node, input_gate_bias_tensor); if (use_cifg) { - TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + TF_LITE_ENSURE(context, input_gate_bias == nullptr); } else { TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); diff --git a/tensorflow/lite/kernels/concatenation.cc b/tensorflow/lite/kernels/concatenation.cc index 4f116440fd2049..fbf153b90d7327 100644 --- a/tensorflow/lite/kernels/concatenation.cc +++ b/tensorflow/lite/kernels/concatenation.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "Eigen/Core" // from @eigen_archive #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" @@ -91,6 +92,12 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, int axis, case kTfLiteFloat32: TF_LITE_CONCATENATION(float); break; + case kTfLiteFloat16: + TF_LITE_CONCATENATION(Eigen::half); + break; + case kTfLiteBFloat16: + TF_LITE_CONCATENATION(Eigen::bfloat16); + break; case kTfLiteInt32: TF_LITE_CONCATENATION(int32); break; @@ -142,10 +149,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); TF_LITE_ENSURE(context, - input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || - input_type == kTfLiteInt8 || input_type == kTfLiteInt16 || - input_type == kTfLiteInt32 || input_type == kTfLiteInt64 || - input_type == kTfLiteBool || input_type == kTfLiteUInt32); + input_type == kTfLiteFloat32 || input_type == kTfLiteFloat16 || + input_type == kTfLiteBFloat16 || + input_type == kTfLiteUInt8 || input_type == kTfLiteInt8 || + input_type == kTfLiteInt16 || input_type == kTfLiteInt32 || + input_type == kTfLiteInt64 || input_type == kTfLiteBool || + input_type == kTfLiteUInt32); // Check to see if we can calculate the output now. bool all_inputs_at_prepare = true; diff --git a/tensorflow/lite/kernels/concatenation_test.cc b/tensorflow/lite/kernels/concatenation_test.cc index 685abd5d5e7569..28692ae1528dd3 100644 --- a/tensorflow/lite/kernels/concatenation_test.cc +++ b/tensorflow/lite/kernels/concatenation_test.cc @@ -108,6 +108,29 @@ TEST(ConcatenationOpTest, ThreeDimensionalOneInput) { EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7})); } +TEST(ConcatenationOpTest, ThreeDimensionalOneInputBFloat16) { + ConcatenationOpModel m({TensorType_BFLOAT16, {2, 1, 2}}, + /*axis=*/1, + /*num_inputs=*/1); + m.SetInput( + 0, + {static_cast(1.0f), static_cast(3.0f), + static_cast(4.0f), static_cast(7.0f)}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 4, 7})); +} + +TEST(ConcatenationOpTest, ThreeDimensionalOneInputFloat16) { + ConcatenationOpModel m({TensorType_FLOAT16, {2, 1, 2}}, + /*axis=*/1, + /*num_inputs=*/1); + m.SetInput(0, + {static_cast(1.0f), static_cast(3.0f), + static_cast(4.0f), static_cast(7.0f)}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 4, 7})); +} + TEST(ConcatenationOpTest, ThreeDimensionalOneInputUInt32) { ConcatenationOpModel m0({TensorType_UINT32, {2, 1, 2}}, /*axis=*/1, /*num_inputs=*/1); @@ -152,6 +175,61 @@ TEST(ConcatenationOpTest, FiveDimensionalTwoInput) { 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24})); } +TEST(ConcatenationOpTest, FiveDimensionalTwoInputBFloat16) { + ConcatenationOpModel m( + {TensorType_BFLOAT16, {2, 1, 2, 1, 3}}, + /*axis=*/0, + /*num_inputs=*/2); + m.SetInput( + 0, + {static_cast(1.0f), static_cast(2.0f), + static_cast(3.0f), static_cast(4.0f), + static_cast(5.0f), static_cast(6.0f), + static_cast(7.0f), static_cast(8.0f), + static_cast(9.0f), static_cast(10.0f), + static_cast(11.0f), + static_cast(12.0f)}); + m.SetInput( + 1, + {static_cast(13.0f), static_cast(14.0f), + static_cast(15.0f), Eigen::bfloat16{16.0f}, + static_cast(17.0f), static_cast(18.0f), + static_cast(19.0f), static_cast(20.0f), + static_cast(21.0f), static_cast(22.0f), + static_cast(23.0f), + static_cast(24.0f)}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24})); +} + +TEST(ConcatenationOpTest, FiveDimensionalTwoInputFloat16) { + ConcatenationOpModel m({TensorType_FLOAT16, {2, 1, 2, 1, 3}}, + /*axis=*/0, + /*num_inputs=*/2); + m.SetInput( + 0, {static_cast(1.0f), static_cast(2.0f), + static_cast(3.0f), static_cast(4.0f), + static_cast(5.0f), static_cast(6.0f), + static_cast(7.0f), Eigen::half{8.0f}, + static_cast(9.0f), static_cast(10.0f), + static_cast(11.0f), static_cast(12.0f)}); + m.SetInput( + 1, {static_cast(13.0f), static_cast(14.0f), + Eigen::half{15.0f}, static_cast(16.0f), + Eigen::half{17.0f}, static_cast(18.0f), + static_cast(19.0f), static_cast(20.0f), + static_cast(21.0f), static_cast(22.0f), + static_cast(23.0f), static_cast(24.0f)}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24})); +} + TEST(ConcatenationOpTest, FiveDimensionalTwoInputUInt32) { ConcatenationOpModel m0({TensorType_UINT32, {2, 1, 2, 1, 3}}, /*axis=*/0, diff --git a/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc b/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc index 05f242397f44d9..5046ace8c07eb7 100644 --- a/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc +++ b/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include +#include +#include +#include #include #include "flatbuffers/flexbuffers.h" // from @flatbuffers diff --git a/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder_test.cc b/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder_test.cc index 0ffacfbae80bc8..3e3591f6471f02 100644 --- a/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder_test.cc +++ b/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder_test.cc @@ -14,9 +14,10 @@ limitations under the License. ==============================================================================*/ #include -#include +#include #include +#include #include #include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/core/interpreter.h" diff --git a/tensorflow/lite/kernels/dynamic_update_slice.cc b/tensorflow/lite/kernels/dynamic_update_slice.cc index 776379058cc1ff..5c5cbcd8f963b4 100644 --- a/tensorflow/lite/kernels/dynamic_update_slice.cc +++ b/tensorflow/lite/kernels/dynamic_update_slice.cc @@ -219,6 +219,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { DynamicUpdateSlice(operand, update, indices_data_i64.data(), output); break; + case kTfLiteInt16: + DynamicUpdateSlice(operand, update, indices_data_i64.data(), + output); + break; case kTfLiteInt32: DynamicUpdateSlice(operand, update, indices_data_i64.data(), output); diff --git a/tensorflow/lite/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/kernels/dynamic_update_slice_test.cc index 867f2b9b8cc029..373a719d5ac412 100644 --- a/tensorflow/lite/kernels/dynamic_update_slice_test.cc +++ b/tensorflow/lite/kernels/dynamic_update_slice_test.cc @@ -177,6 +177,21 @@ TEST(DynamicUpdateSliceOpTest, SimpleTestI8) { 7, -2, 9})); } +TEST(DynamicUpdateSliceOpTest, SimpleTestI16) { + DynamicUpdateSliceOpModel m({TensorType_INT16, {3, 3}}, + {TensorType_INT16, {2, 1}}, + {TensorType_INT32, {2}}); + m.SetInput({1, 2, 3, // + 4, 5, 6, // + 7, 8, 9}); + m.SetUpdate({-1, -2}); + m.SetStartIndices({1, 1}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, // + 4, -1, 6, // + 7, -2, 9})); +} + TEST(DynamicUpdateSliceOpTest, SimpleTestI32) { DynamicUpdateSliceOpModel m({TensorType_INT32, {3, 3}}, {TensorType_INT32, {2, 1}}, diff --git a/tensorflow/lite/kernels/embedding_lookup.cc b/tensorflow/lite/kernels/embedding_lookup.cc index d92701059822f6..158bd8c63bf932 100644 --- a/tensorflow/lite/kernels/embedding_lookup.cc +++ b/tensorflow/lite/kernels/embedding_lookup.cc @@ -29,8 +29,8 @@ limitations under the License. // When indices are out of bound, the ops will not succeed. // -#include - +#include +#include #include #include "tensorflow/lite/c/c_api_types.h" @@ -104,17 +104,17 @@ TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node, // Propagate empty tensor if input is empty return kTfLiteOk; } - const int64_t row_bytes = value->bytes / row_size; + const size_t row_bytes = value->bytes / row_size; char* output_raw = GetTensorData(output); const char* value_raw = GetTensorData(value); const int32_t* lookup_data = GetTensorData(lookup); for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { - int64_t idx = lookup_data[i]; + const int32_t idx = lookup_data[i]; if (idx >= row_size || idx < 0) { TF_LITE_KERNEL_LOG(context, "Embedding Lookup: index out of bounds. " - "Got %d, and bounds are [0, %d]", + "Got %" PRId32 ", and bounds are [0, %d]", idx, row_size - 1); return kTfLiteError; } else { @@ -142,11 +142,11 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, const int32_t* lookup_data = GetTensorData(lookup); for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { - int idx = lookup_data[i]; + const int32_t idx = lookup_data[i]; if (idx >= row_size || idx < 0) { TF_LITE_KERNEL_LOG(context, "Embedding Lookup: index out of bounds. " - "Got %d, and bounds are [0, %d]", + "Got %" PRId32 ", and bounds are [0, %d]", idx, row_size - 1); return kTfLiteError; } else { diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index 8bfb045bc1b477..287cf22365d4b4 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" @@ -120,13 +122,15 @@ TfLiteStatus VerifyPerChannelQuantization(TfLiteContext* context, TfLiteStatus VerifyQuantizationZeroPoint(const TfLiteTensor* tensor, int expected_value) { - const auto* params = - reinterpret_cast(tensor->quantization.params); - if (params && params->zero_point && - std::any_of(params->zero_point->data, - params->zero_point->data + params->zero_point->size, - [expected_value](int v) { return v != expected_value; })) { - return kTfLiteError; + if (tensor->quantization.type == kTfLiteAffineQuantization) { + const auto* params = reinterpret_cast( + tensor->quantization.params); + if (params && params->zero_point && + std::any_of(params->zero_point->data, + params->zero_point->data + params->zero_point->size, + [expected_value](int v) { return v != expected_value; })) { + return kTfLiteError; + } } return kTfLiteOk; } @@ -947,6 +951,82 @@ struct SparseHybridFullyConnectedTask : cpu_backend_threadpool::Task { TfLiteTensor* output; }; +inline int8_t SignExtendInt4(int8_t value) { return (value ^ 0x8) - 8; } + +TfLiteStatus EvalBlockwise4Bit( + TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* input_quantized, TfLiteTensor* scaling_factors, + TfLiteTensor* accum_scratch, TfLiteTensor* input_offsets, + TfLiteTensor* output) { + const auto quantization_params = + static_cast( + filter->quantization.params); + + const size_t blocksize = quantization_params->blocksize; + const size_t input_channels = filter->dims->data[1]; + const size_t output_channels = filter->dims->data[0]; + const size_t batch_size = NumElements(input) / input_channels; + const size_t num_blocks = input_channels / blocksize; + const TfLiteTensor& scale = context->tensors[quantization_params->scale]; + int num_scales = NumElements(&scale); + std::vector dequantized_scale(num_scales, 0); + const Eigen::half* half_data = reinterpret_cast( + GetTensorData(&scale)); + reference_ops::Dequantize(GetTensorShape(&scale), half_data, + GetTensorShape(&scale), dequantized_scale.data()); + float* output_ptr = GetTensorData(output); + memset(output_ptr, 0, NumElements(output) * sizeof(float)); + std::vector quant_data(NumElements(input)); + std::vector input_scales(batch_size); + std::vector input_zero_points(batch_size); + + const float* input_ptr = GetTensorData(input); + tensor_utils::BatchQuantizeFloats(input_ptr, batch_size, input_channels, + quant_data.data(), input_scales.data(), + input_zero_points.data(), + /*do_asymmetric=*/true); + + const float* bias_data = nullptr; + if (bias) { + bias_data = GetTensorData(bias); + } + const size_t k2 = (input_channels + 1) & 0xFFFFFFFFFFFFFFFE; + const uint8_t* kernel = GetTensorData(filter); + for (size_t mi = 0; mi < batch_size; mi++) { + for (size_t ni = 0; ni < output_channels; ni++) { + float kfsum = 0.0; + for (size_t bi = 0; bi < num_blocks; bi++) { + int32_t ksum = 0; + int32_t c_ref_acc = 0; + for (size_t ki = 0; ki < blocksize; ki++) { + const size_t k_index = bi * blocksize + ki; + const size_t nb_index = (ni * k2 + k_index) / 2; + const int8_t k_value = int8_t( + (k_index % 2 == 0) ? (kernel[nb_index] & static_cast(0xF)) + : (kernel[nb_index] >> 4)); + const int32_t kernel_value = SignExtendInt4(k_value); + ksum += kernel_value; + c_ref_acc += + static_cast(quant_data[mi * input_channels + k_index]) * + static_cast(kernel_value); + } + size_t scale_index = ni * num_blocks + bi; + float scale = dequantized_scale[scale_index]; + output_ptr[mi * output_channels + ni] += c_ref_acc * scale; + kfsum += scale * ksum; + } + output_ptr[mi * output_channels + ni] -= (input_zero_points[mi] * kfsum); + output_ptr[mi * output_channels + ni] *= input_scales[mi]; + if (bias_data != nullptr) { + output_ptr[mi * output_channels + ni] += bias_data[ni]; + } + } + } + return kTfLiteOk; +} + TfLiteStatus EvalHybridDense4Bit( TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input, @@ -1134,6 +1214,7 @@ void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input, op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), filter_data, GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), + input->params.scale, output->params.scale, filter->params.scale, GetTensorData(output)); } else { optimized_integer_ops::FullyConnected( @@ -1162,12 +1243,14 @@ void FullyConnectedInt16(const OpData* data, const TfLiteTensor* input, op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), filter_data, GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), + input->params.scale, output->params.scale, filter->params.scale, GetTensorData(output)); } else { reference_integer_ops::FullyConnected( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), filter_data, GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), + input->params.scale, output->params.scale, filter->params.scale, GetTensorData(output)); } } @@ -1191,12 +1274,16 @@ void FullyConnectedPerChannelInt8(const OpData* data, const TfLiteTensor* input, op_params.rhs_cacheable = IsConstantTensor(input); if (kernel_type == kReference) { + const auto* affine_quantization = + reinterpret_cast( + filter->quantization.params); + const float* filter_scales = affine_quantization->scale->data; reference_integer_ops::FullyConnectedPerChannel( - op_params, data->per_channel_output_multiplier.data(), - data->per_channel_output_shift.data(), GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), filter_data, - GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output)); + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), filter_data, GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + input->params.scale, output->params.scale, filter_scales, + GetTensorData(output)); } else { optimized_integer_ops::FullyConnectedPerChannel( op_params, data->per_channel_output_multiplier.data(), @@ -1220,21 +1307,24 @@ void FullyConnectedPerChannelInt16( op_params.output_offset = output->params.zero_point; op_params.quantized_activation_min = data->output_activation_min; op_params.quantized_activation_max = data->output_activation_max; + const auto* affine_quantization = + reinterpret_cast(filter->quantization.params); + const float* filter_scales = affine_quantization->scale->data; if (data->quantized_bias_type == kTfLiteInt32) { reference_integer_ops::FullyConnectedPerChannel( - op_params, data->per_channel_output_multiplier.data(), - data->per_channel_output_shift.data(), GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), filter_data, - GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output)); + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), filter_data, GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + input->params.scale, output->params.scale, filter_scales, + GetTensorData(output)); } else { reference_integer_ops::FullyConnectedPerChannel( - op_params, data->per_channel_output_multiplier.data(), - data->per_channel_output_shift.data(), GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), filter_data, - GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output)); + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), filter_data, GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + input->params.scale, output->params.scale, filter_scales, + GetTensorData(output)); } } @@ -1295,9 +1385,18 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, TF_LITE_ENSURE_OK( context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets)); if (data->op_data_4bit) { - return EvalHybridDense4Bit(context, node, params, data, input, filter, - bias, input_quantized, scaling_factors, - accum_scratch, input_offsets, output); + switch (filter->quantization.type) { + case kTfLiteAffineQuantization: + return EvalHybridDense4Bit(context, node, params, data, input, filter, + bias, input_quantized, scaling_factors, + accum_scratch, input_offsets, output); + case kTfLiteBlockwiseQuantization: + return EvalBlockwise4Bit(context, node, params, data, input, filter, + bias, input_quantized, scaling_factors, + accum_scratch, input_offsets, output); + default: + return kTfLiteError; + } } TfLiteTensor* row_sums; TF_LITE_ENSURE_OK(context, @@ -1324,7 +1423,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), GetTensorData(filter), GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output)); + GetTensorShape(output), input->params.scale, output->params.scale, + filter->params.scale, GetTensorData(output)); } else { optimized_ops::FullyConnected( op_params, GetTensorShape(input), GetTensorData(input), @@ -1445,7 +1545,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), GetTensorData(filter), GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output)); + GetTensorShape(output), input->params.scale, output->params.scale, + filter->params.scale, GetTensorData(output)); } else { optimized_ops::FullyConnected( op_params, GetTensorShape(input), GetTensorData(input), diff --git a/tensorflow/lite/kernels/if_test.cc b/tensorflow/lite/kernels/if_test.cc index 5fd734bba86b4d..cd34ca705965c8 100644 --- a/tensorflow/lite/kernels/if_test.cc +++ b/tensorflow/lite/kernels/if_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include @@ -170,10 +171,12 @@ TEST_F(IfTest, TestWithXNNPACK) { builder_->BuildFloatIfSubgraph(&interpreter_->primary_subgraph(), 3); const auto opt = TfLiteXNNPackDelegateOptionsDefault(); - TfLiteDelegate* xnnpack_delegate = TfLiteXNNPackDelegateCreate(&opt); + std::unique_ptr xnnpack_delegate( + TfLiteXNNPackDelegateCreate(&opt), TfLiteXNNPackDelegateDelete); interpreter_->primary_subgraph().MarkAsDelegationSkippable(); interpreter_->subgraph(1)->MarkAsDelegationSkippable(); - ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(xnnpack_delegate), kTfLiteOk); + ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(std::move(xnnpack_delegate)), + kTfLiteOk); ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), kTfLiteOk); ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), @@ -201,7 +204,6 @@ TEST_F(IfTest, TestWithXNNPACK) { interpreter_->typed_input_tensor(0)[0] = true; ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); - TfLiteXNNPackDelegateDelete(xnnpack_delegate); } TEST_F(IfTest, TestInputIsOutput) { diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 4be963cc8e3607..353cdcdf23e417 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -2,7 +2,7 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("//tensorflow:tensorflow.bzl", "transitive_hdrs") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/lite:build_def.bzl", "tflite_copts") -load("//tensorflow/lite:special_rules.bzl", "tflite_extra_arm_config_settings", "tflite_portable_test_suite_combined") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite_combined") load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") package( @@ -20,18 +20,11 @@ HARD_FP_FLAGS_IF_APPLICABLE = select({ }) NEON_FLAGS_IF_APPLICABLE = select({ - ":arm": [ - "-O3", - ], - ":armeabi-v7a": [ - "-O3", - "-mfpu=neon", - ], ":armhf": [ "-O3", "-mfpu=neon", ], - ":armv7a": [ + ":armv7": [ "-O3", "-mfpu=neon", ], @@ -107,147 +100,25 @@ config_setting( ], ) -config_setting( - name = "arm", - values = { - "cpu": "arm", - }, -) - -config_setting( - name = "arm64-v8a", - values = { - "cpu": "arm64-v8a", - }, -) - config_setting( name = "armhf", - values = { - "cpu": "armhf", - }, -) - -config_setting( - name = "armv7a", - values = { - "cpu": "armv7a", - }, -) - -config_setting( - name = "armeabi-v7a", - values = { - "cpu": "armeabi-v7a", - }, -) - -config_setting( - name = "haswell", - values = { - "cpu": "haswell", - }, -) - -config_setting( - name = "ios_armv7", - values = { - "cpu": "ios_armv7", - }, -) - -config_setting( - name = "ios_arm64", - values = { - "cpu": "ios_arm64", - }, -) - -config_setting( - name = "ios_arm64e", - values = { - "cpu": "ios_arm64e", - }, -) - -config_setting( - name = "ios_sim_arm64", - values = { - "cpu": "ios_sim_arm64", - }, -) - -config_setting( - name = "visionos_arm64", - values = { - "cpu": "visionos_arm64", - }, -) - -config_setting( - name = "visionos_sim_arm64", - values = { - "cpu": "visionos_sim_arm64", - }, -) - -config_setting( - name = "k8", - values = { - "cpu": "k8", - }, + constraint_values = [ + "@platforms//cpu:armv7e-mf", + ], ) config_setting( name = "x86", - values = { - "cpu": "x86", - }, + constraint_values = [ + "@platforms//cpu:x86_32", + ], ) config_setting( name = "x86_64", - values = { - "cpu": "x86_64", - }, -) - -config_setting( - name = "darwin", - values = { - "cpu": "darwin", - }, -) - -config_setting( - name = "darwin_arm64", - values = { - "cpu": "darwin_arm64", - }, -) - -config_setting( - name = "freebsd", - values = { - "cpu": "freebsd", - }, -) - -config_setting( - name = "windows", - values = { - "cpu": "x64_windows", - }, -) - -config_setting( - name = "raspberry_pi_with_neon", - define_values = { - "raspberry_pi_with_neon": "true", - }, - values = { - "cpu": "armeabi", - }, + constraint_values = [ + "@platforms//cpu:x86_64", + ], ) selects.config_setting_group( @@ -263,7 +134,7 @@ selects.config_setting_group( match_any = [ ":arm32_any", ":aarch64_any", - ] + tflite_extra_arm_config_settings(), + ], ) selects.config_setting_group( @@ -1450,10 +1321,7 @@ cc_test( srcs = ["optimized/avx2_quantization_utils_test.cc"], copts = select( { - ":haswell": [ - "-mavx2", - ], - ":k8": [ + ":x86_64": [ "-mavx2", ], "//conditions:default": [ diff --git a/tensorflow/lite/kernels/internal/reference/fully_connected.h b/tensorflow/lite/kernels/internal/reference/fully_connected.h index ba51cbcfe3e8a0..bccc6220062564 100644 --- a/tensorflow/lite/kernels/internal/reference/fully_connected.h +++ b/tensorflow/lite/kernels/internal/reference/fully_connected.h @@ -16,6 +16,8 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_ #include +#include +#include #include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" @@ -62,6 +64,59 @@ inline void FullyConnected( } } +// This implementation receives the scales in float and performs requant in +// float to avoid loss of precision. +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + float input_scale, float output_scale, float filter_scale, + uint8_t* output_data) { + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + // TODO(b/62193649): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int output_dim_count = output_shape.DimensionsCount(); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2, + output_shape, output_dim_count - 1); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + int32_t acc = 0; + for (int d = 0; d < accum_depth; ++d) { + int32_t input_val = input_data[b * accum_depth + d]; + int32_t filter_val = filter_data[out_c * accum_depth + d]; + acc += (filter_val + filter_offset) * (input_val + input_offset); + } + if (bias_data) { + acc += bias_data[out_c]; + } + const double effective_output_scale = static_cast(input_scale) * + static_cast(filter_scale) / + static_cast(output_scale); + int32_t acc_scaled = static_cast( + round(static_cast(acc) * effective_output_scale)); + acc_scaled += output_offset; + acc_scaled = std::max(acc_scaled, output_activation_min); + acc_scaled = std::min(acc_scaled, output_activation_max); + output_data[out_c + output_depth * b] = static_cast(acc_scaled); + } + } +} + inline void FullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, const uint8_t* input_data, const RuntimeShape& filter_shape, @@ -164,6 +219,60 @@ inline void FullyConnected( } } +// This implementation receives the scales in float and performs requant in +// float to avoid loss of precision. +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + float input_scale, float output_scale, float filter_scale, + int16_t* output_data) { + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + TFLITE_DCHECK_EQ(output_offset, 0); + // TODO(b/62193649): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int output_dim_count = output_shape.DimensionsCount(); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2, + output_shape, output_dim_count - 1); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + // Internal accumulation. + // Initialize accumulator with the bias-value. + int32_t accum = bias_data[out_c]; + // Accumulation loop. + for (int d = 0; d < accum_depth; ++d) { + int16_t input_val = input_data[b * accum_depth + d] + input_offset; + int16_t filter_val = + filter_data[out_c * accum_depth + d] + filter_offset; + accum += filter_val * input_val; + } + const double effective_output_scale = static_cast(input_scale) * + static_cast(filter_scale) / + static_cast(output_scale); + int32_t acc_scaled = static_cast( + round(static_cast(accum) * effective_output_scale)); + // Saturate, cast to int16_t, and store to output array. + acc_scaled = std::max(acc_scaled, output_activation_min - output_offset); + acc_scaled = std::min(acc_scaled, output_activation_max - output_offset); + acc_scaled += output_offset; + output_data[out_c + output_depth * b] = acc_scaled; + } + } +} + inline void ShuffledFullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, const uint8_t* input_data, const RuntimeShape& weights_shape, diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h b/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h index c6d06077934839..f249beef8503f6 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h @@ -16,6 +16,8 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_FULLY_CONNECTED_H_ #include +#include +#include #include "tensorflow/lite/kernels/internal/common.h" @@ -74,6 +76,61 @@ void FullyConnectedPerChannel( } } +// This implementation receives the scales in float and performs requant in +// float to avoid loss of precision. +template +void FullyConnectedPerChannel( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const InputType* input_data, const RuntimeShape& filter_shape, + const WeightType* filter_data, const RuntimeShape& bias_shape, + const BiasType* bias_data, const RuntimeShape& output_shape, + float input_scale, float output_scale, const float* filter_scales, + OutputType* output_data) { + const int32_t input_offset = params.input_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int filter_dim_count = filter_shape.DimensionsCount(); + + const int output_dim_count = output_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = output_shape.Dims(output_dim_count - 1); + TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2)); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + BiasType acc = 0; + for (int d = 0; d < accum_depth; ++d) { + int32_t input_val = input_data[b * accum_depth + d]; + int32_t filter_val = filter_data[out_c * accum_depth + d]; + acc += filter_val * (input_val + input_offset); + } + if (bias_data) { + acc += bias_data[out_c]; + } + + const float scale = filter_scales[out_c]; + const double filter_scale = static_cast(scale); + const double effective_output_scale = static_cast(input_scale) * + filter_scale / + static_cast(output_scale); + int32_t acc_scaled = static_cast( + round(static_cast(acc) * effective_output_scale)); + + acc_scaled += output_offset; + acc_scaled = std::max(acc_scaled, output_activation_min); + acc_scaled = std::min(acc_scaled, output_activation_max); + output_data[out_c + output_depth * b] = + static_cast(acc_scaled); + } + } +} + template void FullyConnected(const FullyConnectedParams& params, @@ -122,6 +179,59 @@ void FullyConnected(const FullyConnectedParams& params, } } +// This implementation receives the scales in float and performs requant in +// float to avoid loss of precision. +template +void FullyConnected(const FullyConnectedParams& params, + const RuntimeShape& input_shape, + const InputType* input_data, + const RuntimeShape& filter_shape, + const WeightType* filter_data, + const RuntimeShape& bias_shape, const BiasType* bias_data, + const RuntimeShape& output_shape, float input_scale, + float output_scale, float filter_scale, + OutputType* output_data) { + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); + + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int output_dim_count = output_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = output_shape.Dims(output_dim_count - 1); + TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2)); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + BiasType acc = 0; + for (int d = 0; d < accum_depth; ++d) { + int32_t input_val = input_data[b * accum_depth + d]; + int32_t filter_val = filter_data[out_c * accum_depth + d]; + acc += (filter_val + filter_offset) * (input_val + input_offset); + } + if (bias_data) { + acc += bias_data[out_c]; + } + const double effective_output_scale = static_cast(input_scale) * + static_cast(filter_scale) / + static_cast(output_scale); + int32_t acc_scaled = static_cast( + round(static_cast(acc) * effective_output_scale)); + acc_scaled += output_offset; + acc_scaled = std::max(acc_scaled, output_activation_min); + acc_scaled = std::min(acc_scaled, output_activation_max); + output_data[out_c + output_depth * b] = + static_cast(acc_scaled); + } + } +} + } // namespace reference_integer_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index 9f74aad97553a0..a88ba32428f2e7 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -25,10 +25,10 @@ limitations under the License. #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" -#include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -982,7 +982,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(context, node, kInputGateBiasTensor); if (use_cifg) { - TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + TF_LITE_ENSURE(context, input_gate_bias == nullptr); } else { TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); @@ -1061,7 +1061,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor( context, node, kInputLayerNormCoefficientsTensor); if (use_cifg) { - TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr); + TF_LITE_ENSURE(context, input_layer_norm_coefficients == nullptr); } else { TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr); TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1); diff --git a/tensorflow/lite/kernels/parse_example/BUILD b/tensorflow/lite/kernels/parse_example/BUILD index b18d23b6c7d607..bbf62c1decb983 100644 --- a/tensorflow/lite/kernels/parse_example/BUILD +++ b/tensorflow/lite/kernels/parse_example/BUILD @@ -24,6 +24,7 @@ cc_library( compatible_with = get_compatible_with_portable(), features = tf_features_nolayering_check_if_ios(), deps = [ + "//tensorflow/core/platform:hash", "//tensorflow/lite:framework", "//tensorflow/lite:string_util", "//tensorflow/lite/core/c:common", @@ -31,6 +32,11 @@ cc_library( "//tensorflow/lite/kernels/internal:tensor", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@flatbuffers", ] + select({ "//tensorflow:android": [ @@ -111,6 +117,11 @@ cc_library( "//tensorflow/lite/kernels/internal:tensor", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@flatbuffers", ], ) diff --git a/tensorflow/lite/kernels/parse_example/parse_example.cc b/tensorflow/lite/kernels/parse_example/parse_example.cc index ec87aabfc86c95..6d6e77f02bb6cc 100644 --- a/tensorflow/lite/kernels/parse_example/parse_example.cc +++ b/tensorflow/lite/kernels/parse_example/parse_example.cc @@ -16,24 +16,35 @@ limitations under the License. #include #include +#include +#include +#include +#include #include #include -#include #include +#include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "xla/tsl/platform/errors.h" #include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/platform/blocking_counter.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/fingerprint.h" -#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/util/example_proto_fast_parsing.h" #include "tensorflow/core/util/presized_cuckoo_map.h" #include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h" #include "tensorflow/lite/mutable_op_resolver.h" @@ -46,7 +57,6 @@ namespace parse_example { namespace { namespace tf = ::tensorflow; -using tf::Status; using tf::StringPiece; using tf::tstring; using tf::example::CopyOrMoveBlock; @@ -116,7 +126,7 @@ bool ParseExample(StringRef serialized, Example* example) { return ParseExample(&stream, example); } -Status FastParseSerializedExample( +absl::Status FastParseSerializedExample( StringRef serialized_example, const tstring& example_name, const size_t example_index, const FastParseExampleConfig& config, bool* quick_filter, int quick_filter_size, @@ -139,7 +149,7 @@ Status FastParseSerializedExample( // I.e. last entry in the map overwrites all the previous ones. tensorflow::example::parsed::FeatureMapEntry& name_and_feature = parsed_example[parsed_example_size - i - 1]; - const StringPiece feature_name = name_and_feature.first; + const absl::string_view feature_name = name_and_feature.first; tensorflow::example::parsed::Feature& feature = name_and_feature.second; if (feature_name.length() >= quick_filter_size || !quick_filter[feature_name.length()]) { @@ -153,7 +163,7 @@ Status FastParseSerializedExample( size_t d = d_and_type.first; bool is_dense = d_and_type.second == Type::Dense; - auto example_error = [&](StringPiece suffix) { + auto example_error = [&](absl::string_view suffix) { return tf::errors::Internal("Name: ", example_name, ", Key: ", feature_name, ", Index: ", example_index, ". ", suffix); @@ -164,7 +174,7 @@ Status FastParseSerializedExample( }; tf::DataType example_dtype; - if (feature.ParseDataType(&example_dtype) != absl::OkStatus()) { + if (!feature.ParseDataType(&example_dtype).ok()) { return parse_error(); } if (is_dense) { @@ -184,7 +194,7 @@ Status FastParseSerializedExample( const std::size_t num_elements = config.dense[d].elements_per_stride; const std::size_t offset = example_index * num_elements; - auto shape_error = [&](size_t size, StringPiece type_str) { + auto shape_error = [&](size_t size, absl::string_view type_str) { return example_error(absl::StrCat( "Number of ", type_str, " values != expected. " @@ -238,7 +248,7 @@ Status FastParseSerializedExample( "Expected type: ", DataTypeString(config.dense[d].dtype))); } - auto shape_error = [&](size_t size, StringPiece type_str) { + auto shape_error = [&](size_t size, absl::string_view type_str) { return example_error( absl::StrCat("Number of ", type_str, " values is not a multiple of stride length. Saw ", @@ -452,7 +462,7 @@ inline void CopyToBuffer(absl::Span vec, char* tensor_buffer, } } -Status FastParseExampleLite( +absl::Status FastParseExampleLite( const FastParseExampleConfig& config, const TfLiteTensor* serialized, absl::Span example_names, bool* quick_filter, int quick_filter_size, const std::unique_ptr& config_index, @@ -465,7 +475,7 @@ Status FastParseExampleLite( std::vector fixed_dense_values(config.dense.size()); std::vector sparse_buffers(config.sparse.size()); std::vector varlen_dense_buffers(config.dense.size()); - Status status_of_minibatch; + absl::Status status_of_minibatch; for (size_t e = 0; e < count; ++e) { status_of_minibatch = FastParseSerializedExample( GetString(serialized, e), @@ -971,8 +981,8 @@ TfLiteStatus EvalParseExample(TfLiteContext* context, TfLiteNode* node) { data->config, serialized, {}, data->quick_filter, data->quick_filter_size, data->config_index, data->config_index_size, &data->hasher, &data->got, stats, context); - if (status != absl::OkStatus()) { - TF_LITE_KERNEL_LOG(context, status.ToString().c_str()); + if (!status.ok()) { + TF_LITE_KERNEL_LOG(context, "%s", status.ToString().c_str()); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/lite/kernels/pooling_test.cc b/tensorflow/lite/kernels/pooling_test.cc index 401bee6a242470..4634c5ce5e7adf 100644 --- a/tensorflow/lite/kernels/pooling_test.cc +++ b/tensorflow/lite/kernels/pooling_test.cc @@ -147,7 +147,7 @@ TEST(FloatPoolingOpTest, AveragePool) { 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {2.75, 5.75})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({2.75, 5.75}))); } TEST(FloatPoolingOpTest, AveragePoolActivationRelu) { @@ -161,7 +161,7 @@ TEST(FloatPoolingOpTest, AveragePoolActivationRelu) { 3, 2, -10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {0.0, 0.75})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({0.0, 0.75}))); } TEST(FloatPoolingOpTest, AveragePoolActivationRelu1) { @@ -175,14 +175,14 @@ TEST(FloatPoolingOpTest, AveragePoolActivationRelu1) { -3, -2, -10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {-1.0, 0.75})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-1.0, 0.75}))); m.SetInput({ 0, -6, -2, -4, // -3, -2, 10, -7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {-1.0, -0.75})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-1.0, -0.75}))); } TEST(FloatPoolingOpTest, AveragePoolActivationRelu6) { @@ -196,14 +196,14 @@ TEST(FloatPoolingOpTest, AveragePoolActivationRelu6) { -3, -2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {0.0, 6.0})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({0.0, 6.0}))); m.SetInput({ 0, 6, 12, 4, // 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {2.75, 6.0})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({2.75, 6.0}))); } TEST(FloatPoolingOpTest, AveragePoolPaddingSameStride1) { @@ -217,9 +217,8 @@ TEST(FloatPoolingOpTest, AveragePoolPaddingSameStride1) { 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT( - m.GetOutput(), - Pointwise(FloatingPointEq(), {2.75, 5.0, 5.75, 5.5, 2.5, 6.0, 8.5, 7.0})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + {2.75, 5.0, 5.75, 5.5, 2.5, 6.0, 8.5, 7.0}))); } TEST(FloatPoolingOpTest, AveragePoolPaddingValidStride1) { @@ -233,7 +232,8 @@ TEST(FloatPoolingOpTest, AveragePoolPaddingValidStride1) { 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {2.75, 5.0, 5.75})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({2.75, 5.0, 5.75}))); } TEST(QuantizedPoolingOpTest, AveragePool) { @@ -643,7 +643,7 @@ TEST(FloatPoolingOpTest, MaxPoolActivationRelu) { -3, -2, 10.5, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {0.0, 10.5})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({0.0, 10.5}))); } TEST(FloatPoolingOpTest, MaxPoolActivationRelu1) { @@ -657,14 +657,14 @@ TEST(FloatPoolingOpTest, MaxPoolActivationRelu1) { -3, -2, -0.3, 0.7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {-1.0, 0.7})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-1.0, 0.7}))); m.SetInput({ -2.75, -6, -2, -4, // -3, -2, 10, -7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {-1.0, 1.0})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-1.0, 1.0}))); } TEST(FloatPoolingOpTest, MaxPoolActivationRelu6) { @@ -678,14 +678,14 @@ TEST(FloatPoolingOpTest, MaxPoolActivationRelu6) { -3, -2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {0.0, 6.0})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({0.0, 6.0}))); m.SetInput({ 0, 4.5, 12, 4, // 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {4.5, 6.0})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({4.5, 6.0}))); } TEST(FloatPoolingOpTest, MaxPoolPaddingSameStride1) { @@ -1063,7 +1063,7 @@ TEST(FloatPoolingOpTest, L2Pool) { 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {3.5, 6.5})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3.5, 6.5}))); } TEST(FloatPoolingOpTest, L2PoolActivationRelu) { @@ -1118,7 +1118,7 @@ TEST(FloatPoolingOpTest, L2PoolPaddingSame) { 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {3.5, 6.5})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3.5, 6.5}))); } TEST(FloatPoolingOpTest, L2PoolPaddingSameSlide1) { @@ -1149,7 +1149,7 @@ TEST(FloatPoolingOpTest, L2PoolPaddingValidSlide1) { 3, 2, 10, 7, // }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), Pointwise(FloatingPointEq(), {3.5, 6.0, 6.5})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3.5, 6.0, 6.5}))); } #if GTEST_HAS_DEATH_TEST diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index 479f3d5c996a8f..a6cdc7b6cc80f4 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -375,10 +375,10 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { /* max_version = */ 2); AddBuiltin(BuiltinOperator_CAST, Register_CAST(), /* min_version = */ 1, - /* max_version = */ 4); + /* max_version = */ 7); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE_REF(), /* min_version = */ 1, - /* max_version = */ 4); + /* max_version = */ 6); AddBuiltin(BuiltinOperator_PRELU, Register_PRELU_REF()); AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM_REF(), /* min_version = */ 1, @@ -500,7 +500,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG()); AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE_REF(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG()); AddBuiltin(BuiltinOperator_IF, Register_IF()); AddBuiltin(BuiltinOperator_WHILE, Register_WHILE()); diff --git a/tensorflow/lite/kernels/stablehlo_composite_test.cc b/tensorflow/lite/kernels/stablehlo_composite_test.cc index 65a935ca94a113..612da57d0cddb7 100644 --- a/tensorflow/lite/kernels/stablehlo_composite_test.cc +++ b/tensorflow/lite/kernels/stablehlo_composite_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -101,9 +102,11 @@ TEST_F(CompositeTest, TestXNNPACKDelegation) { interpreter_->subgraph(1)); const auto opt = TfLiteXNNPackDelegateOptionsDefault(); - TfLiteDelegate* xnnpack_delegate = TfLiteXNNPackDelegateCreate(&opt); + std::unique_ptr xnnpack_delegate( + TfLiteXNNPackDelegateCreate(&opt), TfLiteXNNPackDelegateDelete); interpreter_->primary_subgraph().MarkAsDelegationSkippable(); - ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(xnnpack_delegate), kTfLiteOk); + ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(std::move(xnnpack_delegate)), + kTfLiteOk); ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2, 3}), kTfLiteOk); ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2, 3}), @@ -131,7 +134,6 @@ TEST_F(CompositeTest, TestXNNPACKDelegation) { ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); - TfLiteXNNPackDelegateDelete(xnnpack_delegate); } } // namespace diff --git a/tensorflow/lite/kernels/stablehlo_scatter.cc b/tensorflow/lite/kernels/stablehlo_scatter.cc index be67dc39e911ac..8caad1ef449976 100644 --- a/tensorflow/lite/kernels/stablehlo_scatter.cc +++ b/tensorflow/lite/kernels/stablehlo_scatter.cc @@ -22,6 +22,7 @@ limitations under the License. #include "Eigen/Core" // from @eigen_archive #include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/subgraph.h" @@ -314,6 +315,8 @@ TfLiteStatus EvalWithIndexType(TfLiteContext* context, TfLiteNode* node, return EvalWithTypes(context, node); case kTfLiteUInt64: return EvalWithTypes(context, node); + case kTfLiteBool: + return EvalWithTypes(context, node); default: TF_LITE_KERNEL_LOG( context, "(Index Type: %s, Data Type: %s) currently not supported.\n", diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index 04167e897fb808..518b5c7d69bccb 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -16,14 +16,20 @@ limitations under the License. #include #include +#include #include +#include +#include +#include +#include #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" -#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -562,7 +568,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor); if (use_cifg) { - TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + TF_LITE_ENSURE(context, input_gate_bias == nullptr); } else { TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); @@ -642,7 +648,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor( context, node, lstm::full::kInputLayerNormCoefficientsTensor); if (use_cifg) { - TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr); + TF_LITE_ENSURE(context, input_layer_norm_coefficients == nullptr); } else { TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr); TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1); diff --git a/tensorflow/lite/kernels/while_test.cc b/tensorflow/lite/kernels/while_test.cc index 0e0a3e43a72727..1dee795c72e62a 100644 --- a/tensorflow/lite/kernels/while_test.cc +++ b/tensorflow/lite/kernels/while_test.cc @@ -15,8 +15,10 @@ limitations under the License. #include #include +#include #include +#include #include "tensorflow/lite/core/interpreter.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" @@ -44,10 +46,12 @@ TEST_F(WhileTest, TestWithXNNPACK) { builder_->BuildFloatWhileSubgraph(&interpreter_->primary_subgraph(), 2); const auto opt = TfLiteXNNPackDelegateOptionsDefault(); - TfLiteDelegate* xnnpack_delegate = TfLiteXNNPackDelegateCreate(&opt); + std::unique_ptr xnnpack_delegate( + TfLiteXNNPackDelegateCreate(&opt), TfLiteXNNPackDelegateDelete); interpreter_->primary_subgraph().MarkAsDelegationSkippable(); interpreter_->subgraph(1)->MarkAsDelegationSkippable(); - ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(xnnpack_delegate), kTfLiteOk); + ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(std::move(xnnpack_delegate)), + kTfLiteOk); ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), kTfLiteOk); ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), @@ -71,7 +75,6 @@ TEST_F(WhileTest, TestWithXNNPACK) { ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); - TfLiteXNNPackDelegateDelete(xnnpack_delegate); } TEST_F(WhileTest, TestInputIsOutput) { diff --git a/tensorflow/lite/objc/BUILD.apple b/tensorflow/lite/objc/BUILD.apple index 856ed56ad35967..0f1dbf75429388 100644 --- a/tensorflow/lite/objc/BUILD.apple +++ b/tensorflow/lite/objc/BUILD.apple @@ -95,7 +95,8 @@ objc_library( # directory name. (See: b/174508866) ios_unit_test( name = "tests", - size = "medium", + size = "small", + timeout = "moderate", minimum_os_version = TFL_MINIMUM_OS_VERSION, runner = tflite_ios_lab_runner("IOS_LATEST"), tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, diff --git a/tensorflow/lite/objc/apps/TestApp/TestApp/Base.lproj/LaunchScreen.storyboard b/tensorflow/lite/objc/apps/TestApp/TestApp/Base.lproj/LaunchScreen.storyboard index 6c97d768e15ada..312821c21ae24e 100644 --- a/tensorflow/lite/objc/apps/TestApp/TestApp/Base.lproj/LaunchScreen.storyboard +++ b/tensorflow/lite/objc/apps/TestApp/TestApp/Base.lproj/LaunchScreen.storyboard @@ -1,11 +1,9 @@ - - - - + + - + @@ -21,8 +19,8 @@ -